Skip to content

Commit

Permalink
Fix failing chains pytests (#867)
Browse files Browse the repository at this point in the history
* Fix failing chains pytests

* Update docstrings, changelog
  • Loading branch information
CBroz1 authored Mar 18, 2024
1 parent 10bf4ac commit ade48ea
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 42 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Change Log

## [0.5.2] (Unreleased)

- Refactor `TableChain` to include `_searched` attribute. #867

## [0.5.1] (March 7, 2024)

### Infrastructure
Expand Down
74 changes: 40 additions & 34 deletions src/spyglass/utils/dj_chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,13 @@ class TableChain:
_link_symbol : str
Symbol used to represent the link between parent and child. Hardcoded
to " -> ".
_has_link : bool
has_link : bool
Cached attribute to store whether parent is linked to child. False if
child is not in parent.descendants or nx.NetworkXNoPath is raised by
nx.shortest_path.
_has_directed_link : bool
True if directed graph is used to find path. False if undirected graph.
link_type : str
'directed' or 'undirected' based on whether path is found with directed
or undirected graph. None if no path is found.
graph : nx.DiGraph
Directed graph of parent's dependencies from datajoint.connection.
names : List[str]
Expand Down Expand Up @@ -175,18 +176,19 @@ def __init__(self, parent: Table, child: Table, connection=None):
self._link_symbol = " -> "
self.parent = parent
self.child = child
self._has_link = True
self._has_directed_link = None
self.link_type = None
self._searched = False

if child.full_table_name not in self.graph.nodes:
logger.warning(
"Can't find item in graph. Try importing: "
+ f"{child.full_table_name}"
)
self._searched = True

def __str__(self):
"""Return string representation of chain: parent -> child."""
if not self._has_link:
if not self.has_link:
return "No link"
return (
to_camel_case(self.parent.table_name)
Expand All @@ -196,19 +198,22 @@ def __str__(self):

def __repr__(self):
"""Return full representation of chain: parent -> {links} -> child."""
return (
"Chain: "
+ self._link_symbol.join([t.table_name for t in self.objects])
if self.names
else "No link"
if not self.has_link:
return "No link"
return "Chain: " + self._link_symbol.join(
[t.table_name for t in self.objects]
)

def __len__(self):
"""Return number of tables in chain."""
if not self.has_link:
return 0
return len(self.names)

def __getitem__(self, index: Union[int, str]) -> dj.FreeTable:
"""Return FreeTable object at index."""
if not self.has_link:
return None
if isinstance(index, str):
for i, name in enumerate(self.names):
if index in name:
Expand All @@ -219,10 +224,12 @@ def __getitem__(self, index: Union[int, str]) -> dj.FreeTable:
def has_link(self) -> bool:
"""Return True if parent is linked to child.
Cached as hidden attribute _has_link to set False if nx.NetworkXNoPath
is raised by nx.shortest_path.
If not searched, search for path. If searched and no link is found,
return False. If searched and link is found, return True.
"""
return self._has_link
if not self._searched:
_ = self.path
return self.link_type is not None

def pk_link(self, src, trg, data) -> float:
"""Return 1 if data["primary"] else float("inf").
Expand All @@ -242,7 +249,7 @@ def find_path(self, directed=True) -> OrderedDict:
If True, use directed graph. If False, use undirected graph.
Defaults to True. Undirected permits paths to traverse from merge
part-parent -> merge part -> merge table. Undirected excludes
PERIPHERAL_TABLES likne interval_list, nwbfile, etc.
PERIPHERAL_TABLES like interval_list, nwbfile, etc.
Returns
-------
Expand All @@ -265,6 +272,9 @@ def find_path(self, directed=True) -> OrderedDict:
path = nx.shortest_path(self.graph, source, target)
except nx.NetworkXNoPath:
return None
except nx.NodeNotFound:
self._searched = True
return None

ret = OrderedDict()
prev_table = None
Expand All @@ -283,48 +293,44 @@ def find_path(self, directed=True) -> OrderedDict:
@cached_property
def path(self) -> OrderedDict:
"""Return list of full table names in chain."""
if not self._has_link:
if self._searched and not self.has_link:
return None

link = None
if link := self.find_path(directed=True):
self._has_directed_link = True
self.link_type = "directed"
elif link := self.find_path(directed=False):
self._has_directed_link = False
self.link_type = "undirected"
self._searched = True

if link:
return link

self._has_link = False
return None
return link

@cached_property
def names(self) -> List[str]:
"""Return list of full table names in chain."""
if self._has_link:
return list(self.path.keys())
return None
if not self.has_link:
return None
return list(self.path.keys())

@cached_property
def objects(self) -> List[dj.FreeTable]:
"""Return list of FreeTable objects for each table in chain.
Unused. Preserved for future debugging.
"""
if self._has_link:
return [v["free_table"] for v in self.path.values()]
return None
if not self.has_link:
return None
return [v["free_table"] for v in self.path.values()]

@cached_property
def attr_maps(self) -> List[dict]:
"""Return list of attribute maps for each table in chain.
Unused. Preserved for future debugging.
"""
#
if self._has_link:
return [v["attr_map"] for v in self.path.values()]
return None
if not self.has_link:
return None
return [v["attr_map"] for v in self.path.values()]

def join(
self, restriction: str = None, reverse_order: bool = False
Expand All @@ -339,7 +345,7 @@ def join(
reverse_order : bool, optional
If True, join tables in reverse order. Defaults to False.
"""
if not self._has_link:
if not self.has_link:
return None

restriction = restriction or self.parent.restriction or True
Expand Down
19 changes: 11 additions & 8 deletions tests/utils/test_chains.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
from datajoint.utils import to_camel_case


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -31,16 +32,13 @@ def test_invalid_chain(Nwbfile, pos_merge_tables, TableChain):
def test_chain_str(chain):
"""Test that the str of a TableChain object is as expected."""
chain = chain
str_got = str(chain)
str_exp = (
chain.parent.table_name + chain._link_symbol + chain.child.table_name
)
assert str_got == str_exp, "Unexpected str of TableChain object."
parent = to_camel_case(chain.parent.table_name)
child = to_camel_case(chain.child.table_name)

str_got = str(chain)
str_exp = parent + chain._link_symbol + child

def test_chain_str_no_link(no_link_chain):
"""Test that the str of a TableChain object with no link is as expected."""
assert str(no_link_chain) == "No link", "Unexpected str of no link chain."
assert str_got == str_exp, "Unexpected str of TableChain object."


def test_chain_repr(chain):
Expand All @@ -66,3 +64,8 @@ def test_chain_getitem(chain):

def test_nolink_join(no_link_chain):
assert no_link_chain.join() is None, "Unexpected join of no link chain."


def test_chain_str_no_link(no_link_chain):
"""Test that the str of a TableChain object with no link is as expected."""
assert str(no_link_chain) == "No link", "Unexpected str of no link chain."

0 comments on commit ade48ea

Please sign in to comment.