Skip to content

Commit

Permalink
Fix graph processing by visiting each paths
Browse files Browse the repository at this point in the history
  • Loading branch information
bra-fsn committed Jun 6, 2022
1 parent 1b2371b commit 5d0a211
Showing 1 changed file with 59 additions and 22 deletions.
81 changes: 59 additions & 22 deletions pandag/pandag.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from pandag.nodes import Node, Output
from pandag import plot, graphml
import more_itertools
import itertools
import logging


class FakeDiGraph(nx.DiGraph):
Expand Down Expand Up @@ -127,6 +129,14 @@ def create_graph(self, sub, parent=None, local_dict=None, global_dict=None):
local_dict=local_dict,
global_dict=global_dict)

def start_nodes(self):
"""Return nodes which don't have incoming edges."""
return [node for node in self.G.nodes if self.G.in_degree(node) == 0]

def end_nodes(self):
"""Return nodes which don't have outgoing edges."""
return [node for node in self.G.nodes if self.G.out_degree(node) == 0]

def eval(self, df):
"""Evaluate a Pandas DataFrame with the graph.
Expand All @@ -139,29 +149,56 @@ def eval(self, df):
"""
# generate a unique column name
node_col = f'{self.uuid}_curr_node'
for src_node_id, dst_node_id in nx.edge_dfs(self.G):
if node_col not in df.columns:
# this is the first node we're visiting, add its id to the
# `node_col` column as the current node ID for each rows
df[node_col] = src_node_id

# initialize the path column with the first node id as string, so
# we can later append new node IDs with a vectorized operation
if self.path_column is not None:
df[self.path_column] = str(src_node_id)
edge_data = self.G.get_edge_data(src_node_id, dst_node_id)
src_node = self.get_node(src_node_id)
dst_node = self.get_node(dst_node_id)
if isinstance(dst_node, Output):
dst_node.update(df, df[node_col] == src_node_id)
flt = (df[node_col] == src_node_id) & (src_node.eval(df, edge_data))
# while not recommended, handle multiple start nodes (even multiple
# DAGs) in a general way, so we detect all start and end nodes
start_nodes_visited = set()
start_nodes = set(self.start_nodes())
end_nodes = set(self.end_nodes())
if len(start_nodes) > 1:
logging.warning(f"The DAG has {len(start_nodes)}, output might be non-deterministic!")

# loop through all start->end permutations
for start, end in itertools.product(start_nodes, end_nodes):
if start in end_nodes or end in start_nodes:
# exclude backwards permutations (end -> start)
continue
if self.path_column:
# store the path which touched these rows
df.loc[flt, self.path_column] = df[self.path_column] + f',{dst_node_id}'
# update the matching rows' curr_node column to the next node,
# we'll use this to select the source rows for running the
# edge pointing from this node to the next
df.loc[flt, node_col] = dst_node_id
# for the first run, we initialize the column with a string,
# after that we're just appending
if start_nodes_visited:
if start not in start_nodes_visited:
df[self.path_column] = df[self.path_column] + f',{start}'
else:
# initialize the path column with the first node id as string, so
# we can later append new node IDs with a vectorized operation
df[self.path_column] = str(start)
# (re)set the current node ID for each new start nodes
# if we've already visited this start node, leave the field alone,
# so we can make progress on sub-frames which are at a given node
# ID from a different path
if start not in start_nodes_visited:
df[node_col] = start
for path in nx.all_simple_edge_paths(self.G, start, end):
# this represents an edge between two nodes (src -> dst) in the path
for (src_node_id, dst_node_id) in path:
edge_data = self.G.get_edge_data(src_node_id, dst_node_id)
src_node = self.get_node(src_node_id)

flt = (df[node_col] == src_node_id)
if isinstance(src_node, Output):
src_node.update(df, flt)
else:
flt &= (src_node.eval(df, edge_data))
# update the matching rows' curr_node column to the next node,
# we'll use this to select the source rows for running the
# edge pointing from this node to the next
df.loc[flt, node_col] = dst_node_id
if self.path_column:
# store the path which touched these rows
df.loc[flt, self.path_column] = df[self.path_column] + f',{dst_node_id}'

# record that we've already visited this start node
start_nodes_visited.add(start)

# remove the temporary current node column
df = df.drop(node_col, axis=1)
Expand Down

0 comments on commit 5d0a211

Please sign in to comment.