diff --git a/pandag/pandag.py b/pandag/pandag.py index 269a630..577a5d5 100644 --- a/pandag/pandag.py +++ b/pandag/pandag.py @@ -5,6 +5,8 @@ from pandag.nodes import Node, Output from pandag import plot, graphml import more_itertools +import itertools +import logging class FakeDiGraph(nx.DiGraph): @@ -127,6 +129,14 @@ def create_graph(self, sub, parent=None, local_dict=None, global_dict=None): local_dict=local_dict, global_dict=global_dict) + def start_nodes(self): + """Return nodes which don't have incoming edges.""" + return [node for node in self.G.nodes if self.G.in_degree(node) == 0] + + def end_nodes(self): + """Return nodes which don't have outgoing edges.""" + return [node for node in self.G.nodes if self.G.out_degree(node) == 0] + def eval(self, df): """Evaluate a Pandas DataFrame with the graph. @@ -139,29 +149,56 @@ def eval(self, df): """ # generate a unique column name node_col = f'{self.uuid}_curr_node' - for src_node_id, dst_node_id in nx.edge_dfs(self.G): - if node_col not in df.columns: - # this is the first node we're visiting, add its id to the - # `node_col` column as the current node ID for each rows - df[node_col] = src_node_id - - # initialize the path column with the first node id as string, so - # we can later append new node IDs with a vectorized operation - if self.path_column is not None: - df[self.path_column] = str(src_node_id) - edge_data = self.G.get_edge_data(src_node_id, dst_node_id) - src_node = self.get_node(src_node_id) - dst_node = self.get_node(dst_node_id) - if isinstance(dst_node, Output): - dst_node.update(df, df[node_col] == src_node_id) - flt = (df[node_col] == src_node_id) & (src_node.eval(df, edge_data)) + # while not recommended, handle multiple start nodes (even multiple + # DAGs) in a general way, so we detect all start and end nodes + start_nodes_visited = set() + start_nodes = set(self.start_nodes()) + end_nodes = set(self.end_nodes()) + if len(start_nodes) > 1: + logging.warning(f"The DAG has {len(start_nodes)}, output might be non-deterministic!") + + # loop through all start->end permutations + for start, end in itertools.product(start_nodes, end_nodes): + if start in end_nodes or end in start_nodes: + # exclude backwards permutations (end -> start) + continue if self.path_column: - # store the path which touched these rows - df.loc[flt, self.path_column] = df[self.path_column] + f',{dst_node_id}' - # update the matching rows' curr_node column to the next node, - # we'll use this to select the source rows for running the - # edge pointing from this node to the next - df.loc[flt, node_col] = dst_node_id + # for the first run, we initialize the column with a string, + # after that we're just appending + if start_nodes_visited: + if start not in start_nodes_visited: + df[self.path_column] = df[self.path_column] + f',{start}' + else: + # initialize the path column with the first node id as string, so + # we can later append new node IDs with a vectorized operation + df[self.path_column] = str(start) + # (re)set the current node ID for each new start nodes + # if we've already visited this start node, leave the field alone, + # so we can make progress on sub-frames which are at a given node + # ID from a different path + if start not in start_nodes_visited: + df[node_col] = start + for path in nx.all_simple_edge_paths(self.G, start, end): + # this represents an edge between two nodes (src -> dst) in the path + for (src_node_id, dst_node_id) in path: + edge_data = self.G.get_edge_data(src_node_id, dst_node_id) + src_node = self.get_node(src_node_id) + + flt = (df[node_col] == src_node_id) + if isinstance(src_node, Output): + src_node.update(df, flt) + else: + flt &= (src_node.eval(df, edge_data)) + # update the matching rows' curr_node column to the next node, + # we'll use this to select the source rows for running the + # edge pointing from this node to the next + df.loc[flt, node_col] = dst_node_id + if self.path_column: + # store the path which touched these rows + df.loc[flt, self.path_column] = df[self.path_column] + f',{dst_node_id}' + + # record that we've already visited this start node + start_nodes_visited.add(start) # remove the temporary current node column df = df.drop(node_col, axis=1)