Skip to content

Commit

Permalink
Fixes for creating nested SDFGs with structs and inlining them (#1888)
Browse files Browse the repository at this point in the history
  • Loading branch information
phschaad authored Jan 22, 2025
1 parent 855fc27 commit 1a56352
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 17 deletions.
16 changes: 9 additions & 7 deletions dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -3744,18 +3744,20 @@ def _is_inputnode(self, sdfg: SDFG, name: str):
for state in sdfg.states():
visited_state_data = set()
for node in state.nodes():
if isinstance(node, nodes.AccessNode) and node.data == name:
visited_state_data.add(node.data)
if (node.data not in visited_data and state.in_degree(node) == 0):
return True
if isinstance(node, nodes.AccessNode):
if node.data == name or ('.' in node.data and node.data.split('.')[0] == name):
visited_state_data.add(node.data)
if (node.data not in visited_data and state.in_degree(node) == 0):
return True
visited_data = visited_data.union(visited_state_data)

def _is_outputnode(self, sdfg: SDFG, name: str):
for state in sdfg.states():
for node in state.nodes():
if isinstance(node, nodes.AccessNode) and node.data == name:
if state.in_degree(node) > 0:
return True
if isinstance(node, nodes.AccessNode):
if node.data == name or ('.' in node.data and node.data.split('.')[0] == name):
if state.in_degree(node) > 0:
return True

def _get_sdfg(self, value: Any, args: Tuple[Any], kwargs: Dict[str, Any]) -> SDFG:
if isinstance(value, SDFG): # Already an SDFG
Expand Down
50 changes: 40 additions & 10 deletions dace/transformation/interstate/sdfg_nesting.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,11 @@ def apply(self, state: SDFGState, sdfg: SDFG):
pass
for node in nstate.sink_nodes():
if (isinstance(node, nodes.AccessNode) and node.data not in transients and node.data not in reshapes):
new_outgoing_edges[node] = outputs[node.data]
sink_accesses.add(node)
try:
new_outgoing_edges[node] = outputs[node.data]
sink_accesses.add(node)
except KeyError:
pass

# All constants (and associated transients) become constants of the parent
for cstname, (csttype, cstval) in nsdfg.constants_prop.items():
Expand Down Expand Up @@ -427,11 +430,26 @@ def apply(self, state: SDFGState, sdfg: SDFG):

orig_data: Dict[Union[nodes.AccessNode, MultiConnectorEdge], str] = {}
for node in nstate.nodes():
if isinstance(node, nodes.AccessNode) and node.data in repldict:
orig_data[node] = node.data
node.data = repldict[node.data]
if isinstance(node, nodes.AccessNode):
if '.' in node.data:
parts = node.data.split('.')
root_container = parts[0]
if root_container in repldict:
orig_data[node] = node.data
full_data = [repldict[root_container]] + parts[1:]
node.data = '.'.join(full_data)
elif node.data in repldict:
orig_data[node] = node.data
node.data = repldict[node.data]
for edge in nstate.edges():
if edge.data.data in repldict:
if edge.data.data is not None and '.' in edge.data.data:
parts = edge.data.data.split('.')
root_container = parts[0]
if root_container in repldict:
orig_data[edge] = edge.data.data
full_data = [repldict[root_container]] + parts[1:]
edge.data.data = '.'.join(full_data)
elif edge.data.data in repldict:
orig_data[edge] = edge.data.data
edge.data.data = repldict[edge.data.data]

Expand Down Expand Up @@ -557,27 +575,39 @@ def apply(self, state: SDFGState, sdfg: SDFG):
for edge in removed_in_edges:
# Find first access node that refers to this edge
try:
node = next(n for n in order if n.data == edge.data.data)
node = next(n for n in order
if n.data == edge.data.data or ('.' in n.data and n.data.split('.')[0] == edge.data.data))
except StopIteration:
continue
# raise NameError(f'Access node with data "{edge.data.data}" not found in'
# f' nested SDFG "{nsdfg.name}" while inlining '
# '(reconnecting inputs)')
state.add_edge(edge.src, edge.src_conn, node, edge.dst_conn, edge.data)
if node.data != edge.data.data:
anode = state.add_access(edge.data.data)
state.add_edge(edge.src, edge.src_conn, anode, edge.dst_conn, edge.data)
state.add_edge(anode, None, node, None, Memlet())
else:
state.add_edge(edge.src, edge.src_conn, node, edge.dst_conn, edge.data)
# Fission state if necessary
cc = utils.weakly_connected_component(state, node)
if not any(n in cc for n in subgraph.nodes()):
helpers.state_fission(cc)
for edge in removed_out_edges:
# Find last access node that refers to this edge
try:
node = next(n for n in reversed(order) if n.data == edge.data.data)
node = next(n for n in reversed(order)
if n.data == edge.data.data or ('.' in n.data and n.data.split('.')[0] == edge.data.data))
except StopIteration:
continue
# raise NameError(f'Access node with data "{edge.data.data}" not found in'
# f' nested SDFG "{nsdfg.name}" while inlining '
# '(reconnecting outputs)')
state.add_edge(node, edge.src_conn, edge.dst, edge.dst_conn, edge.data)
if node.data != edge.data.data:
anode = state.add_access(edge.data.data)
state.add_edge(node, None, anode, None, Memlet())
state.add_edge(anode, edge.src_conn, edge.dst, edge.dst_conn, edge.data)
else:
state.add_edge(node, edge.src_conn, edge.dst, edge.dst_conn, edge.data)
# Fission state if necessary
cc = utils.weakly_connected_component(state, node)
if not any(n in cc for n in subgraph.nodes()):
Expand Down

0 comments on commit 1a56352

Please sign in to comment.