diff --git a/CHANGELOG.md b/CHANGELOG.md index ea49e3d4e..6b1cf8ebc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # Change Log +## [0.5.2] (Unreleased) + +- Refactor `TableChain` to include `_searched` attribute. #867 + ## [0.5.1] (March 7, 2024) ### Infrastructure diff --git a/src/spyglass/utils/dj_chains.py b/src/spyglass/utils/dj_chains.py index 4e05763fc..76ffeb107 100644 --- a/src/spyglass/utils/dj_chains.py +++ b/src/spyglass/utils/dj_chains.py @@ -123,12 +123,13 @@ class TableChain: _link_symbol : str Symbol used to represent the link between parent and child. Hardcoded to " -> ". - _has_link : bool + has_link : bool Cached attribute to store whether parent is linked to child. False if child is not in parent.descendants or nx.NetworkXNoPath is raised by nx.shortest_path. - _has_directed_link : bool - True if directed graph is used to find path. False if undirected graph. + link_type : str + 'directed' or 'undirected' based on whether path is found with directed + or undirected graph. None if no path is found. graph : nx.DiGraph Directed graph of parent's dependencies from datajoint.connection. names : List[str] @@ -175,18 +176,19 @@ def __init__(self, parent: Table, child: Table, connection=None): self._link_symbol = " -> " self.parent = parent self.child = child - self._has_link = True - self._has_directed_link = None + self.link_type = None + self._searched = False if child.full_table_name not in self.graph.nodes: logger.warning( "Can't find item in graph. Try importing: " + f"{child.full_table_name}" ) + self._searched = True def __str__(self): """Return string representation of chain: parent -> child.""" - if not self._has_link: + if not self.has_link: return "No link" return ( to_camel_case(self.parent.table_name) @@ -196,19 +198,22 @@ def __str__(self): def __repr__(self): """Return full representation of chain: parent -> {links} -> child.""" - return ( - "Chain: " - + self._link_symbol.join([t.table_name for t in self.objects]) - if self.names - else "No link" + if not self.has_link: + return "No link" + return "Chain: " + self._link_symbol.join( + [t.table_name for t in self.objects] ) def __len__(self): """Return number of tables in chain.""" + if not self.has_link: + return 0 return len(self.names) def __getitem__(self, index: Union[int, str]) -> dj.FreeTable: """Return FreeTable object at index.""" + if not self.has_link: + return None if isinstance(index, str): for i, name in enumerate(self.names): if index in name: @@ -219,10 +224,12 @@ def __getitem__(self, index: Union[int, str]) -> dj.FreeTable: def has_link(self) -> bool: """Return True if parent is linked to child. - Cached as hidden attribute _has_link to set False if nx.NetworkXNoPath - is raised by nx.shortest_path. + If not searched, search for path. If searched and no link is found, + return False. If searched and link is found, return True. """ - return self._has_link + if not self._searched: + _ = self.path + return self.link_type is not None def pk_link(self, src, trg, data) -> float: """Return 1 if data["primary"] else float("inf"). @@ -242,7 +249,7 @@ def find_path(self, directed=True) -> OrderedDict: If True, use directed graph. If False, use undirected graph. Defaults to True. Undirected permits paths to traverse from merge part-parent -> merge part -> merge table. Undirected excludes - PERIPHERAL_TABLES likne interval_list, nwbfile, etc. + PERIPHERAL_TABLES like interval_list, nwbfile, etc. Returns ------- @@ -265,6 +272,9 @@ def find_path(self, directed=True) -> OrderedDict: path = nx.shortest_path(self.graph, source, target) except nx.NetworkXNoPath: return None + except nx.NodeNotFound: + self._searched = True + return None ret = OrderedDict() prev_table = None @@ -283,27 +293,24 @@ def find_path(self, directed=True) -> OrderedDict: @cached_property def path(self) -> OrderedDict: """Return list of full table names in chain.""" - if not self._has_link: + if self._searched and not self.has_link: return None link = None if link := self.find_path(directed=True): - self._has_directed_link = True + self.link_type = "directed" elif link := self.find_path(directed=False): - self._has_directed_link = False + self.link_type = "undirected" + self._searched = True - if link: - return link - - self._has_link = False - return None + return link @cached_property def names(self) -> List[str]: """Return list of full table names in chain.""" - if self._has_link: - return list(self.path.keys()) - return None + if not self.has_link: + return None + return list(self.path.keys()) @cached_property def objects(self) -> List[dj.FreeTable]: @@ -311,9 +318,9 @@ def objects(self) -> List[dj.FreeTable]: Unused. Preserved for future debugging. """ - if self._has_link: - return [v["free_table"] for v in self.path.values()] - return None + if not self.has_link: + return None + return [v["free_table"] for v in self.path.values()] @cached_property def attr_maps(self) -> List[dict]: @@ -321,10 +328,9 @@ def attr_maps(self) -> List[dict]: Unused. Preserved for future debugging. """ - # - if self._has_link: - return [v["attr_map"] for v in self.path.values()] - return None + if not self.has_link: + return None + return [v["attr_map"] for v in self.path.values()] def join( self, restriction: str = None, reverse_order: bool = False @@ -339,7 +345,7 @@ def join( reverse_order : bool, optional If True, join tables in reverse order. Defaults to False. """ - if not self._has_link: + if not self.has_link: return None restriction = restriction or self.parent.restriction or True diff --git a/tests/utils/test_chains.py b/tests/utils/test_chains.py index bc88e7007..7ba4b1fa2 100644 --- a/tests/utils/test_chains.py +++ b/tests/utils/test_chains.py @@ -1,4 +1,5 @@ import pytest +from datajoint.utils import to_camel_case @pytest.fixture(scope="session") @@ -31,16 +32,13 @@ def test_invalid_chain(Nwbfile, pos_merge_tables, TableChain): def test_chain_str(chain): """Test that the str of a TableChain object is as expected.""" chain = chain - str_got = str(chain) - str_exp = ( - chain.parent.table_name + chain._link_symbol + chain.child.table_name - ) - assert str_got == str_exp, "Unexpected str of TableChain object." + parent = to_camel_case(chain.parent.table_name) + child = to_camel_case(chain.child.table_name) + str_got = str(chain) + str_exp = parent + chain._link_symbol + child -def test_chain_str_no_link(no_link_chain): - """Test that the str of a TableChain object with no link is as expected.""" - assert str(no_link_chain) == "No link", "Unexpected str of no link chain." + assert str_got == str_exp, "Unexpected str of TableChain object." def test_chain_repr(chain): @@ -66,3 +64,8 @@ def test_chain_getitem(chain): def test_nolink_join(no_link_chain): assert no_link_chain.join() is None, "Unexpected join of no link chain." + + +def test_chain_str_no_link(no_link_chain): + """Test that the str of a TableChain object with no link is as expected.""" + assert str(no_link_chain) == "No link", "Unexpected str of no link chain."