From 50443f1fbf66475d32972cd505ba4bdc70ebaf6f Mon Sep 17 00:00:00 2001 From: cbroz1 Date: Thu, 25 Jan 2024 16:38:02 -0800 Subject: [PATCH 01/10] WIP: fix for #791 --- src/spyglass/utils/database_settings.py | 20 ++--- src/spyglass/utils/dj_merge_tables.py | 97 +++++++++++++++--------- src/spyglass/utils/dj_mixin.py | 99 +++++++++++++++++++++++++ 3 files changed, 170 insertions(+), 46 deletions(-) diff --git a/src/spyglass/utils/database_settings.py b/src/spyglass/utils/database_settings.py index 5a634c69c..da65914fa 100755 --- a/src/spyglass/utils/database_settings.py +++ b/src/spyglass/utils/database_settings.py @@ -14,6 +14,16 @@ CREATE_USR = "CREATE USER IF NOT EXISTS " TEMP_PASS = " IDENTIFIED BY 'temppass';" ESC = r"\_%" +SHARED_MODULES = [ + "common", + "spikesorting", + "decoding", + "position", + "position_linearization", + "ripple", + "lfp", + "waveform", +] class DatabaseSettings: @@ -40,15 +50,7 @@ def __init__( target_database : str, optional Default is mysql. Can also be docker container id """ - self.shared_modules = [ - f"common{ESC}", - f"spikesorting{ESC}", - f"decoding{ESC}", - f"position{ESC}", - f"position_linearization{ESC}", - f"ripple{ESC}", - f"lfp{ESC}", - ] + self.shared_modules = [f"{m}{ESC}" for m in SHARED_MODULES] self.user = user_name or dj.config["database.user"] self.host = ( host_name or dj.config["database.host"] or "lmf-db.cin.ucsf.edu" diff --git a/src/spyglass/utils/dj_merge_tables.py b/src/spyglass/utils/dj_merge_tables.py index c37122c70..63f8fb948 100644 --- a/src/spyglass/utils/dj_merge_tables.py +++ b/src/spyglass/utils/dj_merge_tables.py @@ -52,8 +52,27 @@ def __init__(self): + f"\n\tExpected: {self.primary_key}" + f"\n\tActual : {part.primary_key}" ) + self._analysis_nwbfile = None self._source_class_dict = {} + @property + def source_class_dict(self) -> dict: + if not self._source_class_dict: + module = getmodule(self) + self._source_class_dict = { + part_name: getattr(module, part_name) + for part_name in self.parts(camel_case=True) + } + return self._source_class_dict + + @property # CB: This is a property to avoid circular import + def analysis_nwbfile(self): + if self._analysis_nwbfile is None: + from spyglass.common import AnalysisNwbfile # noqa F401 + + self._analysis_nwbfile = AnalysisNwbfile + return self._analysis_nwbfile + def _remove_comments(self, definition): """Use regular expressions to remove comments and blank lines""" return re.sub( # First remove comments, then blank lines @@ -822,23 +841,17 @@ def delete_downstream_merge( restriction = True descendants = _unique_descendants(table, recurse_level) - merge_table_pairs = _master_table_pairs( + merge_table_dicts = _search_descendants( table_list=descendants, - restricted_parent=(table & restriction), + parent_table=table, + restriction=restriction, ) - # restrict the merge table based on uuids in part - # don't need part for del, but show on dry_run - merge_pairs = [ - (merge & part.fetch(RESERVED_PRIMARY_KEY, as_dict=True), part) - for merge, part in merge_table_pairs - ] - if dry_run: - return merge_pairs + return [d["view"] for d in merge_table_dicts] - for merge_table, _ in merge_pairs: - merge_table.delete(**kwargs) + for merge_dict in merge_table_dicts: + (merge_dict["master"] & merge_dict["keys"]).delete(**kwargs) def _warn_on_restriction(table: dj.Table, restriction: str = None): @@ -876,12 +889,9 @@ def _unique_descendants( """ if recurse_level == 0: - return [] + return {} - if attribute is None: - skip_attr_check = True - else: - skip_attr_check = False + skip_attr_check = True if attribute is None else False descendants = {} @@ -901,13 +911,14 @@ def recurse_descendants(sub_table, level): ) -def _master_table_pairs( +def _search_descendants( table_list: list, - restricted_parent: dj.expression.QueryExpression = True, + parent_table: dj.Table, + restriction: str = True, connection: dj.connection.Connection = None, ) -> list: """ - Given list of tables, return a list of master table pairs. + Given list of descendant tables, find merge tables linked to parent table. Returns a list of tuples, with master and part. Part will have restriction applied. If restriction yield empty list, skip. @@ -916,20 +927,23 @@ def _master_table_pairs( ---------- table_list : List[dj.Table] A list of datajoint tables. - restricted_parent : dj.expression.QueryExpression - Parent table restricted, to be joined with master and part. Default - True, no restriction. + parent_table : dj.Table + The parent table of the tables in table_list. + restiction : str, optional + Restriction applies to parent table. Default True, no restriction. connection : datajoint.connection.Connection A database connection. Default None, use connection from first table. Returns ------- - List[Tuple[dj.Table, dj.Table]] - A list of master table pairs. + List[dict] + A list of dictionaries, each with keys: master, view, and keys. + These are the master table, the restricted view of the master table, + and the primary keys of the restricted view. """ conn = connection or table_list[0].connection - master_table_pairs = [] + ret = [] unique_parts = [] # Adapted from Spyglass PR 535 @@ -939,24 +953,33 @@ def _master_table_pairs( continue master_name = get_master(table_name) - if not master_name: # then it's not a part table + if not master_name: # then not a part table continue master = dj.FreeTable(conn, master_name) - if RESERVED_PRIMARY_KEY not in master.heading.attributes.keys(): - continue # then it's not a merge table + # if table has no master, or master is not MergeTable, skip + if RESERVED_PRIMARY_KEY not in master.heading.names: + continue + + try: + restricted_master = ( + master & restriction + ) # .merge_restrict(restriction) + except DataJointError: + continue - restricted_join = restricted_parent * table - if not restricted_join: # No entries relevant to restriction in part + if not restricted_master: # No entries relevant to restriction in part continue unique_parts.append(table_name) - master_table_pairs.append( - ( - master, - table - & restricted_join.fetch(RESERVED_PRIMARY_KEY, as_dict=True), + ret.append( + dict( + master=master, + view=restricted_master, + keys=restricted_master.fetch( + RESERVED_PRIMARY_KEY, as_dict=True + ), ) ) - return master_table_pairs + return ret diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index 490274fe0..a813ae4e3 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -2,6 +2,7 @@ from datajoint.table import logger as dj_logger from datajoint.utils import user_choice +from spyglass.utils.database_settings import SHARED_MODULES from spyglass.utils.dj_helper_fn import fetch_nwb from spyglass.utils.logging import logger @@ -318,3 +319,101 @@ def cautious_delete(self, force_permission: bool = False, *args, **kwargs): def cdel(self, *args, **kwargs): """Alias for cautious_delete.""" self.cautious_delete(*args, **kwargs) + + +def get_instanced(*tables): + ret = [] + if not isinstance(tables, tuple): + tables = [tables] + for table in tables: + if not isinstance(table, dj.user_tables.Table): + ret.append(table()) + else: + ret.append(table) + return ret[0] if len(ret) == 1 else ret + + +def find_tables_connecting( + parent_table: dj.user_tables.TableMeta, + child_table: dj.user_tables.TableMeta, + recurse_level: int = 4, + visited: set = None, +) -> list: + """ + Return list of tables connecting the parent and child for a valid join. + + Parameters + ---------- + parent_table : dj.user_tables.TableMeta + DataJoint table upstream in pipeline. + child_table : dj.user_tables.TableMeta + DataJoint table downstream in pipeline. + recurse_level : int, optional + Maximum number of recursion levels. Default is 4. + visited : set, optional + Set of visited tables (used internally for recursion). + + Returns + ------- + list + List of paths, with each path as a list FreeTables connecting the parent + and child for a valid join. + """ + parent_table, child_table = get_instanced(parent_table, child_table) + visited = visited or set() + child_is_merge = isinstance(child_table, Merge) + + if recurse_level < 1 or ( # if too much recursion + not child_is_merge # merge table ok + and ( # already visited, outside spyglass, or no connection + child_table.full_table_name in visited + or child_table.full_table_name.strip("`").split("_")[0] + not in SHARED_MODULES + or child_table.full_table_name not in parent_table.descendants() + ) + ): + return [] + + if child_table.full_table_name in parent_table.children(): + return [parent_table, child_table] + + if child_is_merge: + _ = child_table._ensure_dependencies_loaded() + ret = [] + for part in child_table.parts(as_objects=True): + connecting_path = find_tables_connecting( + parent_table, + part, + recurse_level=recurse_level, + visited=visited, + ) + visited.add(part.full_table_name) + if connecting_path: + ret.append(connecting_path + [child_table]) + + return ret + + for child in parent_table.children(as_objects=True): + connecting_path = find_tables_connecting( + child, + child_table, + recurse_level=recurse_level - 1, + visited=visited, + ) + visited.add(child.full_table_name) + if connecting_path: + return [parent_table] + connecting_path + + return [] + + +def join_all(connecting_tables, restriction=True): + if not isinstance(connecting_tables, list): + connecting_tables = [connecting_tables] + ret = [] + for table_list in connecting_tables: + join = table_list[0] & restriction + for table in table_list[1:]: + join = join * table + ret.append(join) + return ret[0] if len(ret) == 1 else ret From 4bf8d79fad5f32e4957ec131f94ccc2e04fefa59 Mon Sep 17 00:00:00 2001 From: cbroz1 Date: Fri, 26 Jan 2024 14:58:07 -0800 Subject: [PATCH 02/10] WIP: #791, pt 2 --- src/spyglass/utils/dj_merge_tables.py | 9 - src/spyglass/utils/dj_mixin.py | 379 +++++++++++++++++++------- 2 files changed, 274 insertions(+), 114 deletions(-) diff --git a/src/spyglass/utils/dj_merge_tables.py b/src/spyglass/utils/dj_merge_tables.py index 63f8fb948..5d1e7d235 100644 --- a/src/spyglass/utils/dj_merge_tables.py +++ b/src/spyglass/utils/dj_merge_tables.py @@ -52,7 +52,6 @@ def __init__(self): + f"\n\tExpected: {self.primary_key}" + f"\n\tActual : {part.primary_key}" ) - self._analysis_nwbfile = None self._source_class_dict = {} @property @@ -65,14 +64,6 @@ def source_class_dict(self) -> dict: } return self._source_class_dict - @property # CB: This is a property to avoid circular import - def analysis_nwbfile(self): - if self._analysis_nwbfile is None: - from spyglass.common import AnalysisNwbfile # noqa F401 - - self._analysis_nwbfile = AnalysisNwbfile - return self._analysis_nwbfile - def _remove_comments(self, definition): """Use regular expressions to remove comments and blank lines""" return re.sub( # First remove comments, then blank lines diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index a813ae4e3..46d536b44 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -1,9 +1,14 @@ +from collections.abc import Iterable +from typing import Dict, List, Union + import datajoint as dj from datajoint.table import logger as dj_logger from datajoint.utils import user_choice from spyglass.utils.database_settings import SHARED_MODULES from spyglass.utils.dj_helper_fn import fetch_nwb +from spyglass.utils.dj_merge_tables import RESERVED_PRIMARY_KEY as MERGE_PK +from spyglass.utils.dj_merge_tables import Merge from spyglass.utils.logging import logger @@ -40,6 +45,9 @@ class SpyglassMixin: # pks for delete permission check, assumed to be on field _session_pk = None # Session primary key. Mixin is ambivalent to Session pk _member_pk = None # LabMember primary key. Mixin ambivalent table structure + _merge_cache = {} # Cache of merge tables downstream of self + _merge_cache_links = {} # Cache of merge links downstream of self + _session_connection_cache = None # Cache of path from Session to self # ------------------------------- fetch_nwb ------------------------------- @@ -147,6 +155,148 @@ def _merge_del_func(self) -> callable: self._merge_delete_func = delete_downstream_merge return self._merge_delete_func + @staticmethod + def _get_instanced(*tables) -> Union[dj.user_tables.Table, list]: + """Return instance of table(s) if not already instanced.""" + ret = [] + if not isinstance(tables, Iterable): + tables = tuple(tables) + for table in tables: + if not isinstance(table, dj.user_tables.Table): + ret.append(table()) + else: + ret.append(table) + return ret[0] if len(ret) == 1 else ret + + @staticmethod + def _link_repr(parent, child, width=120): + len_each = (width - 4) // 2 + p = parent.full_table_name[:len_each].ljust(len_each) + c = child.full_table_name[:len_each].ljust(len_each) + return f"{p} -> {c}" + + def _get_connection( + self, + child: dj.user_tables.TableMeta, + parent: dj.user_tables.TableMeta = None, + recurse_level: int = 4, + visited: set = None, + ) -> Union[List[dj.FreeTable], List[List[dj.FreeTable]]]: + """ + Return list of tables connecting the parent and child for a valid join. + + Parameters + ---------- + parent : dj.user_tables.TableMeta + DataJoint table upstream in pipeline. + child : dj.user_tables.TableMeta + DataJoint table downstream in pipeline. + recurse_level : int, optional + Maximum number of recursion levels. Default is 4. + visited : set, optional + Set of visited tables (used internally for recursion). + + Returns + ------- + List[dj.FreeTable] or List[List[dj.FreeTable]] + List of paths, with each path as a list FreeTables connecting the + parent and child for a valid join. + """ + parent = parent or self + parent, child = self._get_instanced(parent, child) + visited = visited or set() + child_is_merge = child.full_table_name in self._merge_tables + + if recurse_level < 1 or ( # if too much recursion + not child_is_merge # merge table ok + and ( # already visited, outside spyglass, or no connection + child.full_table_name in visited + or child.full_table_name.strip("`").split("_")[0] + not in SHARED_MODULES + or child.full_table_name not in parent.descendants() + ) + ): + return [] + + if child.full_table_name in parent.children(): + logger.debug(f"1-{recurse_level}:" + self._link_repr(parent, child)) + if isinstance(child, dict) or isinstance(parent, dict): + __import__("pdb").set_trace() + return [parent, child] + + if child_is_merge: + ret = [] + parts = child.parts(as_objects=True) + if not parts: + logger.warning(f"Merge has no parts: {child.full_table_name}") + for part in child.parts(as_objects=True): + links = self._get_connection( + parent=parent, + child=part, + recurse_level=recurse_level, + visited=visited, + ) + visited.add(part.full_table_name) + if links: + logger.debug( + f"2-{recurse_level}:" + self._link_repr(parent, part) + ) + ret.append(links + [child]) + + return ret + + for subchild in parent.children(as_objects=True): + links = self._get_connection( + parent=subchild, + child=child, + recurse_level=recurse_level - 1, + visited=visited, + ) + visited.add(subchild.full_table_name) + if links: + logger.debug( + f"3-{recurse_level}:" + self._link_repr(subchild, child) + ) + if parent.full_table_name in [l.full_table_name for l in links]: + return links + else: + return [parent] + links + + return [] + + def _join_list( + self, + tables: Union[List[dj.FreeTable], List[List[dj.FreeTable]]], + restriction: str = None, + ) -> dj.expression.QueryExpression: + """Return join of all tables in list. Omits empty items.""" + restriction = restriction or self.restriction or True + + if not isinstance(tables[0], (list, tuple)): + tables = [tables] + ret = [] + for table_list in tables: + join = table_list[0] & restriction + for table in table_list[1:]: + join = join * table + if join: + ret.append(join) + return ret[0] if len(ret) == 1 else ret + + def _connection_repr(self, connection) -> str: + if not isinstance(connection[0], Iterable): + connection = [connection] + ret = [] + for table_list in connection: + connection_str = "" + for table in table_list: + if isinstance(table, str): + connection_str += table + " -> " + else: + connection_str += table.table_name + " -> " + ret.append(f"\n\tPath: {connection_str[:-4]}") + return ret + def _find_session_link( self, table: dj.user_tables.UserTable, @@ -185,7 +335,111 @@ def _find_session_link( return table * Session - def _get_exp_summary(self, sess_link: dj.expression.QueryExpression): + def _ensure_dependencies_loaded(self) -> None: + """Ensure connection dependencies loaded.""" + if not self.connection.dependencies._loaded: + self.connection.dependencies.load() + + @property + def _merge_tables(self) -> Dict[str, dj.FreeTable]: + """Dict of merge tables downstream of self. + + Cache of items in parents of self.descendants(as_objects=True) that + have a merge primary key. + """ + if self._merge_cache: + return self._merge_cache + + def has_merge_pk(table): + ret = MERGE_PK in table.heading.names + if "output" in table.full_table_name and not ret: + logger.warning( + f"Skipping merge table without merge primary key: " + + f"{table.full_table_name}" + ) + return MERGE_PK in table.heading.names + + self._ensure_dependencies_loaded() + for desc in self.descendants(as_objects=True): + if "output" not in desc.full_table_name: + logger.debug( + f"Skipping non-output table: {desc.full_table_name}" + ) + if not has_merge_pk(desc): + continue + for parent in desc.parents(as_objects=True): + if ( + has_merge_pk(parent) + and parent.full_table_name not in self._merge_cache + ): + self._merge_cache[parent.full_table_name] = parent + logger.info(f"Found {len(self._merge_cache)} merge tables") + + return self._merge_cache + + @property + def _merge_links(self) -> Dict[str, List[dj.FreeTable]]: + """Dict of merge links downstream of self. + + For each merge table found in _merge_tables, find the path from self to + merge. If the path is valid, add it to the dict. Cahche prevents need + to recompute whenever delete_downstream_merge is called with a new + restriction. + """ + if self._merge_cache_links: + return self._merge_cache_links + for name, merge_table in self._merge_tables.items(): + connection = self._get_connection(child=merge_table) + if connection: + self._merge_cache_links[name] = connection + return self._merge_cache_links + + def delete_downstream_merge( + self, + restriction: str = None, + dry_run: bool = True, + disable_warning: bool = False, + **kwargs, + ) -> List[dj.expression.QueryExpression]: + """Delete downstream merge table entries associated with restricton. + + Requires caching of merge tables and links, which is slow on first call. + + Parameters + ---------- + restriction : str, optional + Restriction to apply to merge tables. Default None. Will attempt to + use table restriction if None. + dry_run : bool, optional + If True, return list of merge part entries to be deleted. Default + True. + **kwargs : Any + Passed to datajoint.table.Table.delete. + """ + restriction = restriction or self.restriction or True + + merge_join_dict = {} + for merge_name, merge_link in self._merge_links.items(): + logger.debug(self._connection_repr(merge_link)) + joined = self._join_list(merge_link, restriction=self.restriction) + if joined: + merge_join_dict[self._merge_tables[merge_name]] = joined + + if dry_run: + return merge_join_dict.values() + + ret = [] + for table, selection in merge_join_dict.items(): + keys = selection.fetch(MERGE_PK, as_dict=True) + ret.append(table & keys) + # (table & keys).delete(**kwargs) # TODO: Run delete here + return ret + + def ddm(self, *args, **kwargs): + """Alias for delete_downstream_merge.""" + return self.delete_downstream_merge(*args, **kwargs) + + def _get_exp_summary(self): """Get summary of experimenters for session(s), including NULL. Parameters @@ -201,6 +455,8 @@ def _get_exp_summary(self, sess_link: dj.expression.QueryExpression): Session = self._delete_deps[-1] format = dj.U(self._session_pk, self._member_pk) + + sess_link = self._join_list(self._session_connection) exp_missing = format & (sess_link - Session.Experimenter).proj( **{self._member_pk: "NULL"} ) @@ -209,6 +465,19 @@ def _get_exp_summary(self, sess_link: dj.expression.QueryExpression): ) return exp_missing + exp_present + @property + def _session_connection(self) -> dj.expression.QueryExpression: + """Path from Session table to self. + + None is not yet cached, False if no connection found. + """ + if not self._session_connection_cache is None: + self._session_connection_cache = ( + self._find_connection(parent=self._delete_deps[-1], child=self) + or False + ) + return self._session_connection_cache + def _check_delete_permission(self) -> None: """Check user name against lab team assoc. w/ self * Session. @@ -229,18 +498,16 @@ def _check_delete_permission(self) -> None: if dj_user in LabMember().admin: # bypass permission check for admin return - sess_link = self._find_session_link(table=self) - if not sess_link: # Permit delete if not linked to a session - logger.warn( + if not self._session_connection: + logger.warn( # Permit delete if no session connection "Could not find lab team associated with " + f"{self.__class__.__name__}." + "\nBe careful not to delete others' data." ) return - sess_summary = self._get_exp_summary( - sess_link.restrict(self.restriction) - ) + sess_link = self.join(self._session_connection) + sess_summary = self._get_exp_summary(sess_link) experimenters = sess_summary.fetch(self._member_pk) if None in experimenters: raise PermissionError( @@ -319,101 +586,3 @@ def cautious_delete(self, force_permission: bool = False, *args, **kwargs): def cdel(self, *args, **kwargs): """Alias for cautious_delete.""" self.cautious_delete(*args, **kwargs) - - -def get_instanced(*tables): - ret = [] - if not isinstance(tables, tuple): - tables = [tables] - for table in tables: - if not isinstance(table, dj.user_tables.Table): - ret.append(table()) - else: - ret.append(table) - return ret[0] if len(ret) == 1 else ret - - -def find_tables_connecting( - parent_table: dj.user_tables.TableMeta, - child_table: dj.user_tables.TableMeta, - recurse_level: int = 4, - visited: set = None, -) -> list: - """ - Return list of tables connecting the parent and child for a valid join. - - Parameters - ---------- - parent_table : dj.user_tables.TableMeta - DataJoint table upstream in pipeline. - child_table : dj.user_tables.TableMeta - DataJoint table downstream in pipeline. - recurse_level : int, optional - Maximum number of recursion levels. Default is 4. - visited : set, optional - Set of visited tables (used internally for recursion). - - Returns - ------- - list - List of paths, with each path as a list FreeTables connecting the parent - and child for a valid join. - """ - parent_table, child_table = get_instanced(parent_table, child_table) - visited = visited or set() - child_is_merge = isinstance(child_table, Merge) - - if recurse_level < 1 or ( # if too much recursion - not child_is_merge # merge table ok - and ( # already visited, outside spyglass, or no connection - child_table.full_table_name in visited - or child_table.full_table_name.strip("`").split("_")[0] - not in SHARED_MODULES - or child_table.full_table_name not in parent_table.descendants() - ) - ): - return [] - - if child_table.full_table_name in parent_table.children(): - return [parent_table, child_table] - - if child_is_merge: - _ = child_table._ensure_dependencies_loaded() - ret = [] - for part in child_table.parts(as_objects=True): - connecting_path = find_tables_connecting( - parent_table, - part, - recurse_level=recurse_level, - visited=visited, - ) - visited.add(part.full_table_name) - if connecting_path: - ret.append(connecting_path + [child_table]) - - return ret - - for child in parent_table.children(as_objects=True): - connecting_path = find_tables_connecting( - child, - child_table, - recurse_level=recurse_level - 1, - visited=visited, - ) - visited.add(child.full_table_name) - if connecting_path: - return [parent_table] + connecting_path - - return [] - - -def join_all(connecting_tables, restriction=True): - if not isinstance(connecting_tables, list): - connecting_tables = [connecting_tables] - ret = [] - for table_list in connecting_tables: - join = table_list[0] & restriction - for table in table_list[1:]: - join = join * table - ret.append(join) - return ret[0] if len(ret) == 1 else ret From 8875699070077eda6b3589ecf5adeace6f098a00 Mon Sep 17 00:00:00 2001 From: cbroz1 Date: Fri, 26 Jan 2024 16:37:41 -0800 Subject: [PATCH 03/10] WIP: #791, needs testing --- src/spyglass/utils/dj_merge_tables.py | 154 +------------------------- src/spyglass/utils/dj_mixin.py | 127 +++++++++------------ 2 files changed, 57 insertions(+), 224 deletions(-) diff --git a/src/spyglass/utils/dj_merge_tables.py b/src/spyglass/utils/dj_merge_tables.py index 5d1e7d235..7917bb161 100644 --- a/src/spyglass/utils/dj_merge_tables.py +++ b/src/spyglass/utils/dj_merge_tables.py @@ -813,8 +813,6 @@ def delete_downstream_merge( dry_run: bool Default True. If true, return list of tuples, merge/part tables downstream of table input. Otherwise, delete merge/part table entries. - recurse_level: int - Default 2. Depth to recurse into table descendants. disable_warning: bool Default False. If True, don't warn about restrictions on table object. kwargs: dict @@ -825,152 +823,12 @@ def delete_downstream_merge( List[Tuple[dj.Table, dj.Table]] Entries in merge/part tables downstream of table input. """ - if not disable_warning: - _warn_on_restriction(table, restriction) + from spyglass.utils.dj_mixin import SpyglassMixin - if not restriction: - restriction = True + if not isinstance(table, SpyglassMixin): + raise ValueError("Input must be a Spyglass Table.") + table = table if isinstance(table, dj.Table) else table() - descendants = _unique_descendants(table, recurse_level) - merge_table_dicts = _search_descendants( - table_list=descendants, - parent_table=table, - restriction=restriction, + return table.delete_downstream_merge( + restriction=restriction, dry_run=dry_run, **kwargs ) - - if dry_run: - return [d["view"] for d in merge_table_dicts] - - for merge_dict in merge_table_dicts: - (merge_dict["master"] & merge_dict["keys"]).delete(**kwargs) - - -def _warn_on_restriction(table: dj.Table, restriction: str = None): - """Warn if restriction on table object differs from input restriction""" - if restriction is None and table.restriction: - logger.warn( - f"Warning: ignoring table restriction: {table().restriction}.\n\t" - + "Please pass restrictions as an arg" - ) - - -def _unique_descendants( - table: dj.Table, - recurse_level: int = 2, - return_names: bool = False, - attribute=None, -) -> list: - """Recurisively find unique descendants of a given table - - Parameters - ---------- - table: dj.Table - The node in the tree from which to find descendants. - recurse_level: int - The maximum level of descendants to find. - return_names: bool - If True, return names of descendants found. Else return Table objects. - attribute: str, optional - If provided, only return descendants that have this attribute. - - Returns - ------- - List[dj.Table, str] - List descendants found when recurisively called to recurse_level - """ - - if recurse_level == 0: - return {} - - skip_attr_check = True if attribute is None else False - - descendants = {} - - def recurse_descendants(sub_table, level): - for descendant in sub_table.descendants(as_objects=True): - if descendant.full_table_name not in descendants and ( - skip_attr_check or attribute in descendant.heading.attributes - ): - descendants[descendant.full_table_name] = descendant - if level > 1: - recurse_descendants(descendant, level - 1) - - recurse_descendants(table, recurse_level) - - return ( - list(descendants.keys()) if return_names else list(descendants.values()) - ) - - -def _search_descendants( - table_list: list, - parent_table: dj.Table, - restriction: str = True, - connection: dj.connection.Connection = None, -) -> list: - """ - Given list of descendant tables, find merge tables linked to parent table. - - Returns a list of tuples, with master and part. Part will have restriction - applied. If restriction yield empty list, skip. - - Parameters - ---------- - table_list : List[dj.Table] - A list of datajoint tables. - parent_table : dj.Table - The parent table of the tables in table_list. - restiction : str, optional - Restriction applies to parent table. Default True, no restriction. - connection : datajoint.connection.Connection - A database connection. Default None, use connection from first table. - - Returns - ------- - List[dict] - A list of dictionaries, each with keys: master, view, and keys. - These are the master table, the restricted view of the master table, - and the primary keys of the restricted view. - """ - conn = connection or table_list[0].connection - - ret = [] - unique_parts = [] - - # Adapted from Spyglass PR 535 - for table in table_list: - table_name = table.full_table_name - if table_name in unique_parts: # then repeat in list - continue - - master_name = get_master(table_name) - if not master_name: # then not a part table - continue - - master = dj.FreeTable(conn, master_name) - # if table has no master, or master is not MergeTable, skip - if RESERVED_PRIMARY_KEY not in master.heading.names: - continue - - try: - restricted_master = ( - master & restriction - ) # .merge_restrict(restriction) - except DataJointError: - continue - - if not restricted_master: # No entries relevant to restriction in part - continue - - unique_parts.append(table_name) - ret.append( - dict( - master=master, - view=restricted_master, - keys=restricted_master.fetch( - RESERVED_PRIMARY_KEY, as_dict=True - ), - ) - ) - - return ret diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index 46d536b44..2fd80ff5e 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -3,8 +3,10 @@ import datajoint as dj from datajoint.table import logger as dj_logger -from datajoint.utils import user_choice +from datajoint.user_tables import Table, TableMeta +from datajoint.utils import get_master, user_choice +from spyglass.settings import test_mode from spyglass.utils.database_settings import SHARED_MODULES from spyglass.utils.dj_helper_fn import fetch_nwb from spyglass.utils.dj_merge_tables import RESERVED_PRIMARY_KEY as MERGE_PK @@ -284,7 +286,9 @@ def _join_list( return ret[0] if len(ret) == 1 else ret def _connection_repr(self, connection) -> str: - if not isinstance(connection[0], Iterable): + if isinstance(connection[0], (Table, TableMeta)): + connection = [connection] + if not isinstance(connection[0], (list, tuple)): connection = [connection] ret = [] for table_list in connection: @@ -297,44 +301,6 @@ def _connection_repr(self, connection) -> str: ret.append(f"\n\tPath: {connection_str[:-4]}") return ret - def _find_session_link( - self, - table: dj.user_tables.UserTable, - search_limit: int = 2, - ) -> dj.expression.QueryExpression: - """Find Session table associated with table. - - Parameters - ---------- - table : datajoint.user_tables.UserTable - Table to search for Session ancestor. - Session : datajoint.user_tables.UserTable - Session table to search for. Passed as arg to prevent re-import. - search_limit : int, optional - Number of levels of children of target table to search. Default 2. - - Returns - ------- - datajoint.expression.QueryExpression or None - Join of table link with Session table if found, else None. - """ - Session = self._delete_deps[-1] - # TODO: check search_limit default is enough for any table in spyglass - if self._session_pk in table.primary_key: - # joinable with Session - return table * Session - - elif search_limit > 0: - for child in table.children(as_objects=True): - table = self._find_session_link(child, search_limit - 1) - if table: # table is link, will valid join to Session - return table - - elif not table or search_limit < 1: # if none found and limit reached - return # Err kept in parent func to centralize permission logic - - return table * Session - def _ensure_dependencies_loaded(self) -> None: """Ensure connection dependencies loaded.""" if not self.connection.dependencies._loaded: @@ -351,28 +317,17 @@ def _merge_tables(self) -> Dict[str, dj.FreeTable]: return self._merge_cache def has_merge_pk(table): - ret = MERGE_PK in table.heading.names - if "output" in table.full_table_name and not ret: - logger.warning( - f"Skipping merge table without merge primary key: " - + f"{table.full_table_name}" - ) return MERGE_PK in table.heading.names self._ensure_dependencies_loaded() for desc in self.descendants(as_objects=True): - if "output" not in desc.full_table_name: - logger.debug( - f"Skipping non-output table: {desc.full_table_name}" - ) if not has_merge_pk(desc): continue - for parent in desc.parents(as_objects=True): - if ( - has_merge_pk(parent) - and parent.full_table_name not in self._merge_cache - ): - self._merge_cache[parent.full_table_name] = parent + if not (master_name := get_master(desc.full_table_name)): + continue + master = dj.FreeTable(self.connection, master_name) + if has_merge_pk(master): + self._merge_cache[master_name] = master logger.info(f"Found {len(self._merge_cache)} merge tables") return self._merge_cache @@ -394,11 +349,21 @@ def _merge_links(self) -> Dict[str, List[dj.FreeTable]]: self._merge_cache_links[name] = connection return self._merge_cache_links + def _commit_merge_deletes(self, merge_join_dict, **kwargs): + ret = [] + for table, selection in merge_join_dict.items(): + keys = selection.fetch(MERGE_PK, as_dict=True) + ret.append(table & keys) # NEEDS TESTING WITH ACTUAL DELETE + # (table & keys).delete(**kwargs) # TODO: Run delete here + return ret + def delete_downstream_merge( self, restriction: str = None, dry_run: bool = True, + reload_cache: bool = False, disable_warning: bool = False, + return_parts: bool = True, **kwargs, ) -> List[dj.expression.QueryExpression]: """Delete downstream merge table entries associated with restricton. @@ -413,9 +378,20 @@ def delete_downstream_merge( dry_run : bool, optional If True, return list of merge part entries to be deleted. Default True. + reload_cache : bool, optional + If True, reload merge cache. Default False. + disable_warning : bool, optional + If True, do not warn if no merge tables found. Default False. + return_parts : bool, optional + If True, return list of merge part entries to be deleted. Default + True. If False, return dictionary of merge tables and their joins. **kwargs : Any Passed to datajoint.table.Table.delete. """ + if reload_cache: + self._merge_cache = {} + self._merge_cache_links = {} + restriction = restriction or self.restriction or True merge_join_dict = {} @@ -425,15 +401,15 @@ def delete_downstream_merge( if joined: merge_join_dict[self._merge_tables[merge_name]] = joined - if dry_run: - return merge_join_dict.values() + if not merge_join_dict and not disable_warning: + logger.warning( + f"No merge tables found downstream of {self.full_table_name}." + + "\n\tIf this is unexpected, try running with `reload_cache`." + ) - ret = [] - for table, selection in merge_join_dict.items(): - keys = selection.fetch(MERGE_PK, as_dict=True) - ret.append(table & keys) - # (table & keys).delete(**kwargs) # TODO: Run delete here - return ret + if dry_run: + return merge_join_dict + self._commit_merge_deletes(merge_join_dict, **kwargs) def ddm(self, *args, **kwargs): """Alias for delete_downstream_merge.""" @@ -471,9 +447,9 @@ def _session_connection(self) -> dj.expression.QueryExpression: None is not yet cached, False if no connection found. """ - if not self._session_connection_cache is None: + if self._session_connection_cache is None: self._session_connection_cache = ( - self._find_connection(parent=self._delete_deps[-1], child=self) + self._get_connection(parent=self._delete_deps[-1], child=self) or False ) return self._session_connection_cache @@ -506,10 +482,10 @@ def _check_delete_permission(self) -> None: ) return - sess_link = self.join(self._session_connection) - sess_summary = self._get_exp_summary(sess_link) + sess_summary = self._get_exp_summary() experimenters = sess_summary.fetch(self._member_pk) if None in experimenters: + # TODO: Check if allow delete of remainder? raise PermissionError( "Please ensure all Sessions have an experimenter in " + f"SessionExperimenter:\n{sess_summary}" @@ -554,11 +530,10 @@ def cautious_delete(self, force_permission: bool = False, *args, **kwargs): if not force_permission: self._check_delete_permission() - merge_deletes = self._merge_del_func( - self, - restriction=self.restriction if self.restriction else None, + merge_deletes = self.delete_downstream_merge( dry_run=True, disable_warning=True, + return_parts=False, ) safemode = ( @@ -568,15 +543,15 @@ def cautious_delete(self, force_permission: bool = False, *args, **kwargs): ) if merge_deletes: - for table, _ in merge_deletes: - count, name = len(table), table.full_table_name + for table, content in merge_deletes.items(): + count, name = len(content), table.full_table_name dj_logger.info(f"Merge: Deleting {count} rows from {name}") if ( - not safemode + not test_mode + or not safemode or user_choice("Commit deletes?", default="no") == "yes" ): - for merge_table, _ in merge_deletes: - merge_table.delete({**kwargs, "safemode": False}) + self._commit_merge_deletes(merge_deletes, **kwargs) else: logger.info("Delete aborted.") return From 7c0dd4d6f7930175ba89e4d313dcbc0f6890951c Mon Sep 17 00:00:00 2001 From: cbroz1 Date: Mon, 29 Jan 2024 13:08:59 -0800 Subject: [PATCH 04/10] Faster tree search with networkx --- src/spyglass/common/common_usage.py | 22 ++ src/spyglass/utils/dj_mixin.py | 436 +++++++++++++++------------- 2 files changed, 255 insertions(+), 203 deletions(-) create mode 100644 src/spyglass/common/common_usage.py diff --git a/src/spyglass/common/common_usage.py b/src/spyglass/common/common_usage.py new file mode 100644 index 000000000..67fac5412 --- /dev/null +++ b/src/spyglass/common/common_usage.py @@ -0,0 +1,22 @@ +"""A schema to store the usage of advanced Spyglass features. + +Records show usage of features such as table chains, which will be used to +determine which features are used, how often, and by whom. This will help +plan future development of Spyglass. +""" +import datajoint as dj + +schema = dj.schema("common_usage") + + +@schema +class CautiousDelete(dj.Manual): + definition = """ + id: int auto_increment + --- + dj_user: varchar(64) + duration: float + origin: varchar(64) + restriction: varchar(64) + merge_deletes = null: blob + """ diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index 2fd80ff5e..acd83bb9d 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -1,12 +1,13 @@ from collections.abc import Iterable +from time import time from typing import Dict, List, Union import datajoint as dj +import networkx as nx from datajoint.table import logger as dj_logger from datajoint.user_tables import Table, TableMeta from datajoint.utils import get_master, user_choice -from spyglass.settings import test_mode from spyglass.utils.database_settings import SHARED_MODULES from spyglass.utils.dj_helper_fn import fetch_nwb from spyglass.utils.dj_merge_tables import RESERVED_PRIMARY_KEY as MERGE_PK @@ -37,19 +38,27 @@ class SpyglassMixin: raised. `force_permission` can be set to True to bypass permission check. cdel(*args, **kwargs) Alias for cautious_delete. + delte_downstream_merge(restriction=None, dry_run=True, reload_cache=False) + Delete downstream merge table entries associated with restricton. + Requires caching of merge tables and links, which is slow on first call. + `restriction` can be set to a string to restrict the delete. `dry_run` + can be set to False to commit the delete. `reload_cache` can be set to + True to reload the merge cache. + ddm(*args, **kwargs) + Alias for delete_downstream_merge. """ _nwb_table_dict = {} # Dict mapping NWBFile table to path attribute name. # _nwb_table = None # NWBFile table class, defined at the table level _nwb_table_resolved = None # NWBFiletable class, resolved here from above _delete_dependencies = [] # Session, LabMember, LabTeam, delay import - _merge_delete_func = None # delete_downstream_merge, delay import # pks for delete permission check, assumed to be on field _session_pk = None # Session primary key. Mixin is ambivalent to Session pk _member_pk = None # LabMember primary key. Mixin ambivalent table structure - _merge_cache = {} # Cache of merge tables downstream of self - _merge_cache_links = {} # Cache of merge links downstream of self + _merge_table_cache = {} # Cache of merge tables downstream of self + _merge_chains_cache = {} # Cache of table chains to merges _session_connection_cache = None # Cache of path from Session to self + _usage_table_cache = None # Temporary inclusion for usage tracking # ------------------------------- fetch_nwb ------------------------------- @@ -144,167 +153,13 @@ def _delete_deps(self) -> list: return self._delete_dependencies @property - def _merge_del_func(self) -> callable: - """Callable: delete_downstream_merge function. - - Used to delay import of func until needed, avoiding circular imports. - """ - if not self._merge_delete_func: - from spyglass.utils.dj_merge_tables import ( # noqa F401 - delete_downstream_merge, - ) - - self._merge_delete_func = delete_downstream_merge - return self._merge_delete_func - - @staticmethod - def _get_instanced(*tables) -> Union[dj.user_tables.Table, list]: - """Return instance of table(s) if not already instanced.""" - ret = [] - if not isinstance(tables, Iterable): - tables = tuple(tables) - for table in tables: - if not isinstance(table, dj.user_tables.Table): - ret.append(table()) - else: - ret.append(table) - return ret[0] if len(ret) == 1 else ret + def _test_mode(self) -> bool: + """Return True if test mode is enabled.""" + if not self._test_mode_cache: + from spyglass.settings import test_mode - @staticmethod - def _link_repr(parent, child, width=120): - len_each = (width - 4) // 2 - p = parent.full_table_name[:len_each].ljust(len_each) - c = child.full_table_name[:len_each].ljust(len_each) - return f"{p} -> {c}" - - def _get_connection( - self, - child: dj.user_tables.TableMeta, - parent: dj.user_tables.TableMeta = None, - recurse_level: int = 4, - visited: set = None, - ) -> Union[List[dj.FreeTable], List[List[dj.FreeTable]]]: - """ - Return list of tables connecting the parent and child for a valid join. - - Parameters - ---------- - parent : dj.user_tables.TableMeta - DataJoint table upstream in pipeline. - child : dj.user_tables.TableMeta - DataJoint table downstream in pipeline. - recurse_level : int, optional - Maximum number of recursion levels. Default is 4. - visited : set, optional - Set of visited tables (used internally for recursion). - - Returns - ------- - List[dj.FreeTable] or List[List[dj.FreeTable]] - List of paths, with each path as a list FreeTables connecting the - parent and child for a valid join. - """ - parent = parent or self - parent, child = self._get_instanced(parent, child) - visited = visited or set() - child_is_merge = child.full_table_name in self._merge_tables - - if recurse_level < 1 or ( # if too much recursion - not child_is_merge # merge table ok - and ( # already visited, outside spyglass, or no connection - child.full_table_name in visited - or child.full_table_name.strip("`").split("_")[0] - not in SHARED_MODULES - or child.full_table_name not in parent.descendants() - ) - ): - return [] - - if child.full_table_name in parent.children(): - logger.debug(f"1-{recurse_level}:" + self._link_repr(parent, child)) - if isinstance(child, dict) or isinstance(parent, dict): - __import__("pdb").set_trace() - return [parent, child] - - if child_is_merge: - ret = [] - parts = child.parts(as_objects=True) - if not parts: - logger.warning(f"Merge has no parts: {child.full_table_name}") - for part in child.parts(as_objects=True): - links = self._get_connection( - parent=parent, - child=part, - recurse_level=recurse_level, - visited=visited, - ) - visited.add(part.full_table_name) - if links: - logger.debug( - f"2-{recurse_level}:" + self._link_repr(parent, part) - ) - ret.append(links + [child]) - - return ret - - for subchild in parent.children(as_objects=True): - links = self._get_connection( - parent=subchild, - child=child, - recurse_level=recurse_level - 1, - visited=visited, - ) - visited.add(subchild.full_table_name) - if links: - logger.debug( - f"3-{recurse_level}:" + self._link_repr(subchild, child) - ) - if parent.full_table_name in [l.full_table_name for l in links]: - return links - else: - return [parent] + links - - return [] - - def _join_list( - self, - tables: Union[List[dj.FreeTable], List[List[dj.FreeTable]]], - restriction: str = None, - ) -> dj.expression.QueryExpression: - """Return join of all tables in list. Omits empty items.""" - restriction = restriction or self.restriction or True - - if not isinstance(tables[0], (list, tuple)): - tables = [tables] - ret = [] - for table_list in tables: - join = table_list[0] & restriction - for table in table_list[1:]: - join = join * table - if join: - ret.append(join) - return ret[0] if len(ret) == 1 else ret - - def _connection_repr(self, connection) -> str: - if isinstance(connection[0], (Table, TableMeta)): - connection = [connection] - if not isinstance(connection[0], (list, tuple)): - connection = [connection] - ret = [] - for table_list in connection: - connection_str = "" - for table in table_list: - if isinstance(table, str): - connection_str += table + " -> " - else: - connection_str += table.table_name + " -> " - ret.append(f"\n\tPath: {connection_str[:-4]}") - return ret - - def _ensure_dependencies_loaded(self) -> None: - """Ensure connection dependencies loaded.""" - if not self.connection.dependencies._loaded: - self.connection.dependencies.load() + self._test_mode_cache = test_mode + return self._test_mode_cache @property def _merge_tables(self) -> Dict[str, dj.FreeTable]: @@ -313,13 +168,13 @@ def _merge_tables(self) -> Dict[str, dj.FreeTable]: Cache of items in parents of self.descendants(as_objects=True) that have a merge primary key. """ - if self._merge_cache: - return self._merge_cache + if self._merge_table_cache: + return self._merge_table_cache def has_merge_pk(table): return MERGE_PK in table.heading.names - self._ensure_dependencies_loaded() + self.connection.dependencies.load() for desc in self.descendants(as_objects=True): if not has_merge_pk(desc): continue @@ -327,13 +182,16 @@ def has_merge_pk(table): continue master = dj.FreeTable(self.connection, master_name) if has_merge_pk(master): - self._merge_cache[master_name] = master - logger.info(f"Found {len(self._merge_cache)} merge tables") + self._merge_table_cache[master_name] = master + logger.info( + f"Building merge cache for {self.table_name}.\n\t" + + f"Found {len(self._merge_table_cache)} downstream merge tables" + ) - return self._merge_cache + return self._merge_table_cache @property - def _merge_links(self) -> Dict[str, List[dj.FreeTable]]: + def _merge_chains(self) -> Dict[str, List[dj.FreeTable]]: """Dict of merge links downstream of self. For each merge table found in _merge_tables, find the path from self to @@ -341,21 +199,23 @@ def _merge_links(self) -> Dict[str, List[dj.FreeTable]]: to recompute whenever delete_downstream_merge is called with a new restriction. """ - if self._merge_cache_links: - return self._merge_cache_links + if self._merge_chains_cache: + return self._merge_chains_cache + for name, merge_table in self._merge_tables.items(): - connection = self._get_connection(child=merge_table) - if connection: - self._merge_cache_links[name] = connection - return self._merge_cache_links + chains = TableChains(self, merge_table, connection=self.connection) + if len(chains): + self._merge_chains_cache[name] = chains + return self._merge_chains_cache def _commit_merge_deletes(self, merge_join_dict, **kwargs): - ret = [] - for table, selection in merge_join_dict.items(): - keys = selection.fetch(MERGE_PK, as_dict=True) - ret.append(table & keys) # NEEDS TESTING WITH ACTUAL DELETE - # (table & keys).delete(**kwargs) # TODO: Run delete here - return ret + """Commit merge deletes. + + Extraxted for use in cautious_delete and delete_downstream_merge.""" + for table_name, part_restr in merge_join_dict.items(): + table = self._merge_tables[table_name] + keys = [part.fetch(MERGE_PK, as_dict=True) for part in part_restr] + (table & keys).delete(**kwargs) def delete_downstream_merge( self, @@ -389,17 +249,16 @@ def delete_downstream_merge( Passed to datajoint.table.Table.delete. """ if reload_cache: - self._merge_cache = {} - self._merge_cache_links = {} + self._merge_table_cache = {} + self._merge_chains_cache = {} restriction = restriction or self.restriction or True merge_join_dict = {} - for merge_name, merge_link in self._merge_links.items(): - logger.debug(self._connection_repr(merge_link)) - joined = self._join_list(merge_link, restriction=self.restriction) - if joined: - merge_join_dict[self._merge_tables[merge_name]] = joined + for name, chain in self._merge_chains.items(): + join = chain.join(restriction) + if join: + merge_join_dict[name] = join if not merge_join_dict and not disable_warning: logger.warning( @@ -408,12 +267,30 @@ def delete_downstream_merge( ) if dry_run: - return merge_join_dict + return merge_join_dict.values() if return_parts else merge_join_dict + self._commit_merge_deletes(merge_join_dict, **kwargs) - def ddm(self, *args, **kwargs): + def ddm( + self, + restriction: str = None, + dry_run: bool = True, + reload_cache: bool = False, + disable_warning: bool = False, + return_parts: bool = True, + *args, + **kwargs, + ): """Alias for delete_downstream_merge.""" - return self.delete_downstream_merge(*args, **kwargs) + return self.delete_downstream_merge( + restriction=restriction, + dry_run=dry_run, + reload_cache=reload_cache, + disable_warning=disable_warning, + return_parts=return_parts, + *args, + **kwargs, + ) def _get_exp_summary(self): """Get summary of experimenters for session(s), including NULL. @@ -429,16 +306,15 @@ def _get_exp_summary(self): Summary of experimenters for session(s). """ Session = self._delete_deps[-1] + SesExp = Session.Experimenter + empty_pk = {self._member_pk: "NULL"} format = dj.U(self._session_pk, self._member_pk) + sess_link = self._session_connection.join(self.restriction) + + exp_missing = format & (sess_link - SesExp).proj(**empty_pk) + exp_present = format & (sess_link * SesExp - exp_missing).proj() - sess_link = self._join_list(self._session_connection) - exp_missing = format & (sess_link - Session.Experimenter).proj( - **{self._member_pk: "NULL"} - ) - exp_present = ( - format & (sess_link * Session.Experimenter - exp_missing).proj() - ) return exp_missing + exp_present @property @@ -448,9 +324,9 @@ def _session_connection(self) -> dj.expression.QueryExpression: None is not yet cached, False if no connection found. """ if self._session_connection_cache is None: + connection = TableChain(parent=self._delete_deps[-1], child=self) self._session_connection_cache = ( - self._get_connection(parent=self._delete_deps[-1], child=self) - or False + connection if connection.has_link else False ) return self._session_connection_cache @@ -508,6 +384,15 @@ def _check_delete_permission(self) -> None: ) logger.info(f"Queueing delete for session(s):\n{sess_summary}") + @property + def _usage_table(self): + """Temporary inclusion for usage tracking.""" + if not self._usage_table_cache: + from spyglass.common.common_usage import CautiousDelete + + self._usage_table_cache = CautiousDelete + return self._usage_table_cache + # Rename to `delete` when we're ready to use it # TODO: Intercept datajoint delete confirmation prompt for merge deletes def cautious_delete(self, force_permission: bool = False, *args, **kwargs): @@ -526,6 +411,12 @@ def cautious_delete(self, force_permission: bool = False, *args, **kwargs): *args, **kwargs : Any Passed to datajoint.table.Table.delete. """ + start = time() + usage_dict = dict( + dj_user=dj.config["database.user"], + origin=self.full_table_name, + restriction=self.restriction, + ) if not force_permission: self._check_delete_permission() @@ -547,17 +438,156 @@ def cautious_delete(self, force_permission: bool = False, *args, **kwargs): count, name = len(content), table.full_table_name dj_logger.info(f"Merge: Deleting {count} rows from {name}") if ( - not test_mode + not self._test_mode or not safemode or user_choice("Commit deletes?", default="no") == "yes" ): self._commit_merge_deletes(merge_deletes, **kwargs) else: logger.info("Delete aborted.") + self._usage_table.insert1( + dict(duration=time() - start, **usage_dict) + ) return super().delete(*args, **kwargs) # Additional confirm here + self._usage_table.insert1( + dict( + duration=time() - start, + merge_deletes=merge_deletes, + ) + ) + def cdel(self, *args, **kwargs): """Alias for cautious_delete.""" self.cautious_delete(*args, **kwargs) + + +class TableChains: + """Class for representing chains from parent to Merge table via parts.""" + + def __init__(self, parent, child, connection=None): + self.parent = parent + self.child = child + self.connection = connection or parent.connection + parts = child.parts(as_objects=True) + self.part_names = [part.full_table_name for part in parts] + self.chains = [TableChain(parent, part) for part in parts] + self.has_link = any([chain.has_link for chain in self.chains]) + + def __repr__(self): + return "\n".join([str(chain) for chain in self.chains]) + + def __len__(self): + return len([c for c in self.chains if c.has_link]) + + def join(self, restriction=None): + restriction = restriction or self.parent.restriction or True + joins = [] + for chain in self.chains: + if joined := chain.join(restriction): + joins.append(joined) + return joins + + +class TableChain: + """Class for representing a chain of tables. + + Note: Parent -> Merge should use TableChains instead. + """ + + def __init__(self, parent: Table, child: Table, connection=None): + self._connection = connection or parent.connection + if not self._connection.dependencies._loaded: + self._connection.dependencies.load() + + if ( # if child is a merge table + get_master(child.full_table_name) == "" + and MERGE_PK in child.heading.names + ): + logger.error("Child is a merge table. Use TableChains instead.") + + self._link_symbol = " -> " + self.parent = parent + self.child = child + self._repr = None + self._names = None # full table names of tables in chain + self._objects = None # free tables in chain + self._has_link = child.full_table_name in parent.descendants() + + def __str__(self): + """Return string representation of chain: parent -> child.""" + if not self._has_link: + return "No link" + return ( + f"Chain: " + + self.parent.table_name + + self._link_symbol + + self.child.table_name + ) + + def __repr__(self): + """Return full representation of chain: parent -> {links} -> child.""" + if self._repr: + return self._repr + self._repr = ( + "Chain: " + + self._link_symbol.join([t.table_name for t in self.objects]) + if self.names + else "No link" + ) + return self._repr + + def __len__(self): + """Return number of tables in chain.""" + return len(self.names) + + @property + def has_link(self) -> bool: + """Return True if parent is linked to child. + + Cached as hidden attribute _has_link to set False if nx.NetworkXNoPath + is raised by nx.shortest_path. + """ + return self._has_link + + @property + def names(self) -> List[str]: + """Return list of full table names in chain. + + Uses networkx.shortest_path. + """ + if not self._has_link: + return None + if self._names: + return self._names + try: + self._names = nx.shortest_path( + self.parent.connection.dependencies, + self.parent.full_table_name, + self.child.full_table_name, + ) + return self._names + except nx.NetworkXNoPath: + self._has_link = False + return None + + @property + def objects(self) -> List[dj.FreeTable]: + """Return list of FreeTable objects for each table in chain.""" + if not self._objects: + self._objects = ( + [dj.FreeTable(self._connection, name) for name in self.names] + if self.names + else None + ) + return self._objects + + def join(self, restricton: str = None) -> dj.expression.QueryExpression: + """Return join of tables in chain with restriction applied to parent.""" + restriction = restricton or self.parent.restriction or True + join = self.objects[0] & restriction + for table in self.objects[1:]: + join = join * table + return join if join else None From 8385eb3d5dab2489611b5158e12daa5f861e0ae9 Mon Sep 17 00:00:00 2001 From: cbroz1 Date: Mon, 29 Jan 2024 13:19:18 -0800 Subject: [PATCH 05/10] Blackify --- src/spyglass/common/common_behav.py | 10 +++-- src/spyglass/common/common_device.py | 6 +-- src/spyglass/common/common_filter.py | 6 +-- src/spyglass/common/common_interval.py | 8 ++-- src/spyglass/common/common_lab.py | 1 + src/spyglass/common/common_usage.py | 1 + src/spyglass/decoding/v0/clusterless.py | 11 +++--- .../decoding/v0/dj_decoder_conversion.py | 37 ++++++++++--------- src/spyglass/decoding/v0/sorted_spikes.py | 1 + .../decoding/v0/visualization_2D_view.py | 2 +- src/spyglass/decoding/v1/clusterless.py | 6 +-- .../decoding/v1/dj_decoder_conversion.py | 1 - src/spyglass/decoding/v1/sorted_spikes.py | 6 +-- .../position/v1/position_dlc_orient.py | 18 ++++----- .../v1/position_dlc_pose_estimation.py | 20 +++++----- .../position/v1/position_dlc_position.py | 20 +++++----- src/spyglass/sharing/sharing_kachery.py | 6 +-- .../prepare_spikesortingview_data.py | 18 ++++----- src/spyglass/spikesorting/v1/recording.py | 6 +-- src/spyglass/utils/dj_helper_fn.py | 9 +++-- src/spyglass/utils/dj_merge_tables.py | 8 ++-- src/spyglass/utils/dj_mixin.py | 8 ++-- src/spyglass/utils/logging.py | 1 + src/spyglass/utils/nwb_helper_fn.py | 8 ++-- 24 files changed, 116 insertions(+), 102 deletions(-) diff --git a/src/spyglass/common/common_behav.py b/src/spyglass/common/common_behav.py index ed9673ecb..d7d4759fb 100644 --- a/src/spyglass/common/common_behav.py +++ b/src/spyglass/common/common_behav.py @@ -191,10 +191,12 @@ def _get_column_names(rp, pos_id): INDEX_ADJUST = 1 # adjust 0-index to 1-index (e.g., xloc0 -> xloc1) n_pos_dims = rp.data.shape[1] column_names = [ - col # use existing columns if already numbered - if "1" in rp.description or "2" in rp.description - # else number them by id - else col + str(pos_id + INDEX_ADJUST) + ( + col # use existing columns if already numbered + if "1" in rp.description or "2" in rp.description + # else number them by id + else col + str(pos_id + INDEX_ADJUST) + ) for col in rp.description.split(", ") ] if len(column_names) != n_pos_dims: diff --git a/src/spyglass/common/common_device.py b/src/spyglass/common/common_device.py index 2dd03c822..96fa11d44 100644 --- a/src/spyglass/common/common_device.py +++ b/src/spyglass/common/common_device.py @@ -476,9 +476,9 @@ def __read_ndx_probe_data( { "probe_id": nwb_probe_obj.probe_type, "probe_type": nwb_probe_obj.probe_type, - "contact_side_numbering": "True" - if nwb_probe_obj.contact_side_numbering - else "False", + "contact_side_numbering": ( + "True" if nwb_probe_obj.contact_side_numbering else "False" + ), } ) # go through the shanks and add each one to the Shank table diff --git a/src/spyglass/common/common_filter.py b/src/spyglass/common/common_filter.py index 988266d0d..59870f266 100644 --- a/src/spyglass/common/common_filter.py +++ b/src/spyglass/common/common_filter.py @@ -500,9 +500,9 @@ def filter_data( for ii, (start, stop) in enumerate(indices): extracted_ts = timestamps[start:stop:decimation] - new_timestamps[ - ts_offset : ts_offset + len(extracted_ts) - ] = extracted_ts + new_timestamps[ts_offset : ts_offset + len(extracted_ts)] = ( + extracted_ts + ) ts_offset += len(extracted_ts) # finally ready to filter data! diff --git a/src/spyglass/common/common_interval.py b/src/spyglass/common/common_interval.py index d754261fc..24b143ad6 100644 --- a/src/spyglass/common/common_interval.py +++ b/src/spyglass/common/common_interval.py @@ -56,9 +56,11 @@ def insert_from_nwbfile(cls, nwbf, *, nwb_file_name): for _, epoch_data in epochs.iterrows(): epoch_dict = { "nwb_file_name": nwb_file_name, - "interval_list_name": epoch_data.tags[0] - if epoch_data.tags - else f"interval_{epoch_data[0]}", + "interval_list_name": ( + epoch_data.tags[0] + if epoch_data.tags + else f"interval_{epoch_data[0]}" + ), "valid_times": np.asarray( [[epoch_data.start_time, epoch_data.stop_time]] ), diff --git a/src/spyglass/common/common_lab.py b/src/spyglass/common/common_lab.py index 177fc4424..ca9d4359a 100644 --- a/src/spyglass/common/common_lab.py +++ b/src/spyglass/common/common_lab.py @@ -1,4 +1,5 @@ """Schema for institution, lab team/name/members. Session-independent.""" + import datajoint as dj from spyglass.utils import SpyglassMixin, logger diff --git a/src/spyglass/common/common_usage.py b/src/spyglass/common/common_usage.py index 67fac5412..716649574 100644 --- a/src/spyglass/common/common_usage.py +++ b/src/spyglass/common/common_usage.py @@ -4,6 +4,7 @@ determine which features are used, how often, and by whom. This will help plan future development of Spyglass. """ + import datajoint as dj schema = dj.schema("common_usage") diff --git a/src/spyglass/decoding/v0/clusterless.py b/src/spyglass/decoding/v0/clusterless.py index f6fd9df37..e5577225d 100644 --- a/src/spyglass/decoding/v0/clusterless.py +++ b/src/spyglass/decoding/v0/clusterless.py @@ -6,6 +6,7 @@ [1] Denovellis, E. L. et al. Hippocampal replay of experience at real-world speeds. eLife 10, e64505 (2021). """ + import os import shutil import uuid @@ -654,11 +655,11 @@ def make(self, key): key["nwb_file_name"] ) - key[ - "multiunit_firing_rate_object_id" - ] = nwb_analysis_file.add_nwb_object( - analysis_file_name=key["analysis_file_name"], - nwb_object=multiunit_firing_rate.reset_index(), + key["multiunit_firing_rate_object_id"] = ( + nwb_analysis_file.add_nwb_object( + analysis_file_name=key["analysis_file_name"], + nwb_object=multiunit_firing_rate.reset_index(), + ) ) nwb_analysis_file.add( diff --git a/src/spyglass/decoding/v0/dj_decoder_conversion.py b/src/spyglass/decoding/v0/dj_decoder_conversion.py index 73566f24a..edcb0d637 100644 --- a/src/spyglass/decoding/v0/dj_decoder_conversion.py +++ b/src/spyglass/decoding/v0/dj_decoder_conversion.py @@ -1,5 +1,6 @@ """Converts decoder classes into dictionaries and dictionaries into classes so that datajoint can store them in tables.""" + from spyglass.utils import logger try: @@ -116,17 +117,17 @@ def restore_classes(params: dict) -> dict: _convert_env_dict(env_params) for env_params in params["classifier_params"]["environments"] ] - params["classifier_params"][ - "discrete_transition_type" - ] = _convert_dict_to_class( - params["classifier_params"]["discrete_transition_type"], - discrete_state_transition_types, + params["classifier_params"]["discrete_transition_type"] = ( + _convert_dict_to_class( + params["classifier_params"]["discrete_transition_type"], + discrete_state_transition_types, + ) ) - params["classifier_params"][ - "initial_conditions_type" - ] = _convert_dict_to_class( - params["classifier_params"]["initial_conditions_type"], - initial_conditions_types, + params["classifier_params"]["initial_conditions_type"] = ( + _convert_dict_to_class( + params["classifier_params"]["initial_conditions_type"], + initial_conditions_types, + ) ) if params["classifier_params"].get("observation_models"): @@ -176,10 +177,10 @@ def convert_classes_to_dict(key: dict) -> dict: key["classifier_params"]["environments"] ) ] - key["classifier_params"][ - "continuous_transition_types" - ] = _convert_transitions_to_dict( - key["classifier_params"]["continuous_transition_types"] + key["classifier_params"]["continuous_transition_types"] = ( + _convert_transitions_to_dict( + key["classifier_params"]["continuous_transition_types"] + ) ) key["classifier_params"]["discrete_transition_type"] = _to_dict( key["classifier_params"]["discrete_transition_type"] @@ -194,10 +195,10 @@ def convert_classes_to_dict(key: dict) -> dict: ] try: - key["classifier_params"][ - "clusterless_algorithm_params" - ] = _convert_algorithm_params( - key["classifier_params"]["clusterless_algorithm_params"] + key["classifier_params"]["clusterless_algorithm_params"] = ( + _convert_algorithm_params( + key["classifier_params"]["clusterless_algorithm_params"] + ) ) except KeyError: pass diff --git a/src/spyglass/decoding/v0/sorted_spikes.py b/src/spyglass/decoding/v0/sorted_spikes.py index abe7ec207..acfc501cf 100644 --- a/src/spyglass/decoding/v0/sorted_spikes.py +++ b/src/spyglass/decoding/v0/sorted_spikes.py @@ -7,6 +7,7 @@ speeds. eLife 10, e64505 (2021). """ + import pprint import datajoint as dj diff --git a/src/spyglass/decoding/v0/visualization_2D_view.py b/src/spyglass/decoding/v0/visualization_2D_view.py index 52338ea78..14dcd204c 100644 --- a/src/spyglass/decoding/v0/visualization_2D_view.py +++ b/src/spyglass/decoding/v0/visualization_2D_view.py @@ -38,7 +38,7 @@ def create_static_track_animation( "xmin": np.min(ul_corners[0]), "xmax": np.max(ul_corners[0]) + track_rect_width, "ymin": np.min(ul_corners[1]), - "ymax": np.max(ul_corners[1]) + track_rect_height + "ymax": np.max(ul_corners[1]) + track_rect_height, # Speed: should this be displayed? # TODO: Better approach for accommodating further data streams } diff --git a/src/spyglass/decoding/v1/clusterless.py b/src/spyglass/decoding/v1/clusterless.py index 3b179d7ee..1751ff8eb 100644 --- a/src/spyglass/decoding/v1/clusterless.py +++ b/src/spyglass/decoding/v1/clusterless.py @@ -240,9 +240,9 @@ def make(self, key): vars(classifier).get("discrete_transition_coefficients_") is not None ): - results[ - "discrete_transition_coefficients" - ] = classifier.discrete_transition_coefficients_ + results["discrete_transition_coefficients"] = ( + classifier.discrete_transition_coefficients_ + ) # Insert results # in future use https://github.com/rly/ndx-xarray and analysis nwb file? diff --git a/src/spyglass/decoding/v1/dj_decoder_conversion.py b/src/spyglass/decoding/v1/dj_decoder_conversion.py index 2795f8be9..c52c95a72 100644 --- a/src/spyglass/decoding/v1/dj_decoder_conversion.py +++ b/src/spyglass/decoding/v1/dj_decoder_conversion.py @@ -1,7 +1,6 @@ """Converts decoder classes into dictionaries and dictionaries into classes so that datajoint can store them in tables.""" - import copy import datajoint as dj diff --git a/src/spyglass/decoding/v1/sorted_spikes.py b/src/spyglass/decoding/v1/sorted_spikes.py index 9f968d768..40041691a 100644 --- a/src/spyglass/decoding/v1/sorted_spikes.py +++ b/src/spyglass/decoding/v1/sorted_spikes.py @@ -232,9 +232,9 @@ def make(self, key): vars(classifier).get("discrete_transition_coefficients_") is not None ): - results[ - "discrete_transition_coefficients" - ] = classifier.discrete_transition_coefficients_ + results["discrete_transition_coefficients"] = ( + classifier.discrete_transition_coefficients_ + ) # Insert results # in future use https://github.com/rly/ndx-xarray and analysis nwb file? diff --git a/src/spyglass/position/v1/position_dlc_orient.py b/src/spyglass/position/v1/position_dlc_orient.py index 421e5330e..9b226d1a0 100644 --- a/src/spyglass/position/v1/position_dlc_orient.py +++ b/src/spyglass/position/v1/position_dlc_orient.py @@ -241,15 +241,15 @@ def interp_orientation(orientation, spans_to_interp, **kwargs): # TODO: add parameters to refine interpolation for ind, (span_start, span_stop) in enumerate(spans_to_interp): if (span_stop + 1) >= len(orientation): - orientation.loc[ - idx[span_start:span_stop], idx["orientation"] - ] = np.nan + orientation.loc[idx[span_start:span_stop], idx["orientation"]] = ( + np.nan + ) print(f"ind: {ind} has no endpoint with which to interpolate") continue if span_start < 1: - orientation.loc[ - idx[span_start:span_stop], idx["orientation"] - ] = np.nan + orientation.loc[idx[span_start:span_stop], idx["orientation"]] = ( + np.nan + ) print(f"ind: {ind} has no startpoint with which to interpolate") continue orient = [ @@ -263,7 +263,7 @@ def interp_orientation(orientation, spans_to_interp, **kwargs): xp=[start_time, stop_time], fp=[orient[0], orient[-1]], ) - orientation.loc[ - idx[start_time:stop_time], idx["orientation"] - ] = orientnew + orientation.loc[idx[start_time:stop_time], idx["orientation"]] = ( + orientnew + ) return orientation diff --git a/src/spyglass/position/v1/position_dlc_pose_estimation.py b/src/spyglass/position/v1/position_dlc_pose_estimation.py index 63603932c..500f888b5 100644 --- a/src/spyglass/position/v1/position_dlc_pose_estimation.py +++ b/src/spyglass/position/v1/position_dlc_pose_estimation.py @@ -309,17 +309,17 @@ def make(self, key): description="video_frame_ind", ) nwb_analysis_file = AnalysisNwbfile() - key[ - "dlc_pose_estimation_position_object_id" - ] = nwb_analysis_file.add_nwb_object( - analysis_file_name=key["analysis_file_name"], - nwb_object=position, + key["dlc_pose_estimation_position_object_id"] = ( + nwb_analysis_file.add_nwb_object( + analysis_file_name=key["analysis_file_name"], + nwb_object=position, + ) ) - key[ - "dlc_pose_estimation_likelihood_object_id" - ] = nwb_analysis_file.add_nwb_object( - analysis_file_name=key["analysis_file_name"], - nwb_object=likelihood, + key["dlc_pose_estimation_likelihood_object_id"] = ( + nwb_analysis_file.add_nwb_object( + analysis_file_name=key["analysis_file_name"], + nwb_object=likelihood, + ) ) nwb_analysis_file.add( nwb_file_name=key["nwb_file_name"], diff --git a/src/spyglass/position/v1/position_dlc_position.py b/src/spyglass/position/v1/position_dlc_position.py index 0e1ae4ef5..0916115e5 100644 --- a/src/spyglass/position/v1/position_dlc_position.py +++ b/src/spyglass/position/v1/position_dlc_position.py @@ -248,17 +248,17 @@ def make(self, key): comments="no comments", description="video_frame_ind", ) - key[ - "dlc_smooth_interp_position_object_id" - ] = nwb_analysis_file.add_nwb_object( - analysis_file_name=key["analysis_file_name"], - nwb_object=position, + key["dlc_smooth_interp_position_object_id"] = ( + nwb_analysis_file.add_nwb_object( + analysis_file_name=key["analysis_file_name"], + nwb_object=position, + ) ) - key[ - "dlc_smooth_interp_info_object_id" - ] = nwb_analysis_file.add_nwb_object( - analysis_file_name=key["analysis_file_name"], - nwb_object=video_frame_ind, + key["dlc_smooth_interp_info_object_id"] = ( + nwb_analysis_file.add_nwb_object( + analysis_file_name=key["analysis_file_name"], + nwb_object=video_frame_ind, + ) ) nwb_analysis_file.add( nwb_file_name=key["nwb_file_name"], diff --git a/src/spyglass/sharing/sharing_kachery.py b/src/spyglass/sharing/sharing_kachery.py index e3b9111ec..5aa4ebe56 100644 --- a/src/spyglass/sharing/sharing_kachery.py +++ b/src/spyglass/sharing/sharing_kachery.py @@ -105,9 +105,9 @@ def set_resource_url(key: dict): def reset_resource_url(): KacheryZone.reset_zone() if default_kachery_resource_url is not None: - os.environ[ - kachery_resource_url_envar - ] = default_kachery_resource_url + os.environ[kachery_resource_url_envar] = ( + default_kachery_resource_url + ) @schema diff --git a/src/spyglass/spikesorting/figurl_views/prepare_spikesortingview_data.py b/src/spyglass/spikesorting/figurl_views/prepare_spikesortingview_data.py index 46138b696..c43031225 100644 --- a/src/spyglass/spikesorting/figurl_views/prepare_spikesortingview_data.py +++ b/src/spyglass/spikesorting/figurl_views/prepare_spikesortingview_data.py @@ -102,16 +102,16 @@ def prepare_spikesortingview_data( channel_neighborhood_size=channel_neighborhood_size, ) if len(spike_train) >= 10: - unit_peak_channel_ids[ - str(unit_id) - ] = peak_channel_id + unit_peak_channel_ids[str(unit_id)] = ( + peak_channel_id + ) else: - fallback_unit_peak_channel_ids[ - str(unit_id) - ] = peak_channel_id - unit_channel_neighborhoods[ - str(unit_id) - ] = channel_neighborhood + fallback_unit_peak_channel_ids[str(unit_id)] = ( + peak_channel_id + ) + unit_channel_neighborhoods[str(unit_id)] = ( + channel_neighborhood + ) for unit_id in unit_ids: peak_channel_id = unit_peak_channel_ids.get(str(unit_id), None) if peak_channel_id is None: diff --git a/src/spyglass/spikesorting/v1/recording.py b/src/spyglass/spikesorting/v1/recording.py index d795c8fe3..996611d9a 100644 --- a/src/spyglass/spikesorting/v1/recording.py +++ b/src/spyglass/spikesorting/v1/recording.py @@ -318,9 +318,9 @@ def _get_recording_timestamps(recording): timestamps = np.zeros((total_frames,)) for i in range(recording.get_num_segments()): - timestamps[ - cumsum_frames[i] : cumsum_frames[i + 1] - ] = recording.get_times(segment_index=i) + timestamps[cumsum_frames[i] : cumsum_frames[i + 1]] = ( + recording.get_times(segment_index=i) + ) else: timestamps = recording.get_times() return timestamps diff --git a/src/spyglass/utils/dj_helper_fn.py b/src/spyglass/utils/dj_helper_fn.py index 390bb2add..4a0495778 100644 --- a/src/spyglass/utils/dj_helper_fn.py +++ b/src/spyglass/utils/dj_helper_fn.py @@ -1,4 +1,5 @@ """Helper functions for manipulating information from DataJoint fetch calls.""" + import inspect import os from typing import Type @@ -193,9 +194,11 @@ def get_child_tables(table): return [ dj.FreeTable( table.connection, - s - if not s.isdigit() - else next(iter(table.connection.dependencies.children(s))), + ( + s + if not s.isdigit() + else next(iter(table.connection.dependencies.children(s))) + ), ) for s in table.children() ] diff --git a/src/spyglass/utils/dj_merge_tables.py b/src/spyglass/utils/dj_merge_tables.py index 7917bb161..c0dec296f 100644 --- a/src/spyglass/utils/dj_merge_tables.py +++ b/src/spyglass/utils/dj_merge_tables.py @@ -601,9 +601,11 @@ def merge_get_part( ) parts = [ - getattr(cls, source)().restrict(restriction) - if restrict_part # Re-apply restriction or don't - else getattr(cls, source)() + ( + getattr(cls, source)().restrict(restriction) + if restrict_part # Re-apply restriction or don't + else getattr(cls, source)() + ) for source in sources ] if join_master: diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index acd83bb9d..cc636664d 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -106,9 +106,7 @@ def _nwb_table_tuple(self): self._nwb_table_resolved = ( AnalysisNwbfile if "-> AnalysisNwbfile" in self.definition - else Nwbfile - if "-> Nwbfile" in self.definition - else None + else Nwbfile if "-> Nwbfile" in self.definition else None ) if getattr(self, "_nwb_table_resolved", None) is None: @@ -195,7 +193,7 @@ def _merge_chains(self) -> Dict[str, List[dj.FreeTable]]: """Dict of merge links downstream of self. For each merge table found in _merge_tables, find the path from self to - merge. If the path is valid, add it to the dict. Cahche prevents need + merge. If the path is valid, add it to the dict. Cache prevents need to recompute whenever delete_downstream_merge is called with a new restriction. """ @@ -211,7 +209,7 @@ def _merge_chains(self) -> Dict[str, List[dj.FreeTable]]: def _commit_merge_deletes(self, merge_join_dict, **kwargs): """Commit merge deletes. - Extraxted for use in cautious_delete and delete_downstream_merge.""" + Extracted for use in cautious_delete and delete_downstream_merge.""" for table_name, part_restr in merge_join_dict.items(): table = self._merge_tables[table_name] keys = [part.fetch(MERGE_PK, as_dict=True) for part in part_restr] diff --git a/src/spyglass/utils/logging.py b/src/spyglass/utils/logging.py index e16706f45..1771a160f 100644 --- a/src/spyglass/utils/logging.py +++ b/src/spyglass/utils/logging.py @@ -1,4 +1,5 @@ """Logging configuration based on datajoint/logging.py""" + import logging import sys diff --git a/src/spyglass/utils/nwb_helper_fn.py b/src/spyglass/utils/nwb_helper_fn.py index 6b7947b2d..a5b184635 100644 --- a/src/spyglass/utils/nwb_helper_fn.py +++ b/src/spyglass/utils/nwb_helper_fn.py @@ -383,9 +383,11 @@ def get_electrode_indices(nwb_object, electrode_ids): # that if it's there and invalid_electrode_index if not. return [ - selected_elect_ids.index(elect_id) - if elect_id in selected_elect_ids - else invalid_electrode_index + ( + selected_elect_ids.index(elect_id) + if elect_id in selected_elect_ids + else invalid_electrode_index + ) for elect_id in electrode_ids ] From 24fa6ba8447a0572572ae993e0a3a734d421d351 Mon Sep 17 00:00:00 2001 From: cbroz1 Date: Mon, 29 Jan 2024 13:21:17 -0800 Subject: [PATCH 06/10] Blackify 2 --- config/dj_config.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/config/dj_config.py b/config/dj_config.py index 55fd8a9ad..6b25a5d57 100755 --- a/config/dj_config.py +++ b/config/dj_config.py @@ -15,9 +15,7 @@ def main(*args): save_method = ( "local" if filename == "dj_local_conf.json" - else "global" - if filename is None - else "custom" + else "global" if filename is None else "custom" ) config.save_dj_config( From 170fb1ef747aebca63f96d4bd1c674c347e430a9 Mon Sep 17 00:00:00 2001 From: cbroz1 Date: Mon, 29 Jan 2024 14:10:27 -0800 Subject: [PATCH 07/10] Update changelog/docs --- CHANGELOG.md | 10 +++++++--- docs/src/misc/merge_tables.md | 13 +++++++++++-- src/spyglass/utils/dj_mixin.py | 28 ++++++++++++++-------------- 3 files changed, 32 insertions(+), 19 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bbf2ca515..0ddb62a59 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,19 +8,23 @@ - Clean up following pre-commit checks. #688 - Add Mixin class to centralize `fetch_nwb` functionality. #692, #734 - Refactor restriction use in `delete_downstream_merge` #703 -- Add `cautious_delete` to Mixin class, initial implementation. #711, #762 +- Add `cautious_delete` to Mixin class + - Initial implementation. #711, #762 + - More robust caching of join to downstream tables. #806 - Add `deprecation_factory` to facilitate table migration. #717 - Add Spyglass logger. #730 - IntervalList: Add secondary key `pipeline` #742 - Increase pytest coverage for `common`, `lfp`, and `utils`. #743 - Update docs to reflect new notebooks. #776 - Add overview of Spyglass to docs. #779 -- LFPV1: Fix error for multiple lfp settings on same data #775 + ### Pipelines - Spike sorting: Add SpikeSorting V1 pipeline. #651 -- LFP: Minor fixes to LFPBandV1 populator and `make`. #706, #795 +- LFP: + - Minor fixes to LFPBandV1 populator and `make`. #706, #795 + - LFPV1: Fix error for multiple lfp settings on same data #775 - Linearization: - Minor fixes to LinearizedPositionV1 pipeline #695 - Rename `position_linearization` -> `linearization`. #717 diff --git a/docs/src/misc/merge_tables.md b/docs/src/misc/merge_tables.md index 981ea40f7..c11e82670 100644 --- a/docs/src/misc/merge_tables.md +++ b/docs/src/misc/merge_tables.md @@ -15,8 +15,17 @@ deleting a part entry before the master. To circumvent this, you can add `force_parts=True` to the [`delete` function](https://datajoint.com/docs/core/datajoint-python/0.14/api/datajoint/__init__/#datajoint.table.Table.delete) call, but this will leave and orphaned primary key in the master. Instead, use -`spyglass.utils.dj_merge_tables.delete_downstream_merge` to delete master/part -pairs. +`(YourTable & restriction).delete_downstream_merge()` to delete master/part +pairs. If errors persist, identify and import the offending part table and +rerun `delete_downstream_merge` with `reload_cache=True`. This process will +be faster for subsequent calls if you reassign the your table after importing. + +```python +from spyglass.common import Nwbfile +nwbfile = Nwbfile() +(nwbfile & "nwb_file_name LIKE 'Name%'").delete_downstream_merge() +``` + ## What diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index cc636664d..9889b8a40 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -391,6 +391,18 @@ def _usage_table(self): self._usage_table_cache = CautiousDelete return self._usage_table_cache + def _log_use(self, start, merge_deletes=None): + """Log use of cautious_delete.""" + self._usage_table.insert1( + dict( + duration=time() - start, + dj_user=dj.config["database.user"], + origin=self.full_table_name, + restriction=self.restriction, + merge_deletes=merge_deletes, + ) + ) + # Rename to `delete` when we're ready to use it # TODO: Intercept datajoint delete confirmation prompt for merge deletes def cautious_delete(self, force_permission: bool = False, *args, **kwargs): @@ -410,11 +422,6 @@ def cautious_delete(self, force_permission: bool = False, *args, **kwargs): Passed to datajoint.table.Table.delete. """ start = time() - usage_dict = dict( - dj_user=dj.config["database.user"], - origin=self.full_table_name, - restriction=self.restriction, - ) if not force_permission: self._check_delete_permission() @@ -443,19 +450,12 @@ def cautious_delete(self, force_permission: bool = False, *args, **kwargs): self._commit_merge_deletes(merge_deletes, **kwargs) else: logger.info("Delete aborted.") - self._usage_table.insert1( - dict(duration=time() - start, **usage_dict) - ) + self._log_use(start) return super().delete(*args, **kwargs) # Additional confirm here - self._usage_table.insert1( - dict( - duration=time() - start, - merge_deletes=merge_deletes, - ) - ) + self._log_use(start=start, merge_deletes=merge_deletes) def cdel(self, *args, **kwargs): """Alias for cautious_delete.""" From e706c7d0bf20181bbed1a40cc83e7e9b7652c293 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Mon, 29 Jan 2024 16:38:18 -0600 Subject: [PATCH 08/10] Update notebooks --- CHANGELOG.md | 3 +- docs/src/misc/merge_tables.md | 8 +- notebooks/01_Insert_Data.ipynb | 63 ++++----- notebooks/03_Merge_Tables.ipynb | 169 +++++++++++++++++------- notebooks/py_scripts/01_Insert_Data.py | 45 ++++--- notebooks/py_scripts/03_Merge_Tables.py | 35 +++-- notebooks/py_scripts/11_Curation.py | 2 +- src/spyglass/common/common_usage.py | 2 +- src/spyglass/utils/dj_merge_tables.py | 17 +-- src/spyglass/utils/dj_mixin.py | 14 +- 10 files changed, 220 insertions(+), 138 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0ddb62a59..27779bfd3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,11 +18,10 @@ - Update docs to reflect new notebooks. #776 - Add overview of Spyglass to docs. #779 - ### Pipelines - Spike sorting: Add SpikeSorting V1 pipeline. #651 -- LFP: +- LFP: - Minor fixes to LFPBandV1 populator and `make`. #706, #795 - LFPV1: Fix error for multiple lfp settings on same data #775 - Linearization: diff --git a/docs/src/misc/merge_tables.md b/docs/src/misc/merge_tables.md index c11e82670..1cd4b000b 100644 --- a/docs/src/misc/merge_tables.md +++ b/docs/src/misc/merge_tables.md @@ -16,17 +16,17 @@ deleting a part entry before the master. To circumvent this, you can add [`delete` function](https://datajoint.com/docs/core/datajoint-python/0.14/api/datajoint/__init__/#datajoint.table.Table.delete) call, but this will leave and orphaned primary key in the master. Instead, use `(YourTable & restriction).delete_downstream_merge()` to delete master/part -pairs. If errors persist, identify and import the offending part table and -rerun `delete_downstream_merge` with `reload_cache=True`. This process will -be faster for subsequent calls if you reassign the your table after importing. +pairs. If errors persist, identify and import the offending part table and rerun +`delete_downstream_merge` with `reload_cache=True`. This process will be faster +for subsequent calls if you reassign the your table after importing. ```python from spyglass.common import Nwbfile + nwbfile = Nwbfile() (nwbfile & "nwb_file_name LIKE 'Name%'").delete_downstream_merge() ``` - ## What A Merge Table is fundamentally a master table with one part for each divergent diff --git a/notebooks/01_Insert_Data.ipynb b/notebooks/01_Insert_Data.ipynb index f0d89cdfa..de31ea7c8 100644 --- a/notebooks/01_Insert_Data.ipynb +++ b/notebooks/01_Insert_Data.ipynb @@ -45,8 +45,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "[2023-10-05 11:48:12,292][INFO]: Connecting root@localhost:3306\n", - "[2023-10-05 11:48:12,302][INFO]: Connected root@localhost:3306\n" + "[2024-01-29 16:24:30,933][INFO]: Connecting root@localhost:3309\n", + "[2024-01-29 16:24:30,942][INFO]: Connected root@localhost:3309\n" ] } ], @@ -719,9 +719,9 @@ "\n", "- `minirec20230622.nwb`, .3 GB: minimal recording,\n", " [Link](https://ucsf.box.com/s/k3sgql6z475oia848q1rgms4zdh4rkjn)\n", - "- `mediumnwb20230802.nwb`, 32 GB: full-featured dataset, \n", - " [Link](https://ucsf.box.com/s/2qbhxghzpttfam4b7q7j8eg0qkut0opa) \n", - "- `montague20200802.nwb`, 8 GB: full experimental recording, \n", + "- `mediumnwb20230802.nwb`, 32 GB: full-featured dataset,\n", + " [Link](https://ucsf.box.com/s/2qbhxghzpttfam4b7q7j8eg0qkut0opa)\n", + "- `montague20200802.nwb`, 8 GB: full experimental recording,\n", " [Link](https://ucsf.box.com/s/26je2eytjpqepyznwpm92020ztjuaomb)\n", "- For those in the UCSF network, these and many others on `/stelmo/nwb/raw`\n", "\n", @@ -747,7 +747,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Spyglass will create a copy with this name." + "Spyglass will create a copy with this name.\n" ] }, { @@ -1072,7 +1072,6 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "\n", "`spyglass.data_import.insert_sessions` helps take the many fields of data\n", "present in an NWB file and insert them into various tables across Spyglass. If\n", "the NWB file is properly composed, this includes...\n", @@ -1082,8 +1081,8 @@ "- neural activity (extracellular recording of multiple brain areas)\n", "- etc.\n", "\n", - "_Note:_ this may take time as Spyglass creates the copy. You may see a prompt \n", - "about inserting device information." + "_Note:_ this may take time as Spyglass creates the copy. You may see a prompt\n", + "about inserting device information.\n" ] }, { @@ -2053,21 +2052,20 @@ "metadata": {}, "source": [ "`IntervalList` has an additional secondary key `pipeline` which can describe the origin of the data.\n", - "Because it is a _secondary_ key, it is not required to uniquely identify an entry. \n", + "Because it is a _secondary_ key, it is not required to uniquely identify an entry.\n", "Current values for this key from spyglass pipelines are:\n", "\n", - "| pipeline | Source|\n", - "| --- | --- |\n", - "| position | sg.common.PositionSource |\n", - "| lfp_v0 | sg.common.LFP |\n", - "| lfp_v1 | sg.lfp.v1.LFPV1 |\n", - "| lfp_band | sg.common.LFPBand,
sg.lfp.analysis.v1.LFPBandV1 |\n", - "| lfp_artifact | sg.lfp.v1.LFPArtifactDetection |\n", - "| spikesorting_artifact_v0 | sg.spikesorting.ArtifactDetection |\n", - "| spikesorting_artifact_v1 | sg.spikesorting.v1.ArtifactDetection |\n", - "| spikesorting_recording_v0 | sg.spikesorting.SpikeSortingRecording |\n", - "| spikesorting_recording_v1 | sg.spikesorting.v1.SpikeSortingRecording |\n", - "\n" + "| pipeline | Source |\n", + "| ------------------------- | --------------------------------------------------- |\n", + "| position | sg.common.PositionSource |\n", + "| lfp_v0 | sg.common.LFP |\n", + "| lfp_v1 | sg.lfp.v1.LFPV1 |\n", + "| lfp_band | sg.common.LFPBand,
sg.lfp.analysis.v1.LFPBandV1 |\n", + "| lfp_artifact | sg.lfp.v1.LFPArtifactDetection |\n", + "| spikesorting_artifact_v0 | sg.spikesorting.ArtifactDetection |\n", + "| spikesorting_artifact_v1 | sg.spikesorting.v1.ArtifactDetection |\n", + "| spikesorting_recording_v0 | sg.spikesorting.SpikeSortingRecording |\n", + "| spikesorting_recording_v1 | sg.spikesorting.v1.SpikeSortingRecording |\n" ] }, { @@ -2086,9 +2084,9 @@ "with _cascading deletes_. For example, if we delete our `Session` entry, all\n", "associated downstream entries are also deleted (e.g. `Raw`, `IntervalList`).\n", "\n", - "_Note_: The deletion process can be complicated by \n", + "_Note_: The deletion process can be complicated by\n", "[Merge Tables](https://lorenfranklab.github.io/spyglass/0.4/misc/merge_tables/)\n", - "when the entry is referenced by a part table. To demo deletion in these cases, \n", + "when the entry is referenced by a part table. To demo deletion in these cases,\n", "run the hidden code below.\n", "\n", "
\n", @@ -2113,20 +2111,23 @@ "lfp.v1.LFPSelection.insert1(lfp_key, skip_duplicates=True)\n", "lfp.v1.LFPV1().populate(lfp_key)\n", "```\n", + "\n", "
\n", "
\n", "Deleting Merge Entries\n", "\n", "```python\n", - "from spyglass.utils.dj_merge_tables import delete_downstream_merge\n", + "nwbfile = sgc.Nwbfile()\n", "\n", - "delete_downstream_merge(\n", - " sgc.Nwbfile(),\n", - " restriction={\"nwb_file_name\": nwb_copy_file_name},\n", + "(nwbfile & {\"nwb_file_name\": nwb_copy_file_name}).delete_downstream_merge(\n", " dry_run=False, # True will show Merge Table entries that would be deleted\n", - ") \n", + ")\n", "```\n", - "
" + "\n", + "Please see the [next notebook](./03_Merge_Tables.ipynb) for a more detailed\n", + "explanation.\n", + "\n", + "\n" ] }, { @@ -2659,7 +2660,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Up Next" + "## Up Next\n" ] }, { diff --git a/notebooks/03_Merge_Tables.ipynb b/notebooks/03_Merge_Tables.ipynb index 04cc6ba13..2d76867d8 100644 --- a/notebooks/03_Merge_Tables.ipynb +++ b/notebooks/03_Merge_Tables.ipynb @@ -66,8 +66,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "[2023-10-12 11:15:17,864][INFO]: Connecting root@localhost:3306\n", - "[2023-10-12 11:15:17,873][INFO]: Connected root@localhost:3306\n" + "[2024-01-29 16:15:00,903][INFO]: Connecting root@localhost:3309\n", + "[2024-01-29 16:15:00,912][INFO]: Connected root@localhost:3309\n" ] } ], @@ -328,7 +328,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "['merge_delete', 'merge_delete_parent', 'merge_fetch', 'merge_get_parent', 'merge_get_part', 'merge_html', 'merge_populate', 'merge_restrict', 'merge_view']\n" + "['merge_delete', 'merge_delete_parent', 'merge_fetch', 'merge_get_parent', 'merge_get_parent_class', 'merge_get_part', 'merge_html', 'merge_populate', 'merge_restrict', 'merge_restrict_class', 'merge_view']\n" ] } ], @@ -386,7 +386,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -415,7 +415,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -497,7 +497,7 @@ " (Total: 1)" ] }, - "execution_count": 8, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -510,7 +510,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -521,11 +521,11 @@ " 'target_interval_list_name': '01_s1',\n", " 'filter_name': 'LFP 0-400 Hz',\n", " 'filter_sampling_rate': 30000,\n", - " 'analysis_file_name': 'minirec20230622_JOV02AWW09.nwb',\n", + " 'analysis_file_name': 'minirec20230622_R5DWQ6S53S.nwb',\n", " 'interval_list_name': 'lfp_test_01_s1_valid times',\n", - " 'lfp_object_id': '340b9a0b-626b-40ca-8b48-e033be72570a',\n", + " 'lfp_object_id': 'ffb893d1-a31e-41d3-aec7-8dc8936c8898',\n", " 'lfp_sampling_rate': 1000.0,\n", - " 'lfp': filtered data pynwb.ecephys.ElectricalSeries at 0x139910624563552\n", + " 'lfp': filtered data pynwb.ecephys.ElectricalSeries at 0x129602752674544\n", " Fields:\n", " comments: no comments\n", " conversion: 1.0\n", @@ -540,7 +540,7 @@ " unit: volts}]" ] }, - "execution_count": 9, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -552,7 +552,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -567,7 +567,7 @@ " 'filter_sampling_rate': 30000}" ] }, - "execution_count": 10, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -579,7 +579,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -588,7 +588,7 @@ "True" ] }, - "execution_count": 12, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -616,7 +616,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -718,7 +718,7 @@ " (Total: 1)" ] }, - "execution_count": 14, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -730,7 +730,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 12, "metadata": {}, "outputs": [ { @@ -824,9 +824,9 @@ "01_s1\n", "LFP 0-400 Hz\n", "30000\n", - "minirec20230622_JOV02AWW09.nwb\n", + "minirec20230622_R5DWQ6S53S.nwb\n", "lfp_test_01_s1_valid times\n", - "340b9a0b-626b-40ca-8b48-e033be72570a\n", + "ffb893d1-a31e-41d3-aec7-8dc8936c8898\n", "1000.0 \n", " \n", " \n", @@ -837,11 +837,11 @@ "FreeTable(`lfp_v1`.`__l_f_p_v1`)\n", "*nwb_file_name *lfp_electrode *target_interv *filter_name *filter_sampli analysis_file_ interval_list_ lfp_object_id lfp_sampling_r\n", "+------------+ +------------+ +------------+ +------------+ +------------+ +------------+ +------------+ +------------+ +------------+\n", - "minirec2023062 test 01_s1 LFP 0-400 Hz 30000 minirec2023062 lfp_test_01_s1 340b9a0b-626b- 1000.0 \n", + "minirec2023062 test 01_s1 LFP 0-400 Hz 30000 minirec2023062 lfp_test_01_s1 ffb893d1-a31e- 1000.0 \n", " (Total: 1)" ] }, - "execution_count": 15, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -861,7 +861,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -870,7 +870,7 @@ "array([1000.])" ] }, - "execution_count": 16, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -890,7 +890,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -900,7 +900,7 @@ " array(['minirec20230622_.nwb'], dtype=object)]" ] }, - "execution_count": 19, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -912,7 +912,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 15, "metadata": {}, "outputs": [ { @@ -926,7 +926,7 @@ " 'filter_sampling_rate': 30000}" ] }, - "execution_count": 20, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -955,8 +955,8 @@ "2. use `merge_delete_parent` to delete from the parent sources, getting rid of\n", " the entries in the source table they came from.\n", "\n", - "3. use `delete_downstream_merge` to find Merge Tables downstream and get rid\n", - " full entries, avoiding orphaned master table entries.\n", + "3. use `delete_downstream_merge` to find Merge Tables downstream of any other\n", + " table and get rid full entries, avoiding orphaned master table entries.\n", "\n", "The two latter cases can be destructive, so we include an extra layer of\n", "protection with `dry_run`. When true (by default), these functions return\n", @@ -965,16 +965,100 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 16, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[2024-01-29 16:15:23,054][INFO]: Deleting 1 rows from `lfp_merge`.`l_f_p_output__l_f_p_v1`\n", + "[2024-01-29 16:15:23,058][INFO]: Deleting 1 rows from `lfp_merge`.`l_f_p_output`\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[2024-01-29 16:15:24,953][WARNING]: Deletes cancelled\n" + ] + } + ], + "source": [ + "LFPOutput.merge_delete(nwb_file_dict) # Delete from merge table" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[FreeTable(`lfp_v1`.`__l_f_p_v1`)\n", + " *nwb_file_name *lfp_electrode *target_interv *filter_name *filter_sampli analysis_file_ interval_list_ lfp_object_id lfp_sampling_r\n", + " +------------+ +------------+ +------------+ +------------+ +------------+ +------------+ +------------+ +------------+ +------------+\n", + " minirec2023062 test 01_s1 LFP 0-400 Hz 30000 minirec2023062 lfp_test_01_s1 ffb893d1-a31e- 1000.0 \n", + " (Total: 1)]" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "LFPOutput.merge_delete_parent(restriction=nwb_file_dict, dry_run=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`delete_downstream_merge` is available from any other table in the pipeline,\n", + "but it does take some time to find the links downstream. If you're using this,\n", + "you can save time by reassigning your table to a variable, which will preserve\n", + "a copy of the previous search.\n", + "\n", + "Because the copy is stored, this function may not see additional merge tables\n", + "you've imported. To refresh this copy, set `reload_cache=True`\n" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[16:15:37][INFO] Spyglass: Building merge cache for nwbfile.\n", + "\tFound 3 downstream merge tables\n" + ] + }, + { + "data": { + "text/plain": [ + "dict_values([[*nwb_file_name *analysis_file *lfp_electrode *target_interv *filter_name *filter_sampli *merge_id nwb_file_a analysis_f analysis_file_ analysis_p interval_list_ lfp_object_id lfp_sampling_r\n", + "+------------+ +------------+ +------------+ +------------+ +------------+ +------------+ +------------+ +--------+ +--------+ +------------+ +--------+ +------------+ +------------+ +------------+\n", + "minirec2023062 minirec2023062 test 01_s1 LFP 0-400 Hz 30000 c34f98c5-7de7- =BLOB= =BLOB= =BLOB= lfp_test_01_s1 ffb893d1-a31e- 1000.0 \n", + " (Total: 1)\n", + "]])" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "LFPOutput.merge_delete(nwb_file_dict) # Delete from merge table\n", - "LFPOutput.merge_delete_parent(restriction=nwb_file_dict, dry_run=True)\n", - "delete_downstream_merge(\n", - " table=LFPV1,\n", - " restriction=nwb_file_dict,\n", + "nwbfile = sgc.Nwbfile()\n", + "\n", + "(nwbfile & nwb_file_dict).delete_downstream_merge(\n", " dry_run=True,\n", + " reload_cache=False, # if still encountering errors, try setting this to True\n", ")" ] }, @@ -982,8 +1066,8 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "To delete all merge table entries associated with an NWB file, use\n", - "`delete_downstream_merge` with the `Nwbfile` table.\n" + "This function is run automatically whin you use `cautious_delete`, which\n", + "checks team permissions before deleting.\n" ] }, { @@ -992,12 +1076,7 @@ "metadata": {}, "outputs": [], "source": [ - "delete_downstream_merge(\n", - " table=sgc.Nwbfile,\n", - " restriction={\"nwb_file_name\": nwb_copy_file_name},\n", - " dry_run=True,\n", - " recurse_level=3, # for long pipelines with many tables\n", - ")" + "(nwbfile & nwb_file_dict).cautious_delete()" ] }, { diff --git a/notebooks/py_scripts/01_Insert_Data.py b/notebooks/py_scripts/01_Insert_Data.py index 908c93491..c1fec99a9 100644 --- a/notebooks/py_scripts/01_Insert_Data.py +++ b/notebooks/py_scripts/01_Insert_Data.py @@ -128,6 +128,7 @@ # - # Spyglass will create a copy with this name. +# nwb_copy_file_name @@ -155,9 +156,9 @@ # sgc.LabMember.LabMemberInfo.insert( - [ # Full name, Google email address, DataJoint username - ["Firstname Lastname", "example1@gmail.com", "example1"], - ["Firstname2 Lastname2", "example2@gmail.com", "example2"], + [ # Full name, Google email address, DataJoint username, admin + ["Firstname Lastname", "example1@gmail.com", "example1", 0], + ["Firstname2 Lastname2", "example2@gmail.com", "example2", 0], ], skip_duplicates=True, ) @@ -187,7 +188,6 @@ # ## Inserting from NWB # -# # `spyglass.data_import.insert_sessions` helps take the many fields of data # present in an NWB file and insert them into various tables across Spyglass. If # the NWB file is properly composed, this includes... @@ -199,6 +199,7 @@ # # _Note:_ this may take time as Spyglass creates the copy. You may see a prompt # about inserting device information. +# sgi.insert_sessions(nwb_file_name) @@ -306,18 +307,17 @@ # Because it is a _secondary_ key, it is not required to uniquely identify an entry. # Current values for this key from spyglass pipelines are: # -# | pipeline | Source| -# | --- | --- | -# | position | sg.common.PositionSource | -# | lfp_v0 | sg.common.LFP | -# | lfp_v1 | sg.lfp.v1.LFPV1 | -# | lfp_band | sg.common.LFPBand,
sg.lfp.analysis.v1.LFPBandV1 | -# | lfp_artifact | sg.lfp.v1.LFPArtifactDetection | -# | spikesorting_artifact_v0 | sg.spikesorting.ArtifactDetection | -# | spikesorting_artifact_v1 | sg.spikesorting.v1.ArtifactDetection | -# | spikesorting_recording_v0 | sg.spikesorting.SpikeSortingRecording | -# | spikesorting_recording_v1 | sg.spikesorting.v1.SpikeSortingRecording | -# +# | pipeline | Source | +# | ------------------------- | --------------------------------------------------- | +# | position | sg.common.PositionSource | +# | lfp_v0 | sg.common.LFP | +# | lfp_v1 | sg.lfp.v1.LFPV1 | +# | lfp_band | sg.common.LFPBand,
sg.lfp.analysis.v1.LFPBandV1 | +# | lfp_artifact | sg.lfp.v1.LFPArtifactDetection | +# | spikesorting_artifact_v0 | sg.spikesorting.ArtifactDetection | +# | spikesorting_artifact_v1 | sg.spikesorting.v1.ArtifactDetection | +# | spikesorting_recording_v0 | sg.spikesorting.SpikeSortingRecording | +# | spikesorting_recording_v1 | sg.spikesorting.v1.SpikeSortingRecording | # # ## Deleting data @@ -355,20 +355,24 @@ # lfp.v1.LFPSelection.insert1(lfp_key, skip_duplicates=True) # lfp.v1.LFPV1().populate(lfp_key) # ``` +# # #
# Deleting Merge Entries # # ```python -# from spyglass.utils.dj_merge_tables import delete_downstream_merge +# nwbfile = sgc.Nwbfile() # -# delete_downstream_merge( -# sgc.Nwbfile(), -# restriction={"nwb_file_name": nwb_copy_file_name}, +# (nwbfile & {"nwb_file_name": nwb_copy_file_name}).delete_downstream_merge( # dry_run=False, # True will show Merge Table entries that would be deleted # ) # ``` +# +# Please see the [next notebook](./03_Merge_Tables.ipynb) for a more detailed +# explanation. +# #
+# session_entry = sgc.Session & {"nwb_file_name": nwb_copy_file_name} session_entry @@ -418,6 +422,7 @@ # !ls $SPYGLASS_BASE_DIR/raw # ## Up Next +# # In the [next notebook](./02_Data_Sync.ipynb), we'll explore tools for syncing. # diff --git a/notebooks/py_scripts/03_Merge_Tables.py b/notebooks/py_scripts/03_Merge_Tables.py index c4c0abb48..33b8e9a0e 100644 --- a/notebooks/py_scripts/03_Merge_Tables.py +++ b/notebooks/py_scripts/03_Merge_Tables.py @@ -192,8 +192,8 @@ # 2. use `merge_delete_parent` to delete from the parent sources, getting rid of # the entries in the source table they came from. # -# 3. use `delete_downstream_merge` to find Merge Tables downstream and get rid -# full entries, avoiding orphaned master table entries. +# 3. use `delete_downstream_merge` to find Merge Tables downstream of any other +# table and get rid full entries, avoiding orphaned master table entries. # # The two latter cases can be destructive, so we include an extra layer of # protection with `dry_run`. When true (by default), these functions return @@ -201,23 +201,32 @@ # LFPOutput.merge_delete(nwb_file_dict) # Delete from merge table + LFPOutput.merge_delete_parent(restriction=nwb_file_dict, dry_run=True) -delete_downstream_merge( - table=LFPV1, - restriction=nwb_file_dict, - dry_run=True, -) -# To delete all merge table entries associated with an NWB file, use -# `delete_downstream_merge` with the `Nwbfile` table. +# `delete_downstream_merge` is available from any other table in the pipeline, +# but it does take some time to find the links downstream. If you're using this, +# you can save time by reassigning your table to a variable, which will preserve +# a copy of the previous search. # +# Because the copy is stored, this function may not see additional merge tables +# you've imported. To refresh this copy, set `reload_cache=True` +# + +# + +nwbfile = sgc.Nwbfile() -delete_downstream_merge( - table=sgc.Nwbfile, - restriction={"nwb_file_name": nwb_copy_file_name}, +(nwbfile & nwb_file_dict).delete_downstream_merge( dry_run=True, - recurse_level=3, # for long pipelines with many tables + reload_cache=False, # if still encountering errors, try setting this to True ) +# - + +# This function is run automatically whin you use `cautious_delete`, which +# checks team permissions before deleting. +# + +(nwbfile & nwb_file_dict).cautious_delete() # ## Up Next # diff --git a/notebooks/py_scripts/11_Curation.py b/notebooks/py_scripts/11_Curation.py index 8b75a9c76..25eb698ad 100644 --- a/notebooks/py_scripts/11_Curation.py +++ b/notebooks/py_scripts/11_Curation.py @@ -5,7 +5,7 @@ # extension: .py # format_name: light # format_version: '1.5' -# jupytext_version: 1.15.2 +# jupytext_version: 1.16.0 # kernelspec: # display_name: base # language: python diff --git a/src/spyglass/common/common_usage.py b/src/spyglass/common/common_usage.py index 716649574..8b110cbc2 100644 --- a/src/spyglass/common/common_usage.py +++ b/src/spyglass/common/common_usage.py @@ -1,7 +1,7 @@ """A schema to store the usage of advanced Spyglass features. Records show usage of features such as table chains, which will be used to -determine which features are used, how often, and by whom. This will help +determine which features are used, how often, and by whom. This will help plan future development of Spyglass. """ diff --git a/src/spyglass/utils/dj_merge_tables.py b/src/spyglass/utils/dj_merge_tables.py index c0dec296f..b748267ad 100644 --- a/src/spyglass/utils/dj_merge_tables.py +++ b/src/spyglass/utils/dj_merge_tables.py @@ -54,16 +54,6 @@ def __init__(self): ) self._source_class_dict = {} - @property - def source_class_dict(self) -> dict: - if not self._source_class_dict: - module = getmodule(self) - self._source_class_dict = { - part_name: getattr(module, part_name) - for part_name in self.parts(camel_case=True) - } - return self._source_class_dict - def _remove_comments(self, definition): """Use regular expressions to remove comments and blank lines""" return re.sub( # First remove comments, then blank lines @@ -511,7 +501,7 @@ def merge_delete_parent( def fetch_nwb( self, - restriction: str = True, + restriction: str = None, multi_source=False, disable_warning=False, *attrs, @@ -531,10 +521,7 @@ def fetch_nwb( """ if isinstance(self, dict): raise ValueError("Try replacing Merge.method with Merge().method") - if restriction is True and self.restriction: - if not disable_warning: - _warn_on_restriction(self, restriction) - restriction = self.restriction + restriction = restriction or self.restriction or True return self.merge_restrict_class(restriction).fetch_nwb() diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index 9889b8a40..11e80b025 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -1,6 +1,5 @@ -from collections.abc import Iterable from time import time -from typing import Dict, List, Union +from typing import Dict, List import datajoint as dj import networkx as nx @@ -58,6 +57,7 @@ class SpyglassMixin: _merge_table_cache = {} # Cache of merge tables downstream of self _merge_chains_cache = {} # Cache of table chains to merges _session_connection_cache = None # Cache of path from Session to self + _test_mode_cache = None # Cache of test mode setting for delete _usage_table_cache = None # Temporary inclusion for usage tracking # ------------------------------- fetch_nwb ------------------------------- @@ -106,7 +106,9 @@ def _nwb_table_tuple(self): self._nwb_table_resolved = ( AnalysisNwbfile if "-> AnalysisNwbfile" in self.definition - else Nwbfile if "-> Nwbfile" in self.definition else None + else Nwbfile + if "-> Nwbfile" in self.definition + else None ) if getattr(self, "_nwb_table_resolved", None) is None: @@ -440,8 +442,8 @@ def cautious_delete(self, force_permission: bool = False, *args, **kwargs): if merge_deletes: for table, content in merge_deletes.items(): - count, name = len(content), table.full_table_name - dj_logger.info(f"Merge: Deleting {count} rows from {name}") + count = sum([len(part) for part in content]) + dj_logger.info(f"Merge: Deleting {count} rows from {table}") if ( not self._test_mode or not safemode @@ -519,7 +521,7 @@ def __str__(self): if not self._has_link: return "No link" return ( - f"Chain: " + "Chain: " + self.parent.table_name + self._link_symbol + self.child.table_name From 895e656457d203ed8eec62eaf4f4fdb74050d9a7 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Tue, 30 Jan 2024 15:51:46 -0600 Subject: [PATCH 09/10] Overwrite . Mixin add cached_property decorator --- CHANGELOG.md | 1 + src/spyglass/settings.py | 2 +- src/spyglass/utils/dj_mixin.py | 198 +++++++++++++-------------------- 3 files changed, 77 insertions(+), 124 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 27779bfd3..9b638945d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ - Add `cautious_delete` to Mixin class - Initial implementation. #711, #762 - More robust caching of join to downstream tables. #806 + - Overwrite datajoint `delete` method to use `cautious_delete`. #806 - Add `deprecation_factory` to facilitate table migration. #717 - Add Spyglass logger. #730 - IntervalList: Add secondary key `pipeline` #742 diff --git a/src/spyglass/settings.py b/src/spyglass/settings.py index 68fe1e528..007ec9160 100644 --- a/src/spyglass/settings.py +++ b/src/spyglass/settings.py @@ -7,7 +7,7 @@ import yaml from pymysql.err import OperationalError -from spyglass.utils import logger +from spyglass.utils.logging import logger class SpyglassConfig: diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index b170a3831..7f8415d01 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -1,16 +1,16 @@ +from functools import cached_property from time import time from typing import Dict, List import datajoint as dj import networkx as nx -from datajoint.table import logger as dj_logger -from datajoint.user_tables import Table, TableMeta +from datajoint.logging import logger as dj_logger +from datajoint.table import Table from datajoint.utils import get_master, user_choice -from spyglass.utils.database_settings import SHARED_MODULES +from spyglass.settings import test_mode from spyglass.utils.dj_helper_fn import fetch_nwb from spyglass.utils.dj_merge_tables import RESERVED_PRIMARY_KEY as MERGE_PK -from spyglass.utils.dj_merge_tables import Merge from spyglass.utils.logging import logger @@ -47,41 +47,32 @@ class SpyglassMixin: Alias for delete_downstream_merge. """ - _nwb_table_dict = {} # Dict mapping NWBFile table to path attribute name. # _nwb_table = None # NWBFile table class, defined at the table level - _nwb_table_resolved = None # NWBFiletable class, resolved here from above - _delete_dependencies = [] # Session, LabMember, LabTeam, delay import + # pks for delete permission check, assumed to be on field _session_pk = None # Session primary key. Mixin is ambivalent to Session pk _member_pk = None # LabMember primary key. Mixin ambivalent table structure - _merge_table_cache = {} # Cache of merge tables downstream of self - _merge_chains_cache = {} # Cache of table chains to merges - _session_connection_cache = None # Cache of path from Session to self - _test_mode_cache = None # Cache of test mode setting for delete - _usage_table_cache = None # Temporary inclusion for usage tracking # ------------------------------- fetch_nwb ------------------------------- - @property + @cached_property def _table_dict(self): """Dict mapping NWBFile table to path attribute name. Used to delay import of NWBFile tables until needed, avoiding circular imports. """ - if not self._nwb_table_dict: - from spyglass.common.common_nwbfile import ( # noqa F401 - AnalysisNwbfile, - Nwbfile, - ) + from spyglass.common.common_nwbfile import ( + AnalysisNwbfile, + Nwbfile, + ) # noqa F401 - self._nwb_table_dict = { - AnalysisNwbfile: "analysis_file_abs_path", - Nwbfile: "nwb_file_abs_path", - } - return self._nwb_table_dict + return { + AnalysisNwbfile: "analysis_file_abs_path", + Nwbfile: "nwb_file_abs_path", + } - @property + @cached_property def _nwb_table_tuple(self): """NWBFile table class. @@ -89,35 +80,31 @@ def _nwb_table_tuple(self): Multiple copies for different purposes. - _nwb_table may be user-set. Don't overwrite. - - _nwb_table_resolved is set here from either _nwb_table or definition. - - _nwb_table_tuple is used to cache result of _nwb_table_resolved and - return the appropriate path_attr from _table_dict above. """ - if not self._nwb_table_resolved: - from spyglass.common.common_nwbfile import ( # noqa F401 - AnalysisNwbfile, - Nwbfile, + from spyglass.common.common_nwbfile import ( + AnalysisNwbfile, + Nwbfile, + ) # noqa F401 + + if hasattr(self, "_nwb_table"): + resolved = self._nwb_table + + else: + resolved = ( + AnalysisNwbfile + if "-> AnalysisNwbfile" in self.definition + else Nwbfile if "-> Nwbfile" in self.definition else None ) - if hasattr(self, "_nwb_table"): - self._nwb_table_resolved = self._nwb_table - - if not hasattr(self, "_nwb_table"): - self._nwb_table_resolved = ( - AnalysisNwbfile - if "-> AnalysisNwbfile" in self.definition - else Nwbfile if "-> Nwbfile" in self.definition else None - ) - - if getattr(self, "_nwb_table_resolved", None) is None: - raise NotImplementedError( - f"{self.__class__.__name__} does not have a " - "(Analysis)Nwbfile foreign key or _nwb_table attribute." - ) + if not resolved: + raise NotImplementedError( + f"{self.__class__.__name__} does not have a " + "(Analysis)Nwbfile foreign key or _nwb_table attribute." + ) return ( - self._nwb_table_resolved, - self._table_dict[self._nwb_table_resolved], + resolved, + self._table_dict[resolved], ) def fetch_nwb(self, *attrs, **kwargs): @@ -130,65 +117,50 @@ def fetch_nwb(self, *attrs, **kwargs): '-> AnalysisNwbfile' in its definition can use a _nwb_table attribute to specify which table to use. """ - nwb_table, path_attr = self._nwb_table_tuple - - return fetch_nwb(self, (nwb_table, path_attr), *attrs, **kwargs) + return fetch_nwb(self, self._nwb_table_tuple, *attrs, **kwargs) # -------------------------------- delete --------------------------------- - @property + @cached_property def _delete_deps(self) -> list: """List of tables required for delete permission check. Used to delay import of tables until needed, avoiding circular imports. + Each of these tables inheits SpyglassMixin. """ - if not self._delete_dependencies: - from spyglass.common import LabMember, LabTeam, Session # noqa F401 - - self._delete_dependencies = [LabMember, LabTeam, Session] - self._session_pk = Session.primary_key[0] - self._member_pk = LabMember.primary_key[0] - return self._delete_dependencies + from spyglass.common import LabMember, LabTeam, Session # noqa F401 - @property - def _test_mode(self) -> bool: - """Return True if test mode is enabled.""" - if not self._test_mode_cache: - from spyglass.settings import test_mode - - self._test_mode_cache = test_mode - return self._test_mode_cache + self._session_pk = Session.primary_key[0] + self._member_pk = LabMember.primary_key[0] + return [LabMember, LabTeam, Session] - @property + @cached_property def _merge_tables(self) -> Dict[str, dj.FreeTable]: """Dict of merge tables downstream of self. Cache of items in parents of self.descendants(as_objects=True) that have a merge primary key. """ - if self._merge_table_cache: - return self._merge_table_cache - - def has_merge_pk(table): - return MERGE_PK in table.heading.names self.connection.dependencies.load() + merge_tables = {} for desc in self.descendants(as_objects=True): - if not has_merge_pk(desc): - continue - if not (master_name := get_master(desc.full_table_name)): + if MERGE_PK not in desc.heading.names or not ( + master_name := get_master(desc.full_table_name) + ): continue master = dj.FreeTable(self.connection, master_name) - if has_merge_pk(master): - self._merge_table_cache[master_name] = master + if MERGE_PK in master.heading.names: + merge_tables[master_name] = master + logger.info( f"Building merge cache for {self.table_name}.\n\t" - + f"Found {len(self._merge_table_cache)} downstream merge tables" + + f"Found {len(merge_tables)} downstream merge tables" ) - return self._merge_table_cache + return merge_tables - @property + @cached_property def _merge_chains(self) -> Dict[str, List[dj.FreeTable]]: """Dict of merge links downstream of self. @@ -197,14 +169,12 @@ def _merge_chains(self) -> Dict[str, List[dj.FreeTable]]: to recompute whenever delete_downstream_merge is called with a new restriction. """ - if self._merge_chains_cache: - return self._merge_chains_cache - + merge_chains = {} for name, merge_table in self._merge_tables.items(): chains = TableChains(self, merge_table, connection=self.connection) if len(chains): - self._merge_chains_cache[name] = chains - return self._merge_chains_cache + merge_chains[name] = chains + return merge_chains def _commit_merge_deletes(self, merge_join_dict, **kwargs): """Commit merge deletes. @@ -247,8 +217,8 @@ def delete_downstream_merge( Passed to datajoint.table.Table.delete. """ if reload_cache: - self._merge_table_cache = {} - self._merge_chains_cache = {} + del self._merge_tables + del self._merge_chains restriction = restriction or self.restriction or True @@ -315,18 +285,14 @@ def _get_exp_summary(self): return exp_missing + exp_present - @property + @cached_property def _session_connection(self) -> dj.expression.QueryExpression: """Path from Session table to self. None is not yet cached, False if no connection found. """ - if self._session_connection_cache is None: - connection = TableChain(parent=self._delete_deps[-1], child=self) - self._session_connection_cache = ( - connection if connection.has_link else False - ) - return self._session_connection_cache + connection = TableChain(parent=self._delete_deps[-1], child=self) + return connection if connection.has_link else False def _check_delete_permission(self) -> None: """Check user name against lab team assoc. w/ self * Session. @@ -382,14 +348,12 @@ def _check_delete_permission(self) -> None: ) logger.info(f"Queueing delete for session(s):\n{sess_summary}") - @property + @cached_property def _usage_table(self): """Temporary inclusion for usage tracking.""" - if not self._usage_table_cache: - from spyglass.common.common_usage import CautiousDelete + from spyglass.common.common_usage import CautiousDelete - self._usage_table_cache = CautiousDelete - return self._usage_table_cache + return CautiousDelete def _log_use(self, start, merge_deletes=None): """Log use of cautious_delete.""" @@ -403,9 +367,8 @@ def _log_use(self, start, merge_deletes=None): ) ) - # Rename to `delete` when we're ready to use it # TODO: Intercept datajoint delete confirmation prompt for merge deletes - def cautious_delete(self, force_permission: bool = False, *args, **kwargs): + def delete(self, force_permission: bool = False, *args, **kwargs): """Delete table rows after checking user permission. Permission is granted to users listed as admin in LabMember table or to @@ -443,7 +406,7 @@ def cautious_delete(self, force_permission: bool = False, *args, **kwargs): count = sum([len(part) for part in content]) dj_logger.info(f"Merge: Deleting {count} rows from {table}") if ( - not self._test_mode + not test_mode or not safemode or user_choice("Commit deletes?", default="no") == "yes" ): @@ -509,9 +472,6 @@ def __init__(self, parent: Table, child: Table, connection=None): self._link_symbol = " -> " self.parent = parent self.child = child - self._repr = None - self._names = None # full table names of tables in chain - self._objects = None # free tables in chain self._has_link = child.full_table_name in parent.descendants() def __str__(self): @@ -527,15 +487,12 @@ def __str__(self): def __repr__(self): """Return full representation of chain: parent -> {links} -> child.""" - if self._repr: - return self._repr - self._repr = ( + return ( "Chain: " + self._link_symbol.join([t.table_name for t in self.objects]) if self.names else "No link" ) - return self._repr def __len__(self): """Return number of tables in chain.""" @@ -550,7 +507,7 @@ def has_link(self) -> bool: """ return self._has_link - @property + @cached_property def names(self) -> List[str]: """Return list of full table names in chain. @@ -558,29 +515,24 @@ def names(self) -> List[str]: """ if not self._has_link: return None - if self._names: - return self._names try: - self._names = nx.shortest_path( + return nx.shortest_path( self.parent.connection.dependencies, self.parent.full_table_name, self.child.full_table_name, ) - return self._names except nx.NetworkXNoPath: self._has_link = False return None - @property + @cached_property def objects(self) -> List[dj.FreeTable]: """Return list of FreeTable objects for each table in chain.""" - if not self._objects: - self._objects = ( - [dj.FreeTable(self._connection, name) for name in self.names] - if self.names - else None - ) - return self._objects + return ( + [dj.FreeTable(self._connection, name) for name in self.names] + if self.names + else None + ) def join(self, restricton: str = None) -> dj.expression.QueryExpression: """Return join of tables in chain with restriction applied to parent.""" From ca9e9d26913ec6e733d066242a1281fa3e4bb111 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Tue, 30 Jan 2024 16:45:40 -0600 Subject: [PATCH 10/10] Cleanup docstrings and type annotations --- src/spyglass/utils/dj_chains.py | 168 +++++++++++++++++++ src/spyglass/utils/dj_mixin.py | 279 ++++++++++---------------------- 2 files changed, 250 insertions(+), 197 deletions(-) create mode 100644 src/spyglass/utils/dj_chains.py diff --git a/src/spyglass/utils/dj_chains.py b/src/spyglass/utils/dj_chains.py new file mode 100644 index 000000000..b76132551 --- /dev/null +++ b/src/spyglass/utils/dj_chains.py @@ -0,0 +1,168 @@ +from functools import cached_property +from typing import List + +import datajoint as dj +import networkx as nx +from datajoint.expression import QueryExpression +from datajoint.table import Table +from datajoint.utils import get_master + +from spyglass.utils.dj_merge_tables import RESERVED_PRIMARY_KEY as MERGE_PK +from spyglass.utils.logging import logger + + +class TableChains: + """Class for representing chains from parent to Merge table via parts. + + Functions as a plural version of TableChain, allowing a single `join` + call across all chains from parent -> Merge table. + """ + + def __init__(self, parent, child, connection=None): + self.parent = parent + self.child = child + self.connection = connection or parent.connection + parts = child.parts(as_objects=True) + self.part_names = [part.full_table_name for part in parts] + self.chains = [TableChain(parent, part) for part in parts] + self.has_link = any([chain.has_link for chain in self.chains]) + + def __repr__(self): + return "\n".join([str(chain) for chain in self.chains]) + + def __len__(self): + return len([c for c in self.chains if c.has_link]) + + def join(self, restriction=None) -> List[QueryExpression]: + """Return list of joins for each chain in self.chains.""" + restriction = restriction or self.parent.restriction or True + joins = [] + for chain in self.chains: + if joined := chain.join(restriction): + joins.append(joined) + return joins + + +class TableChain: + """Class for representing a chain of tables. + + A chain is a sequence of tables from parent to child identified by + networkx.shortest_path. Parent -> Merge should use TableChains instead to + handle multiple paths to the respective parts of the Merge table. + + Attributes + ---------- + parent : Table + Parent or origin of chain. + child : Table + Child or destination of chain. + _connection : datajoint.Connection, optional + Connection to database used to create FreeTable objects. Defaults to + parent.connection. + _link_symbol : str + Symbol used to represent the link between parent and child. Hardcoded + to " -> ". + _has_link : bool + Cached attribute to store whether parent is linked to child. False if + child is not in parent.descendants or nx.NetworkXNoPath is raised by + nx.shortest_path. + names : List[str] + List of full table names in chain. Generated by networkx.shortest_path. + objects : List[dj.FreeTable] + List of FreeTable objects for each table in chain. + + Methods + ------- + __str__() + Return string representation of chain: parent -> child. + __repr__() + Return full representation of chain: parent -> {links} -> child. + __len__() + Return number of tables in chain. + join(restriction: str = None) + Return join of tables in chain with restriction applied to parent. + """ + + def __init__(self, parent: Table, child: Table, connection=None): + self._connection = connection or parent.connection + if not self._connection.dependencies._loaded: + self._connection.dependencies.load() + + if ( # if child is a merge table + get_master(child.full_table_name) == "" + and MERGE_PK in child.heading.names + ): + logger.error("Child is a merge table. Use TableChains instead.") + + self._link_symbol = " -> " + self.parent = parent + self.child = child + self._has_link = child.full_table_name in parent.descendants() + + def __str__(self): + """Return string representation of chain: parent -> child.""" + if not self._has_link: + return "No link" + return ( + "Chain: " + + self.parent.table_name + + self._link_symbol + + self.child.table_name + ) + + def __repr__(self): + """Return full representation of chain: parent -> {links} -> child.""" + return ( + "Chain: " + + self._link_symbol.join([t.table_name for t in self.objects]) + if self.names + else "No link" + ) + + def __len__(self): + """Return number of tables in chain.""" + return len(self.names) + + @property + def has_link(self) -> bool: + """Return True if parent is linked to child. + + Cached as hidden attribute _has_link to set False if nx.NetworkXNoPath + is raised by nx.shortest_path. + """ + return self._has_link + + @cached_property + def names(self) -> List[str]: + """Return list of full table names in chain. + + Uses networkx.shortest_path. + """ + if not self._has_link: + return None + try: + return nx.shortest_path( + self.parent.connection.dependencies, + self.parent.full_table_name, + self.child.full_table_name, + ) + except nx.NetworkXNoPath: + self._has_link = False + return None + + @cached_property + def objects(self) -> List[dj.FreeTable]: + """Return list of FreeTable objects for each table in chain.""" + return ( + [dj.FreeTable(self._connection, name) for name in self.names] + if self.names + else None + ) + + def join(self, restricton: str = None) -> dj.expression.QueryExpression: + """Return join of tables in chain with restriction applied to parent.""" + restriction = restricton or self.parent.restriction or True + join = self.objects[0] & restriction + for table in self.objects[1:]: + join = join * table + return join if join else None diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index 7f8415d01..03f0ec08b 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -1,14 +1,14 @@ from functools import cached_property from time import time -from typing import Dict, List +from typing import Dict, List, Union import datajoint as dj -import networkx as nx +from datajoint.expression import QueryExpression from datajoint.logging import logger as dj_logger from datajoint.table import Table from datajoint.utils import get_master, user_choice -from spyglass.settings import test_mode +from spyglass.utils.dj_chains import TableChain, TableChains from spyglass.utils.dj_helper_fn import fetch_nwb from spyglass.utils.dj_merge_tables import RESERVED_PRIMARY_KEY as MERGE_PK from spyglass.utils.logging import logger @@ -27,6 +27,14 @@ class SpyglassMixin: Fetch NWBFile object from relevant table. Uses either a foreign key to a NWBFile table (including AnalysisNwbfile) or a _nwb_table attribute to determine which table to use. + delte_downstream_merge(restriction=None, dry_run=True, reload_cache=False) + Delete downstream merge table entries associated with restricton. + Requires caching of merge tables and links, which is slow on first call. + `restriction` can be set to a string to restrict the delete. `dry_run` + can be set to False to commit the delete. `reload_cache` can be set to + True to reload the merge cache. + ddm(*args, **kwargs) + Alias for delete_downstream_merge. cautious_delete(force_permission=False, *args, **kwargs) Check user permissions before deleting table rows. Permission is granted to users listed as admin in LabMember table or to users on a team with @@ -37,64 +45,37 @@ class SpyglassMixin: raised. `force_permission` can be set to True to bypass permission check. cdel(*args, **kwargs) Alias for cautious_delete. - delte_downstream_merge(restriction=None, dry_run=True, reload_cache=False) - Delete downstream merge table entries associated with restricton. - Requires caching of merge tables and links, which is slow on first call. - `restriction` can be set to a string to restrict the delete. `dry_run` - can be set to False to commit the delete. `reload_cache` can be set to - True to reload the merge cache. - ddm(*args, **kwargs) - Alias for delete_downstream_merge. """ # _nwb_table = None # NWBFile table class, defined at the table level - # pks for delete permission check, assumed to be on field + # pks for delete permission check, assumed to be one field for each _session_pk = None # Session primary key. Mixin is ambivalent to Session pk _member_pk = None # LabMember primary key. Mixin ambivalent table structure # ------------------------------- fetch_nwb ------------------------------- @cached_property - def _table_dict(self): - """Dict mapping NWBFile table to path attribute name. + def _nwb_table_tuple(self) -> tuple: + """NWBFile table class. - Used to delay import of NWBFile tables until needed, avoiding circular - imports. - """ + Used to determine fetch_nwb behavior. Also used in Merge.fetch_nwb. + Implemented as a cached_property to avoid circular imports.""" from spyglass.common.common_nwbfile import ( AnalysisNwbfile, Nwbfile, ) # noqa F401 - return { + table_dict = { AnalysisNwbfile: "analysis_file_abs_path", Nwbfile: "nwb_file_abs_path", } - @cached_property - def _nwb_table_tuple(self): - """NWBFile table class. - - Used to determine fetch_nwb behavior. Also used in Merge.fetch_nwb. - Multiple copies for different purposes. - - - _nwb_table may be user-set. Don't overwrite. - """ - from spyglass.common.common_nwbfile import ( - AnalysisNwbfile, - Nwbfile, - ) # noqa F401 - - if hasattr(self, "_nwb_table"): - resolved = self._nwb_table - - else: - resolved = ( - AnalysisNwbfile - if "-> AnalysisNwbfile" in self.definition - else Nwbfile if "-> Nwbfile" in self.definition else None - ) + resolved = getattr(self, "_nwb_table", None) or ( + AnalysisNwbfile + if "-> AnalysisNwbfile" in self.definition + else Nwbfile if "-> Nwbfile" in self.definition else None + ) if not resolved: raise NotImplementedError( @@ -104,42 +85,27 @@ def _nwb_table_tuple(self): return ( resolved, - self._table_dict[resolved], + table_dict[resolved], ) def fetch_nwb(self, *attrs, **kwargs): """Fetch NWBFile object from relevant table. - Implementing class must have a foreign key to Nwbfile or - AnalysisNwbfile or a _nwb_table attribute. - - A class that does not have with either '-> Nwbfile' or - '-> AnalysisNwbfile' in its definition can use a _nwb_table attribute to - specify which table to use. + Implementing class must have a foreign key reference to Nwbfile or + AnalysisNwbfile (i.e., "-> (Analysis)Nwbfile" in definition) + or a _nwb_table attribute. If both are present, the attribute takes + precedence. """ return fetch_nwb(self, self._nwb_table_tuple, *attrs, **kwargs) - # -------------------------------- delete --------------------------------- - - @cached_property - def _delete_deps(self) -> list: - """List of tables required for delete permission check. - - Used to delay import of tables until needed, avoiding circular imports. - Each of these tables inheits SpyglassMixin. - """ - from spyglass.common import LabMember, LabTeam, Session # noqa F401 - - self._session_pk = Session.primary_key[0] - self._member_pk = LabMember.primary_key[0] - return [LabMember, LabTeam, Session] + # ------------------------ delete_downstream_merge ------------------------ @cached_property def _merge_tables(self) -> Dict[str, dj.FreeTable]: - """Dict of merge tables downstream of self. + """Dict of merge tables downstream of self: {full_table_name: FreeTable}. - Cache of items in parents of self.descendants(as_objects=True) that - have a merge primary key. + Cache of items in parents of self.descendants(as_objects=True). Both + descendant and parent must have the reserved primary key 'merge_id'. """ self.connection.dependencies.load() @@ -162,12 +128,14 @@ def _merge_tables(self) -> Dict[str, dj.FreeTable]: @cached_property def _merge_chains(self) -> Dict[str, List[dj.FreeTable]]: - """Dict of merge links downstream of self. + """Dict of chains to merges downstream of self + + Format: {full_table_name: TableChains}. For each merge table found in _merge_tables, find the path from self to - merge. If the path is valid, add it to the dict. Cache prevents need - to recompute whenever delete_downstream_merge is called with a new - restriction. + merge via merge parts. If the path is valid, add it to the dict. Cache + prevents need to recompute whenever delete_downstream_merge is called + with a new restriction. To recompute, add `reload_cache=True` to call. """ merge_chains = {} for name, merge_table in self._merge_tables.items(): @@ -176,9 +144,17 @@ def _merge_chains(self) -> Dict[str, List[dj.FreeTable]]: merge_chains[name] = chains return merge_chains - def _commit_merge_deletes(self, merge_join_dict, **kwargs): + def _commit_merge_deletes( + self, merge_join_dict: Dict[str, List[QueryExpression]], **kwargs + ) -> None: """Commit merge deletes. + Parameters + ---------- + merge_join_dict : Dict[str, List[QueryExpression]] + Dictionary of merge tables and their joins. Uses 'merge_id' primary + key to restrict delete. + Extracted for use in cautious_delete and delete_downstream_merge.""" for table_name, part_restr in merge_join_dict.items(): table = self._merge_tables[table_name] @@ -193,7 +169,7 @@ def delete_downstream_merge( disable_warning: bool = False, return_parts: bool = True, **kwargs, - ) -> List[dj.expression.QueryExpression]: + ) -> Union[List[QueryExpression], Dict[str, List[QueryExpression]]]: """Delete downstream merge table entries associated with restricton. Requires caching of merge tables and links, which is slow on first call. @@ -248,7 +224,7 @@ def ddm( return_parts: bool = True, *args, **kwargs, - ): + ) -> Union[List[QueryExpression], Dict[str, List[QueryExpression]]]: """Alias for delete_downstream_merge.""" return self.delete_downstream_merge( restriction=restriction, @@ -260,6 +236,23 @@ def ddm( **kwargs, ) + # ---------------------------- cautious_delete ---------------------------- + + @cached_property + def _delete_deps(self) -> List[Table]: + """List of tables required for delete permission check. + + LabMember, LabTeam, and Session are required for delete permission. + + Used to delay import of tables until needed, avoiding circular imports. + Each of these tables inheits SpyglassMixin. + """ + from spyglass.common import LabMember, LabTeam, Session # noqa F401 + + self._session_pk = Session.primary_key[0] + self._member_pk = LabMember.primary_key[0] + return [LabMember, LabTeam, Session] + def _get_exp_summary(self): """Get summary of experimenters for session(s), including NULL. @@ -286,16 +279,22 @@ def _get_exp_summary(self): return exp_missing + exp_present @cached_property - def _session_connection(self) -> dj.expression.QueryExpression: - """Path from Session table to self. - - None is not yet cached, False if no connection found. - """ + def _session_connection(self) -> Union[TableChain, bool]: + """Path from Session table to self. False if no connection found.""" connection = TableChain(parent=self._delete_deps[-1], child=self) return connection if connection.has_link else False + @cached_property + def _test_mode(self) -> bool: + """Return True if in test mode. + + Avoids circular import. Prevents prompt on delete.""" + from spyglass.settings import test_mode + + return test_mode + def _check_delete_permission(self) -> None: - """Check user name against lab team assoc. w/ self * Session. + """Check user name against lab team assoc. w/ self -> Session. Returns ------- @@ -368,7 +367,7 @@ def _log_use(self, start, merge_deletes=None): ) # TODO: Intercept datajoint delete confirmation prompt for merge deletes - def delete(self, force_permission: bool = False, *args, **kwargs): + def cautious_delete(self, force_permission: bool = False, *args, **kwargs): """Delete table rows after checking user permission. Permission is granted to users listed as admin in LabMember table or to @@ -406,7 +405,7 @@ def delete(self, force_permission: bool = False, *args, **kwargs): count = sum([len(part) for part in content]) dj_logger.info(f"Merge: Deleting {count} rows from {table}") if ( - not test_mode + not self._test_mode or not safemode or user_choice("Commit deletes?", default="no") == "yes" ): @@ -424,120 +423,6 @@ def cdel(self, *args, **kwargs): """Alias for cautious_delete.""" self.cautious_delete(*args, **kwargs) - -class TableChains: - """Class for representing chains from parent to Merge table via parts.""" - - def __init__(self, parent, child, connection=None): - self.parent = parent - self.child = child - self.connection = connection or parent.connection - parts = child.parts(as_objects=True) - self.part_names = [part.full_table_name for part in parts] - self.chains = [TableChain(parent, part) for part in parts] - self.has_link = any([chain.has_link for chain in self.chains]) - - def __repr__(self): - return "\n".join([str(chain) for chain in self.chains]) - - def __len__(self): - return len([c for c in self.chains if c.has_link]) - - def join(self, restriction=None): - restriction = restriction or self.parent.restriction or True - joins = [] - for chain in self.chains: - if joined := chain.join(restriction): - joins.append(joined) - return joins - - -class TableChain: - """Class for representing a chain of tables. - - Note: Parent -> Merge should use TableChains instead. - """ - - def __init__(self, parent: Table, child: Table, connection=None): - self._connection = connection or parent.connection - if not self._connection.dependencies._loaded: - self._connection.dependencies.load() - - if ( # if child is a merge table - get_master(child.full_table_name) == "" - and MERGE_PK in child.heading.names - ): - logger.error("Child is a merge table. Use TableChains instead.") - - self._link_symbol = " -> " - self.parent = parent - self.child = child - self._has_link = child.full_table_name in parent.descendants() - - def __str__(self): - """Return string representation of chain: parent -> child.""" - if not self._has_link: - return "No link" - return ( - "Chain: " - + self.parent.table_name - + self._link_symbol - + self.child.table_name - ) - - def __repr__(self): - """Return full representation of chain: parent -> {links} -> child.""" - return ( - "Chain: " - + self._link_symbol.join([t.table_name for t in self.objects]) - if self.names - else "No link" - ) - - def __len__(self): - """Return number of tables in chain.""" - return len(self.names) - - @property - def has_link(self) -> bool: - """Return True if parent is linked to child. - - Cached as hidden attribute _has_link to set False if nx.NetworkXNoPath - is raised by nx.shortest_path. - """ - return self._has_link - - @cached_property - def names(self) -> List[str]: - """Return list of full table names in chain. - - Uses networkx.shortest_path. - """ - if not self._has_link: - return None - try: - return nx.shortest_path( - self.parent.connection.dependencies, - self.parent.full_table_name, - self.child.full_table_name, - ) - except nx.NetworkXNoPath: - self._has_link = False - return None - - @cached_property - def objects(self) -> List[dj.FreeTable]: - """Return list of FreeTable objects for each table in chain.""" - return ( - [dj.FreeTable(self._connection, name) for name in self.names] - if self.names - else None - ) - - def join(self, restricton: str = None) -> dj.expression.QueryExpression: - """Return join of tables in chain with restriction applied to parent.""" - restriction = restricton or self.parent.restriction or True - join = self.objects[0] & restriction - for table in self.objects[1:]: - join = join * table - return join if join else None + def delete(self, *args, **kwargs): + """Alias for cautious_delete, overwrites datajoint.table.Table.delete""" + self.cautious_delete(*args, **kwargs)