From 3509351c005db6f1894f2ef6ef8b16819ba6e8a6 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Tue, 21 Jan 2025 15:39:03 +0100 Subject: [PATCH 1/2] Fix missing replacement of loop variable --- dace/sdfg/state.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 30640306cd..47c180aff6 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -2620,7 +2620,7 @@ def get_meta_read_memlets(self) -> List[mm.Memlet]: """ return [] - def replace_meta_accesses(self, replacements: dict) -> None: + def replace_meta_accesses(self, replacements: Dict[str, str]) -> None: """ Replace accesses to specific data containers in reads or writes performed by the control flow region itself in meta accesses, such as in condition checks for conditional blocks or in loop conditions for loops, etc. @@ -3331,6 +3331,8 @@ def get_meta_read_memlets(self) -> List[mm.Memlet]: return read_memlets def replace_meta_accesses(self, replacements): + if self.loop_variable in replacements: + self.loop_variable = replacements[self.loop_variable] replace_in_codeblock(self.loop_condition, replacements) if self.init_statement: replace_in_codeblock(self.init_statement, replacements) From 4fe276f509a87d751da9dabc6c9ef09e0d5d0c3b Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Tue, 21 Jan 2025 16:09:50 +0100 Subject: [PATCH 2/2] Fix prune symbols removing used symbols --- dace/transformation/passes/prune_symbols.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dace/transformation/passes/prune_symbols.py b/dace/transformation/passes/prune_symbols.py index a01d903a1d..c501a769ff 100644 --- a/dace/transformation/passes/prune_symbols.py +++ b/dace/transformation/passes/prune_symbols.py @@ -111,6 +111,8 @@ def used_symbols(self, sdfg: SDFG) -> Set[str]: if node.code_exit.language != dtypes.Language.Python: result |= symbolic.symbols_in_code(node.code_exit.as_string, sdfg.symbols.keys(), node.ignored_symbols) + else: + result |= block.used_symbols(all_symbols=True, with_contents=False) for e in sdfg.all_interstate_edges(): result |= e.data.free_symbols