Skip to content

Commit

Permalink
fix topological sort
Browse files Browse the repository at this point in the history
  • Loading branch information
dimitri-yatsenko committed Sep 15, 2024
1 parent a470d66 commit 4be8e39
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 7 deletions.
64 changes: 61 additions & 3 deletions datajoint/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,65 @@
import networkx as nx
import itertools
import re
from collections import defaultdict
from .errors import DataJointError


def topo_sort(graph):
"""
topological sort of a dependency graph that keeps part tables together with their masters
:return: list of table names in topological order
"""
graph = nx.DiGraph(graph) # make a copy

# collapse alias nodes
alias_nodes = [node for node in graph if node.isdigit()]
for node in alias_nodes:
direct_edge = (
next(x for x in graph.in_edges(node))[0],
next(x for x in graph.out_edges(node))[1],
)
graph.add_edge(*direct_edge)
graph.remove_nodes_from(alias_nodes)

# Add parts' dependencies to their masters' dependencies
# to ensure correct topological ordering of the masters.
part_pattern = re.compile(r"(?P<master>`\w+`.`#?\w+)__\w+`")
for part in graph:
# print part tables and their master
match = part_pattern.match(part)
if match:
master = match["master"] + "`"
for edge in graph.in_edges(part):
if edge[0] != master:
graph.add_edge(edge[0], master)

sorted_nodes = list(nx.algorithms.topological_sort(graph))

# bring parts up to their masters
pos = len(sorted_nodes)
while pos > 0:
pos -= 1
part = sorted_nodes[pos]
match = part_pattern.match(part)
if match:
master = match["master"] + "`"
print(part, master)
try:
j = sorted_nodes.index(master)
except ValueError:
# master not found
continue
if pos > j + 1:
print(pos, j)
# move the part to its master
del sorted_nodes[pos]
sorted_nodes.insert(j + 1, part)
pos += 1

return sorted_nodes


class Dependencies(nx.DiGraph):
"""
The graph of dependencies (foreign keys) between loaded tables.
Expand Down Expand Up @@ -107,8 +163,8 @@ def load(self, force=True):
self._loaded = True

def topo_sort(self):
""":return: list of nodes in lexcigraphical topological order"""
return list(nx.algorithms.dag.lexicographical_topological_sort(self))
""":return: list of tables names in topological order"""
return topo_sort(self)

def parents(self, table_name, primary=None):
"""
Expand Down Expand Up @@ -146,7 +202,9 @@ def descendants(self, full_table_name):
:return: all dependent tables sorted in topological order. Self is included.
"""
self.load(force=False)
nodes = self.subgraph(nx.algorithms.dag.descendants(self, full_table_name)).copy()
nodes = self.subgraph(
nx.algorithms.dag.descendants(self, full_table_name)
).copy()
return [full_table_name] + nodes.topo_sort()

def ancestors(self, full_table_name):
Expand Down
18 changes: 14 additions & 4 deletions datajoint/diagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import inspect
from .table import Table
from .dependencies import topo_sort
from .user_tables import Manual, Imported, Computed, Lookup, Part
from .errors import DataJointError
from .table import lookup_class_name
Expand Down Expand Up @@ -38,6 +39,7 @@ class _AliasNode:


def _get_tier(table_name):
"""given the table name, return"""
if not table_name.startswith("`"):
return _AliasNode
else:
Expand Down Expand Up @@ -70,19 +72,22 @@ def __init__(self, *args, **kwargs):

class Diagram(nx.DiGraph):
"""
Entity relationship diagram.
Schema diagram showing tables and foreign keys between in the form of a directed
acyclic graph (DAG). The diagram is derived from the connection.dependencies object.
Usage:
>>> diag = Diagram(source)
source can be a base table object, a base table class, a schema, or a module that has a schema.
source can be a table object, a table class, a schema, or a module that has a schema.
>>> diag.draw()
draws the diagram using pyplot
diag1 + diag2 - combines the two diagrams.
diag1 - diag2 - differente between diagrams
diag1 * diag2 - intersction of diagrams
diag + n - expands n levels of successors
diag - n - expands n levels of predecessors
Thus dj.Diagram(schema.Table)+1-1 defines the diagram of immediate ancestors and descendants of schema.Table
Expand All @@ -91,7 +96,8 @@ class Diagram(nx.DiGraph):
Only those tables that are loaded in the connection object are displayed
"""

def __init__(self, source, context=None):
def __init__(self, source=None, context=None):

if isinstance(source, Diagram):
# copy constructor
self.nodes_to_show = set(source.nodes_to_show)
Expand Down Expand Up @@ -152,7 +158,7 @@ def from_sequence(cls, sequence):

def add_parts(self):
"""
Adds to the diagram the part tables of tables already included in the diagram
Adds to the diagram the part tables of all master tables already in the diagram
:return:
"""

Expand Down Expand Up @@ -244,6 +250,10 @@ def __mul__(self, arg):
self.nodes_to_show.intersection_update(arg.nodes_to_show)
return self

def topo_sort(self):
"""return nodes in lexicographical topological order"""
return topo_sort(self)

def _make_graph(self):
"""
Make the self.graph - a graph object ready for drawing
Expand Down

0 comments on commit 4be8e39

Please sign in to comment.