From 001d0e690ac15d4e03e678615a8ef32dceaea53a Mon Sep 17 00:00:00 2001 From: esc Date: Tue, 28 Jan 2025 19:24:17 +0100 Subject: [PATCH 01/16] implement handling simple and/or expressions As title --- .../core/datastructures/ast_transforms.py | 119 +++++++++++++++++- numba_rvsdg/tests/test_ast_transforms.py | 65 ++++++++++ 2 files changed, 183 insertions(+), 1 deletion(-) diff --git a/numba_rvsdg/core/datastructures/ast_transforms.py b/numba_rvsdg/core/datastructures/ast_transforms.py index e87ef39..b92edde 100644 --- a/numba_rvsdg/core/datastructures/ast_transforms.py +++ b/numba_rvsdg/core/datastructures/ast_transforms.py @@ -317,9 +317,16 @@ def handle_ast_node(self, node: type[ast.AST] | ast.stmt) -> None: node, ( ast.Assign, - ast.AugAssign, ast.Expr, ast.Return, + ), + ): + node.value = self.handle_expression(node.value) + self.current_block.instructions.append(node) + elif isinstance( + node, + ( + ast.AugAssign, ast.Break, ast.Continue, ast.Pass, @@ -335,6 +342,116 @@ def handle_ast_node(self, node: type[ast.AST] | ast.stmt) -> None: else: raise NotImplementedError(f"Node type {node} not implemented") + def handle_expression(self, node: Any) -> Any: + """Recursively handle expression nodes and their subexpressions. + Returns the processed expression.""" + + if isinstance(node, ast.BoolOp): + # Handle or/and operations + return self.handle_bool_op(node) + elif isinstance(node, ast.Compare): + # Recursively handle left and right sides of comparison + node.left = self.handle_expression(node.left) + for i, comparator in enumerate(node.comparators): + node.comparators[i] = self.handle_expression(comparator) + return node + elif isinstance(node, ast.BinOp): + # Handle binary operations (+, -, *, / etc) + node.left = self.handle_expression(node.left) + node.right = self.handle_expression(node.right) + return node + elif isinstance(node, ast.Call): + # Handle function calls + for i, arg in enumerate(node.args): + node.args[i] = self.handle_expression(arg) + return node + else: + # Base case: literals, names, etc. + return node + + def handle_bool_op(self, node: ast.BoolOp) -> ast.Name: + """Handle boolean operations (and/or). + Returns an ast.Name representing the result variable.""" + + # Create a new temp variable to store the result + result_var = f"__scfg_bool_op_{self.block_index}__" + + # Generate code for first value + left = self.handle_expression(node.values[0]) + self.current_block.instructions.append( + ast.Assign( + targets=[ast.Name(id=result_var, ctx=ast.Store())], + value=left, + lineno=0, + ) + ) + + if isinstance(node.op, ast.Or): + # Create blocks for the true and false paths + false_block_index = self.block_index + merge_block_index = self.block_index + 1 + self.block_index += 2 + + # Test and jump based on first value + self.current_block.instructions.append( + ast.Name(id=result_var, ctx=ast.Load()) + ) + self.current_block.set_jump_targets( + merge_block_index, false_block_index + ) + + # False block evaluates second value + self.add_block(false_block_index) + right = self.handle_expression(node.values[1]) + self.current_block.instructions.append( + ast.Assign( + targets=[ast.Name(id=result_var, ctx=ast.Store())], + value=right, + lineno=0, + ) + ) + self.current_block.set_jump_targets(merge_block_index) + + # Create merge block + self.add_block(merge_block_index) + + elif isinstance(node.op, ast.And): + # Create blocks for the true and false paths + true_block_index = self.block_index + merge_block_index = self.block_index + 1 + self.block_index += 2 + + # Test and jump based on first value + self.current_block.instructions.append( + ast.Name(id=result_var, ctx=ast.Load()) + ) + self.current_block.set_jump_targets( + true_block_index, merge_block_index + ) + + # True block evaluates second value + self.add_block(true_block_index) + right = self.handle_expression(node.values[1]) + self.current_block.instructions.append( + ast.Assign( + targets=[ast.Name(id=result_var, ctx=ast.Store())], + value=right, + lineno=0, + ) + ) + self.current_block.set_jump_targets(merge_block_index) + + # Create merge block + self.add_block(merge_block_index) + + else: + raise NotImplementedError( + "Only 'or' operations currently supported" + ) + + # Return name node referencing our result variable + return ast.Name(id=result_var, ctx=ast.Load()) + def handle_function_def(self, node: ast.FunctionDef) -> None: """Handle a function definition.""" # Insert implicit return None, if the function isn't terminated. May diff --git a/numba_rvsdg/tests/test_ast_transforms.py b/numba_rvsdg/tests/test_ast_transforms.py index bfc7dc0..7b9ae56 100644 --- a/numba_rvsdg/tests/test_ast_transforms.py +++ b/numba_rvsdg/tests/test_ast_transforms.py @@ -1227,6 +1227,71 @@ def function(x: int) -> int: arguments=arguments, ) + def test_and(self): + def function(x: int, y: int) -> int: + return ( + x and y + ) # Returns last truthy value, or False if any others are falsy + + expected = { + "0": { + "instructions": [ + "__scfg_bool_op_1__ = x", + "__scfg_bool_op_1__", + ], + "jump_targets": ["1", "2"], + "name": "0", + }, + "1": { + "instructions": ["__scfg_bool_op_1__ = y"], + "jump_targets": ["2"], + "name": "1", + }, + "2": { + "instructions": ["return __scfg_bool_op_1__"], + "jump_targets": [], + "name": "2", + }, + } + self.compare(function, expected, arguments=[(0, 0), (0, 1), (1, 0)]) + + def test_or(self): + def function(x: int, y: int) -> int: + return ( + x or y + ) # Returns first truthy value, or last value if all falsy + + expected = { + "0": { + "instructions": [ + "__scfg_bool_op_1__ = x", + "__scfg_bool_op_1__", + ], + "jump_targets": ["2", "1"], + "name": "0", + }, + "1": { + "instructions": ["__scfg_bool_op_1__ = y"], + "jump_targets": ["2"], + "name": "1", + }, + "2": { + "instructions": ["return __scfg_bool_op_1__"], + "jump_targets": [], + "name": "2", + }, + } + + self.compare( + function, + expected, + arguments=[ + (1, 0), # First value truthy + (0, 2), # First value falsy, second truthy + (0, 0), # All values falsy - returns last value + ], + ) + class TestEntryPoints(TestCase): From eb040cd2362450b5059e698af89b37d27509b0dc Mon Sep 17 00:00:00 2001 From: esc Date: Tue, 28 Jan 2025 19:30:08 +0100 Subject: [PATCH 02/16] fix misleading error message As title --- numba_rvsdg/core/datastructures/ast_transforms.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/numba_rvsdg/core/datastructures/ast_transforms.py b/numba_rvsdg/core/datastructures/ast_transforms.py index b92edde..37009ab 100644 --- a/numba_rvsdg/core/datastructures/ast_transforms.py +++ b/numba_rvsdg/core/datastructures/ast_transforms.py @@ -445,9 +445,7 @@ def handle_bool_op(self, node: ast.BoolOp) -> ast.Name: self.add_block(merge_block_index) else: - raise NotImplementedError( - "Only 'or' operations currently supported" - ) + raise NotImplementedError("unreachable") # Return name node referencing our result variable return ast.Name(id=result_var, ctx=ast.Load()) From 029ef65d2d0b8b4424148405846a4921b9a6c748 Mon Sep 17 00:00:00 2001 From: esc Date: Thu, 30 Jan 2025 11:09:10 +0100 Subject: [PATCH 03/16] handle boolean operations with more than two operands As title --- .../core/datastructures/ast_transforms.py | 19 ++++++- numba_rvsdg/tests/test_ast_transforms.py | 52 +++++++++++++++++++ 2 files changed, 69 insertions(+), 2 deletions(-) diff --git a/numba_rvsdg/core/datastructures/ast_transforms.py b/numba_rvsdg/core/datastructures/ast_transforms.py index 37009ab..317fb72 100644 --- a/numba_rvsdg/core/datastructures/ast_transforms.py +++ b/numba_rvsdg/core/datastructures/ast_transforms.py @@ -251,6 +251,7 @@ def __init__( self.code = code self.tree = unparse_code(code) self.block_index: int = 1 # 0 is reserved for genesis block + self.bool_op_index = 0 # can have multiple of these per block self.blocks = ASTCFG() # Initialize first (genesis) block, assume it's named zero. # (This also initializes the self.current_block attribute.) @@ -348,7 +349,20 @@ def handle_expression(self, node: Any) -> Any: if isinstance(node, ast.BoolOp): # Handle or/and operations - return self.handle_bool_op(node) + if len(node.values) > 2: + # In this case the bool operation has more than two operands, + # we need to deconstruct this into a binary tree of bool + # operations and recursively deal with those. The tail_node + # contains the tail of the operand list. + tail_node = ast.BoolOp(node.op, node.values[1:]) + return self.handle_bool_op( + ast.BoolOp(node.op, [node.values[0], tail_node]) + ) + elif len(node.values) == 2: + # Base case, boolean operation has only two operands. + return self.handle_bool_op(node) + else: + raise NotImplementedError("unreachable") elif isinstance(node, ast.Compare): # Recursively handle left and right sides of comparison node.left = self.handle_expression(node.left) @@ -374,7 +388,8 @@ def handle_bool_op(self, node: ast.BoolOp) -> ast.Name: Returns an ast.Name representing the result variable.""" # Create a new temp variable to store the result - result_var = f"__scfg_bool_op_{self.block_index}__" + self.bool_op_index += 1 + result_var = f"__scfg_bool_op_{self.bool_op_index}__" # Generate code for first value left = self.handle_expression(node.values[0]) diff --git a/numba_rvsdg/tests/test_ast_transforms.py b/numba_rvsdg/tests/test_ast_transforms.py index 7b9ae56..8b5dcb3 100644 --- a/numba_rvsdg/tests/test_ast_transforms.py +++ b/numba_rvsdg/tests/test_ast_transforms.py @@ -1292,6 +1292,58 @@ def function(x: int, y: int) -> int: ], ) + def test_complex_and(self): + def function(x: int, y: int, z: int) -> bool: + return x and y and z + + expected = { + "0": { + "instructions": [ + "__scfg_bool_op_1__ = x", + "__scfg_bool_op_1__", + ], + "jump_targets": ["1", "2"], + "name": "0", + }, + "1": { + "instructions": [ + "__scfg_bool_op_2__ = y", + "__scfg_bool_op_2__", + ], + "jump_targets": ["3", "4"], + "name": "1", + }, + "2": { + "instructions": ["return __scfg_bool_op_1__"], + "jump_targets": [], + "name": "2", + }, + "3": { + "instructions": ["__scfg_bool_op_2__ = z"], + "jump_targets": ["4"], + "name": "3", + }, + "4": { + "instructions": ["__scfg_bool_op_1__ = __scfg_bool_op_2__"], + "jump_targets": ["2"], + "name": "4", + }, + } + self.compare( + function, + expected, + arguments=[ + (0, 0, 0), # All false + (0, 0, 1), # Only z true - short circuits at x + (0, 1, 0), # Only y true - short circuits at x + (0, 1, 1), # y and z true - short circuits at x + (1, 0, 0), # Only x true - short circuits at y + (1, 0, 1), # x and z true - short circuits at y + (1, 1, 0), # x and y true - fails at z + (1, 1, 1), # All true - complete evaluation + ], + ) + class TestEntryPoints(TestCase): From aed82a23c30055abb9c646a6b9607808e75b1c6a Mon Sep 17 00:00:00 2001 From: esc Date: Thu, 30 Jan 2025 11:14:39 +0100 Subject: [PATCH 04/16] rename test function to be more descriptive As title --- numba_rvsdg/tests/test_ast_transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numba_rvsdg/tests/test_ast_transforms.py b/numba_rvsdg/tests/test_ast_transforms.py index 8b5dcb3..0968b05 100644 --- a/numba_rvsdg/tests/test_ast_transforms.py +++ b/numba_rvsdg/tests/test_ast_transforms.py @@ -1292,7 +1292,7 @@ def function(x: int, y: int) -> int: ], ) - def test_complex_and(self): + def test_multi_operand_and(self): def function(x: int, y: int, z: int) -> bool: return x and y and z From 6d249157c1eb3e0b1ec17aac4bf78c19d60826c1 Mon Sep 17 00:00:00 2001 From: esc Date: Thu, 30 Jan 2025 11:40:57 +0100 Subject: [PATCH 05/16] implement handling nested andor As title --- .../core/datastructures/ast_transforms.py | 12 ++- numba_rvsdg/tests/test_ast_transforms.py | 73 +++++++++++++++++++ 2 files changed, 83 insertions(+), 2 deletions(-) diff --git a/numba_rvsdg/core/datastructures/ast_transforms.py b/numba_rvsdg/core/datastructures/ast_transforms.py index 317fb72..efde9c5 100644 --- a/numba_rvsdg/core/datastructures/ast_transforms.py +++ b/numba_rvsdg/core/datastructures/ast_transforms.py @@ -356,11 +356,19 @@ def handle_expression(self, node: Any) -> Any: # contains the tail of the operand list. tail_node = ast.BoolOp(node.op, node.values[1:]) return self.handle_bool_op( - ast.BoolOp(node.op, [node.values[0], tail_node]) + ast.BoolOp( + self.handle_expression(node.op), + [node.values[0], tail_node], + ) ) elif len(node.values) == 2: # Base case, boolean operation has only two operands. - return self.handle_bool_op(node) + return self.handle_bool_op( + ast.BoolOp( + node.op, + [self.handle_expression(v) for v in node.values], + ) + ) else: raise NotImplementedError("unreachable") elif isinstance(node, ast.Compare): diff --git a/numba_rvsdg/tests/test_ast_transforms.py b/numba_rvsdg/tests/test_ast_transforms.py index 0968b05..e4c421e 100644 --- a/numba_rvsdg/tests/test_ast_transforms.py +++ b/numba_rvsdg/tests/test_ast_transforms.py @@ -1344,6 +1344,79 @@ def function(x: int, y: int, z: int) -> bool: ], ) + def test_nested_andor(self): + def function(x: int, y: int, a: int, b: int) -> bool: + return (x and y) or (a and b) + + expected = { + "0": { + "instructions": [ + "__scfg_bool_op_1__ = x", + "__scfg_bool_op_1__", + ], + "jump_targets": ["1", "2"], + "name": "0", + }, + "1": { + "instructions": ["__scfg_bool_op_1__ = y"], + "jump_targets": ["2"], + "name": "1", + }, + "2": { + "instructions": [ + "__scfg_bool_op_2__ = a", + "__scfg_bool_op_2__", + ], + "jump_targets": ["3", "4"], + "name": "2", + }, + "3": { + "instructions": ["__scfg_bool_op_2__ = b"], + "jump_targets": ["4"], + "name": "3", + }, + "4": { + "instructions": [ + "__scfg_bool_op_3__ = __scfg_bool_op_1__", + "__scfg_bool_op_3__", + ], + "jump_targets": ["6", "5"], + "name": "4", + }, + "5": { + "instructions": ["__scfg_bool_op_3__ = __scfg_bool_op_2__"], + "jump_targets": ["6"], + "name": "5", + }, + "6": { + "instructions": ["return __scfg_bool_op_3__"], + "jump_targets": [], + "name": "6", + }, + } + self.compare( + function, + expected, + arguments=[ + (0, 0, 0, 0), # F F F F -> False + (0, 0, 0, 1), # F F F T -> False + (0, 0, 1, 0), # F F T F -> False + (0, 0, 1, 1), # F F T T -> True + (0, 1, 0, 0), # F T F F -> False + (0, 1, 0, 1), # F T F T -> False + (0, 1, 1, 0), # F T T F -> False + (0, 1, 1, 1), # F T T T -> True + (1, 0, 0, 0), # T F F F -> False + (1, 0, 0, 1), # T F F T -> False + (1, 0, 1, 0), # T F T F -> False + (1, 0, 1, 1), # T F T T -> True + (1, 1, 0, 0), # T T F F -> True + (1, 1, 0, 1), # T T F T -> True + (1, 1, 1, 0), # T T T F -> True + (1, 1, 1, 1), # T T T T -> True + ], + ) + class TestEntryPoints(TestCase): From cf5141e91178d8f6a008c1fef51d868c75fdc050 Mon Sep 17 00:00:00 2001 From: esc Date: Thu, 30 Jan 2025 13:49:11 +0100 Subject: [PATCH 06/16] handle and/or expressions within branching statements As title --- .../core/datastructures/ast_transforms.py | 4 +- numba_rvsdg/tests/test_ast_transforms.py | 121 ++++++++++++++++++ 2 files changed, 124 insertions(+), 1 deletion(-) diff --git a/numba_rvsdg/core/datastructures/ast_transforms.py b/numba_rvsdg/core/datastructures/ast_transforms.py index efde9c5..3e725fd 100644 --- a/numba_rvsdg/core/datastructures/ast_transforms.py +++ b/numba_rvsdg/core/datastructures/ast_transforms.py @@ -490,8 +490,10 @@ def handle_if(self, node: ast.If) -> None: enif_index = self.block_index + 2 self.block_index += 3 + # Desugar test expression if needed, may modify currect_block. + test_name = self.handle_expression(node.test) # Emit comparison value to current/header block. - self.current_block.instructions.append(node.test) + self.current_block.instructions.append(test_name) # Setup jump targets for current/header block. self.current_block.set_jump_targets(then_index, else_index) diff --git a/numba_rvsdg/tests/test_ast_transforms.py b/numba_rvsdg/tests/test_ast_transforms.py index e4c421e..763f01e 100644 --- a/numba_rvsdg/tests/test_ast_transforms.py +++ b/numba_rvsdg/tests/test_ast_transforms.py @@ -1417,6 +1417,127 @@ def function(x: int, y: int, a: int, b: int) -> bool: ], ) + def test_if_with_bool_ops(self): + def function(a: int, b: int) -> int: + if a and b: + return 1 + return 0 + + expected = { + "0": { + "instructions": [ + "__scfg_bool_op_1__ = a", + "__scfg_bool_op_1__", + ], + "jump_targets": ["4", "5"], + "name": "0", + }, + "1": { + "instructions": ["return 1"], + "jump_targets": [], + "name": "1", + }, + "3": { + "instructions": ["return 0"], + "jump_targets": [], + "name": "3", + }, + "4": { + "instructions": ["__scfg_bool_op_1__ = b"], + "jump_targets": ["5"], + "name": "4", + }, + "5": { + "instructions": ["__scfg_bool_op_1__"], + "jump_targets": ["1", "3"], + "name": "5", + }, + } + self.compare( + function, + expected, + empty={"2"}, + arguments=[ + (0, 0), # All false + (0, 1), # b true + (1, 0), # a true + (1, 1), # All true + ], + ) + + def test_elif_with_bool_ops(self): + def function(a: int, b: int) -> int: + if a and b: + return 1 + elif a or b: + return 2 + return 0 + + expected = { + "0": { + "instructions": [ + "__scfg_bool_op_1__ = a", + "__scfg_bool_op_1__", + ], + "jump_targets": ["4", "5"], + "name": "0", + }, + "1": { + "instructions": ["return 1"], + "jump_targets": [], + "name": "1", + }, + "10": { + "instructions": ["__scfg_bool_op_2__"], + "jump_targets": ["6", "3"], + "name": "10", + }, + "2": { + "instructions": [ + "__scfg_bool_op_2__ = a", + "__scfg_bool_op_2__", + ], + "jump_targets": ["10", "9"], + "name": "2", + }, + "3": { + "instructions": ["return 0"], + "jump_targets": [], + "name": "3", + }, + "4": { + "instructions": ["__scfg_bool_op_1__ = b"], + "jump_targets": ["5"], + "name": "4", + }, + "5": { + "instructions": ["__scfg_bool_op_1__"], + "jump_targets": ["1", "2"], + "name": "5", + }, + "6": { + "instructions": ["return 2"], + "jump_targets": [], + "name": "6", + }, + "9": { + "instructions": ["__scfg_bool_op_2__ = b"], + "jump_targets": ["10"], + "name": "9", + }, + } + self.compare( + function, + expected, + empty={"8", "7"}, + arguments=[ + (0, 0), # All false + (0, 1), # b true + (1, 0), # a true + (1, 1), # All true + ], + ) + class TestEntryPoints(TestCase): From 274d6a80c962a711a9aa382f2a255ed5263b6299 Mon Sep 17 00:00:00 2001 From: esc Date: Thu, 30 Jan 2025 14:16:11 +0100 Subject: [PATCH 07/16] implement boolean expressions in loops As title --- .../core/datastructures/ast_transforms.py | 4 +- numba_rvsdg/tests/test_ast_transforms.py | 56 +++++++++++++++++++ 2 files changed, 59 insertions(+), 1 deletion(-) diff --git a/numba_rvsdg/core/datastructures/ast_transforms.py b/numba_rvsdg/core/datastructures/ast_transforms.py index 3e725fd..65c6c3e 100644 --- a/numba_rvsdg/core/datastructures/ast_transforms.py +++ b/numba_rvsdg/core/datastructures/ast_transforms.py @@ -538,8 +538,10 @@ def handle_while(self, node: ast.While) -> None: # And create new header block self.add_block(head_index) + # Desugar test expression if needed, may modify currect_block. + test_name = self.handle_expression(node.test) # Emit comparison expression into header. - self.current_block.instructions.append(node.test) + self.current_block.instructions.append(test_name) # Set the jump targets to be the body and the else branch. self.current_block.set_jump_targets(body_index, else_index) diff --git a/numba_rvsdg/tests/test_ast_transforms.py b/numba_rvsdg/tests/test_ast_transforms.py index 763f01e..451612f 100644 --- a/numba_rvsdg/tests/test_ast_transforms.py +++ b/numba_rvsdg/tests/test_ast_transforms.py @@ -1538,6 +1538,62 @@ def function(a: int, b: int) -> int: ], ) + def test_while_bool_ops(self): + def function(x: int, y: int) -> int: + count = 0 + while x and y: + count += 1 + x -= 1 + y -= 1 + return count + + expected = { + "0": { + "instructions": ["count = 0"], + "jump_targets": ["1"], + "name": "0", + }, + "1": { + "instructions": [ + "__scfg_bool_op_1__ = x", + "__scfg_bool_op_1__", + ], + "jump_targets": ["5", "6"], + "name": "1", + }, + "2": { + "instructions": ["count += 1", "x -= 1", "y -= 1"], + "jump_targets": ["1"], + "name": "2", + }, + "3": { + "instructions": ["return count"], + "jump_targets": [], + "name": "3", + }, + "5": { + "instructions": ["__scfg_bool_op_1__ = y"], + "jump_targets": ["6"], + "name": "5", + }, + "6": { + "instructions": ["__scfg_bool_op_1__"], + "jump_targets": ["2", "3"], + "name": "6", + }, + } + self.compare( + function, + expected, + empty={"4"}, + arguments=[ + (0, 0), # Both false, loop doesn't run + (2, 1), # y becomes false first, runs once + (1, 2), # x becomes false first, runs once + (2, 2), # Both true, runs twice + ], + ) + class TestEntryPoints(TestCase): From 90c98789ac49d744a72e458dee1f420c3f8bb149 Mon Sep 17 00:00:00 2001 From: esc Date: Thu, 30 Jan 2025 16:18:45 +0100 Subject: [PATCH 08/16] beautify As title --- .../core/datastructures/ast_transforms.py | 39 +++++++++++-------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/numba_rvsdg/core/datastructures/ast_transforms.py b/numba_rvsdg/core/datastructures/ast_transforms.py index 65c6c3e..b42fdf5 100644 --- a/numba_rvsdg/core/datastructures/ast_transforms.py +++ b/numba_rvsdg/core/datastructures/ast_transforms.py @@ -317,17 +317,18 @@ def handle_ast_node(self, node: type[ast.AST] | ast.stmt) -> None: elif isinstance( node, ( + ast.AugAssign, ast.Assign, ast.Expr, ast.Return, ), ): + # Node has an expression, must handle it. node.value = self.handle_expression(node.value) self.current_block.instructions.append(node) elif isinstance( node, ( - ast.AugAssign, ast.Break, ast.Continue, ast.Pass, @@ -348,7 +349,7 @@ def handle_expression(self, node: Any) -> Any: Returns the processed expression.""" if isinstance(node, ast.BoolOp): - # Handle or/and operations + # Handle or/and operations. if len(node.values) > 2: # In this case the bool operation has more than two operands, # we need to deconstruct this into a binary tree of bool @@ -372,18 +373,19 @@ def handle_expression(self, node: Any) -> Any: else: raise NotImplementedError("unreachable") elif isinstance(node, ast.Compare): - # Recursively handle left and right sides of comparison + # Recursively handle left and right sides of comparison. node.left = self.handle_expression(node.left) for i, comparator in enumerate(node.comparators): node.comparators[i] = self.handle_expression(comparator) return node elif isinstance(node, ast.BinOp): - # Handle binary operations (+, -, *, / etc) - node.left = self.handle_expression(node.left) - node.right = self.handle_expression(node.right) + # Handle binary operations (+, -, *, / etc). + node.left, node.right = self.handle_expression( + node.left + ), self.handle_expression(node.right) return node elif isinstance(node, ast.Call): - # Handle function calls + # Handle function calls. for i, arg in enumerate(node.args): node.args[i] = self.handle_expression(arg) return node @@ -395,11 +397,12 @@ def handle_bool_op(self, node: ast.BoolOp) -> ast.Name: """Handle boolean operations (and/or). Returns an ast.Name representing the result variable.""" - # Create a new temp variable to store the result + # Create a new temp variable to store the result. self.bool_op_index += 1 result_var = f"__scfg_bool_op_{self.bool_op_index}__" - # Generate code for first value + # Create an assignment to bin temp variable from above to the left most + # value in the expression. left = self.handle_expression(node.values[0]) self.current_block.instructions.append( ast.Assign( @@ -409,13 +412,14 @@ def handle_bool_op(self, node: ast.BoolOp) -> ast.Name: ) ) + # Handle the or operator. if isinstance(node.op, ast.Or): - # Create blocks for the true and false paths + # Create blocks for the true and false paths. false_block_index = self.block_index merge_block_index = self.block_index + 1 self.block_index += 2 - # Test and jump based on first value + # Test and jump based on first value. self.current_block.instructions.append( ast.Name(id=result_var, ctx=ast.Load()) ) @@ -423,7 +427,7 @@ def handle_bool_op(self, node: ast.BoolOp) -> ast.Name: merge_block_index, false_block_index ) - # False block evaluates second value + # False block evaluates second value. self.add_block(false_block_index) right = self.handle_expression(node.values[1]) self.current_block.instructions.append( @@ -438,13 +442,14 @@ def handle_bool_op(self, node: ast.BoolOp) -> ast.Name: # Create merge block self.add_block(merge_block_index) + # Handle the and operator. elif isinstance(node.op, ast.And): - # Create blocks for the true and false paths + # Create blocks for the true and false paths. true_block_index = self.block_index merge_block_index = self.block_index + 1 self.block_index += 2 - # Test and jump based on first value + # Test and jump based on first value. self.current_block.instructions.append( ast.Name(id=result_var, ctx=ast.Load()) ) @@ -452,7 +457,7 @@ def handle_bool_op(self, node: ast.BoolOp) -> ast.Name: true_block_index, merge_block_index ) - # True block evaluates second value + # True block evaluates second value. self.add_block(true_block_index) right = self.handle_expression(node.values[1]) self.current_block.instructions.append( @@ -464,13 +469,13 @@ def handle_bool_op(self, node: ast.BoolOp) -> ast.Name: ) self.current_block.set_jump_targets(merge_block_index) - # Create merge block + # Create merge block. self.add_block(merge_block_index) else: raise NotImplementedError("unreachable") - # Return name node referencing our result variable + # Return name node referencing our result variable. return ast.Name(id=result_var, ctx=ast.Load()) def handle_function_def(self, node: ast.FunctionDef) -> None: From b5bb5a9a4843738c73382d59c5b79f71e153e16b Mon Sep 17 00:00:00 2001 From: esc Date: Thu, 30 Jan 2025 16:22:54 +0100 Subject: [PATCH 09/16] beautify As title --- numba_rvsdg/tests/test_ast_transforms.py | 34 +++++++++++------------- 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/numba_rvsdg/tests/test_ast_transforms.py b/numba_rvsdg/tests/test_ast_transforms.py index 451612f..f1ac6c8 100644 --- a/numba_rvsdg/tests/test_ast_transforms.py +++ b/numba_rvsdg/tests/test_ast_transforms.py @@ -1229,9 +1229,8 @@ def function(x: int) -> int: def test_and(self): def function(x: int, y: int) -> int: - return ( - x and y - ) # Returns last truthy value, or False if any others are falsy + # Returns last truthy value, or False if any others are falsy. + return x and y expected = { "0": { @@ -1257,9 +1256,8 @@ def function(x: int, y: int) -> int: def test_or(self): def function(x: int, y: int) -> int: - return ( - x or y - ) # Returns first truthy value, or last value if all falsy + # Returns first truthy value, or last value if all falsy. + return x or y expected = { "0": { @@ -1286,9 +1284,9 @@ def function(x: int, y: int) -> int: function, expected, arguments=[ - (1, 0), # First value truthy - (0, 2), # First value falsy, second truthy - (0, 0), # All values falsy - returns last value + (1, 0), # First value truthy. + (0, 2), # First value falsy, second truthy. + (0, 0), # All values falsy - returns last value. ], ) @@ -1333,14 +1331,14 @@ def function(x: int, y: int, z: int) -> bool: function, expected, arguments=[ - (0, 0, 0), # All false - (0, 0, 1), # Only z true - short circuits at x - (0, 1, 0), # Only y true - short circuits at x - (0, 1, 1), # y and z true - short circuits at x - (1, 0, 0), # Only x true - short circuits at y - (1, 0, 1), # x and z true - short circuits at y - (1, 1, 0), # x and y true - fails at z - (1, 1, 1), # All true - complete evaluation + (0, 0, 0), # All false. + (0, 0, 1), # Only z true - short circuits at x. + (0, 1, 0), # Only y true - short circuits at x. + (0, 1, 1), # y and z true - short circuits at x. + (1, 0, 0), # Only x true - short circuits at y. + (1, 0, 1), # x and z true - short circuits at y. + (1, 1, 0), # x and y true - fails at z. + (1, 1, 1), # All true - complete evaluation. ], ) @@ -1538,7 +1536,7 @@ def function(a: int, b: int) -> int: ], ) - def test_while_bool_ops(self): + def test_while_with_bool_ops(self): def function(x: int, y: int) -> int: count = 0 while x and y: From ba4e132c8ab187ddbb31ef7924345f9748eddbae Mon Sep 17 00:00:00 2001 From: esc Date: Thu, 30 Jan 2025 16:30:52 +0100 Subject: [PATCH 10/16] test assignment to a boolean operation As title --- numba_rvsdg/tests/test_ast_transforms.py | 39 ++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/numba_rvsdg/tests/test_ast_transforms.py b/numba_rvsdg/tests/test_ast_transforms.py index f1ac6c8..d6c96e8 100644 --- a/numba_rvsdg/tests/test_ast_transforms.py +++ b/numba_rvsdg/tests/test_ast_transforms.py @@ -1592,6 +1592,45 @@ def function(x: int, y: int) -> int: ], ) + def test_bool_op_assignment(self): + def function(x: int, y: int) -> bool: + result = x and y + return result + + expected = { + "0": { + "instructions": [ + "__scfg_bool_op_1__ = x", + "__scfg_bool_op_1__", + ], + "jump_targets": ["1", "2"], + "name": "0", + }, + "1": { + "instructions": ["__scfg_bool_op_1__ = y"], + "jump_targets": ["2"], + "name": "1", + }, + "2": { + "instructions": [ + "result = __scfg_bool_op_1__", + "return result", + ], + "jump_targets": [], + "name": "2", + }, + } + self.compare( + function, + expected, + arguments=[ + (0, 0), # Both false + (0, 1), # x false, y true + (1, 0), # x true, y false + (1, 1), # Both true + ], + ) + class TestEntryPoints(TestCase): From 524ae68746c54f4db6653f72742e8cf588fb8cdf Mon Sep 17 00:00:00 2001 From: esc Date: Thu, 30 Jan 2025 17:27:13 +0100 Subject: [PATCH 11/16] test base expressions too As title --- numba_rvsdg/tests/test_ast_transforms.py | 32 ++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/numba_rvsdg/tests/test_ast_transforms.py b/numba_rvsdg/tests/test_ast_transforms.py index d6c96e8..1a49071 100644 --- a/numba_rvsdg/tests/test_ast_transforms.py +++ b/numba_rvsdg/tests/test_ast_transforms.py @@ -1631,6 +1631,38 @@ def function(x: int, y: int) -> bool: ], ) + def test_bool_base_expression(self): + def function(x: int) -> int: + result = [] + x or [result.append(i) for i in [0, 0, 0]] + return len(result) + + expected = { + "0": { + "instructions": [ + "result = []", + "__scfg_bool_op_1__ = x", + "__scfg_bool_op_1__", + ], + "jump_targets": ["2", "1"], + "name": "0", + }, + "1": { + "instructions": [ + "__scfg_bool_op_1__ = " + "[result.append(i) for i in [0, 0, 0]]" + ], + "jump_targets": ["2"], + "name": "1", + }, + "2": { + "instructions": ["__scfg_bool_op_1__", "return len(result)"], + "jump_targets": [], + "name": "2", + }, + } + self.compare(function, expected, arguments=[(0,), (1,)]) + class TestEntryPoints(TestCase): From 9b9c78b05f44a8dc27237bc759d093d1de5094b9 Mon Sep 17 00:00:00 2001 From: esc Date: Fri, 31 Jan 2025 09:41:01 +0100 Subject: [PATCH 12/16] code cleanup and using better idioms As title --- numba_rvsdg/core/datastructures/ast_transforms.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/numba_rvsdg/core/datastructures/ast_transforms.py b/numba_rvsdg/core/datastructures/ast_transforms.py index b42fdf5..71e33fc 100644 --- a/numba_rvsdg/core/datastructures/ast_transforms.py +++ b/numba_rvsdg/core/datastructures/ast_transforms.py @@ -375,19 +375,18 @@ def handle_expression(self, node: Any) -> Any: elif isinstance(node, ast.Compare): # Recursively handle left and right sides of comparison. node.left = self.handle_expression(node.left) - for i, comparator in enumerate(node.comparators): - node.comparators[i] = self.handle_expression(comparator) + node.comparators = [ + self.handle_expression(c) for c in node.comparators + ] return node elif isinstance(node, ast.BinOp): # Handle binary operations (+, -, *, / etc). - node.left, node.right = self.handle_expression( - node.left - ), self.handle_expression(node.right) + node.left = self.handle_expression(node.left) + node.right = self.handle_expression(node.right) return node elif isinstance(node, ast.Call): # Handle function calls. - for i, arg in enumerate(node.args): - node.args[i] = self.handle_expression(arg) + node.args = [self.handle_expression(a) for a in node.args] return node else: # Base case: literals, names, etc. From e639f3904c3d777d23192e9e23345a06e87cd9c8 Mon Sep 17 00:00:00 2001 From: esc Date: Mon, 3 Feb 2025 15:48:29 +0100 Subject: [PATCH 13/16] test expressions in function calls too As title --- numba_rvsdg/tests/test_ast_transforms.py | 36 ++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/numba_rvsdg/tests/test_ast_transforms.py b/numba_rvsdg/tests/test_ast_transforms.py index 1a49071..dcfec40 100644 --- a/numba_rvsdg/tests/test_ast_transforms.py +++ b/numba_rvsdg/tests/test_ast_transforms.py @@ -1663,6 +1663,42 @@ def function(x: int) -> int: } self.compare(function, expected, arguments=[(0,), (1,)]) + def test_expression_in_function_call(self): + def function(a: int, b: int) -> int: + return int(a or b) + + expected = { + "0": { + "instructions": [ + "__scfg_bool_op_1__ = a", + "__scfg_bool_op_1__", + ], + "jump_targets": ["2", "1"], + "name": "0", + }, + "1": { + "instructions": ["__scfg_bool_op_1__ = b"], + "jump_targets": ["2"], + "name": "1", + }, + "2": { + "instructions": ["return int(__scfg_bool_op_1__)"], + "jump_targets": [], + "name": "2", + }, + } + + self.compare( + function, + expected, + arguments=[ + (0, 0), + (0, 1), + (1, 0), + (1, 1), + ], + ) + class TestEntryPoints(TestCase): From 3bec548f68d85cebdbe30a878fadc85f6d2a3def Mon Sep 17 00:00:00 2001 From: esc Date: Mon, 3 Feb 2025 15:50:04 +0100 Subject: [PATCH 14/16] adding empty lines for coherence As title --- numba_rvsdg/tests/test_ast_transforms.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/numba_rvsdg/tests/test_ast_transforms.py b/numba_rvsdg/tests/test_ast_transforms.py index dcfec40..3305b1e 100644 --- a/numba_rvsdg/tests/test_ast_transforms.py +++ b/numba_rvsdg/tests/test_ast_transforms.py @@ -1220,6 +1220,7 @@ def function(x: int) -> int: } empty = {"7", "10", "12", "13", "16", "17", "20", "23", "26"} arguments = [(1,), (2,), (3,), (4,), (5,), (6,), (7,)] + self.compare( function, expected, @@ -1252,6 +1253,7 @@ def function(x: int, y: int) -> int: "name": "2", }, } + self.compare(function, expected, arguments=[(0, 0), (0, 1), (1, 0)]) def test_or(self): @@ -1327,6 +1329,7 @@ def function(x: int, y: int, z: int) -> bool: "name": "4", }, } + self.compare( function, expected, @@ -1392,6 +1395,7 @@ def function(x: int, y: int, a: int, b: int) -> bool: "name": "6", }, } + self.compare( function, expected, @@ -1451,6 +1455,7 @@ def function(a: int, b: int) -> int: "name": "5", }, } + self.compare( function, expected, @@ -1524,6 +1529,7 @@ def function(a: int, b: int) -> int: "name": "9", }, } + self.compare( function, expected, @@ -1580,6 +1586,7 @@ def function(x: int, y: int) -> int: "name": "6", }, } + self.compare( function, expected, @@ -1620,6 +1627,7 @@ def function(x: int, y: int) -> bool: "name": "2", }, } + self.compare( function, expected, @@ -1661,6 +1669,7 @@ def function(x: int) -> int: "name": "2", }, } + self.compare(function, expected, arguments=[(0,), (1,)]) def test_expression_in_function_call(self): From d232c7291ff986634c065989c97f2ea6a49311c4 Mon Sep 17 00:00:00 2001 From: esc Date: Mon, 3 Feb 2025 15:52:32 +0100 Subject: [PATCH 15/16] rename test variables for coherence As title --- numba_rvsdg/tests/test_ast_transforms.py | 38 ++++++++++++------------ 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/numba_rvsdg/tests/test_ast_transforms.py b/numba_rvsdg/tests/test_ast_transforms.py index 3305b1e..cf44b78 100644 --- a/numba_rvsdg/tests/test_ast_transforms.py +++ b/numba_rvsdg/tests/test_ast_transforms.py @@ -1420,15 +1420,15 @@ def function(x: int, y: int, a: int, b: int) -> bool: ) def test_if_with_bool_ops(self): - def function(a: int, b: int) -> int: - if a and b: + def function(x: int, y: int) -> int: + if x and y: return 1 return 0 expected = { "0": { "instructions": [ - "__scfg_bool_op_1__ = a", + "__scfg_bool_op_1__ = x", "__scfg_bool_op_1__", ], "jump_targets": ["4", "5"], @@ -1445,7 +1445,7 @@ def function(a: int, b: int) -> int: "name": "3", }, "4": { - "instructions": ["__scfg_bool_op_1__ = b"], + "instructions": ["__scfg_bool_op_1__ = y"], "jump_targets": ["5"], "name": "4", }, @@ -1462,24 +1462,24 @@ def function(a: int, b: int) -> int: empty={"2"}, arguments=[ (0, 0), # All false - (0, 1), # b true - (1, 0), # a true + (0, 1), # y true + (1, 0), # x true (1, 1), # All true ], ) def test_elif_with_bool_ops(self): - def function(a: int, b: int) -> int: - if a and b: + def function(x: int, y: int) -> int: + if x and y: return 1 - elif a or b: + elif x or y: return 2 return 0 expected = { "0": { "instructions": [ - "__scfg_bool_op_1__ = a", + "__scfg_bool_op_1__ = x", "__scfg_bool_op_1__", ], "jump_targets": ["4", "5"], @@ -1497,7 +1497,7 @@ def function(a: int, b: int) -> int: }, "2": { "instructions": [ - "__scfg_bool_op_2__ = a", + "__scfg_bool_op_2__ = x", "__scfg_bool_op_2__", ], "jump_targets": ["10", "9"], @@ -1509,7 +1509,7 @@ def function(a: int, b: int) -> int: "name": "3", }, "4": { - "instructions": ["__scfg_bool_op_1__ = b"], + "instructions": ["__scfg_bool_op_1__ = y"], "jump_targets": ["5"], "name": "4", }, @@ -1524,7 +1524,7 @@ def function(a: int, b: int) -> int: "name": "6", }, "9": { - "instructions": ["__scfg_bool_op_2__ = b"], + "instructions": ["__scfg_bool_op_2__ = y"], "jump_targets": ["10"], "name": "9", }, @@ -1536,8 +1536,8 @@ def function(a: int, b: int) -> int: empty={"8", "7"}, arguments=[ (0, 0), # All false - (0, 1), # b true - (1, 0), # a true + (0, 1), # y true + (1, 0), # x true (1, 1), # All true ], ) @@ -1673,20 +1673,20 @@ def function(x: int) -> int: self.compare(function, expected, arguments=[(0,), (1,)]) def test_expression_in_function_call(self): - def function(a: int, b: int) -> int: - return int(a or b) + def function(x: int, y: int) -> int: + return int(x or y) expected = { "0": { "instructions": [ - "__scfg_bool_op_1__ = a", + "__scfg_bool_op_1__ = x", "__scfg_bool_op_1__", ], "jump_targets": ["2", "1"], "name": "0", }, "1": { - "instructions": ["__scfg_bool_op_1__ = b"], + "instructions": ["__scfg_bool_op_1__ = y"], "jump_targets": ["2"], "name": "1", }, From f2109c5f3765a598b5c3c8f2552bad549353c95a Mon Sep 17 00:00:00 2001 From: Emergency Self-Construct Date: Mon, 3 Feb 2025 23:45:13 +0100 Subject: [PATCH 16/16] Fix spelling mistakes introduced by human Co-authored-by: Siu Kwan Lam <1929845+sklam@users.noreply.github.com> --- numba_rvsdg/core/datastructures/ast_transforms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/numba_rvsdg/core/datastructures/ast_transforms.py b/numba_rvsdg/core/datastructures/ast_transforms.py index 71e33fc..e208f29 100644 --- a/numba_rvsdg/core/datastructures/ast_transforms.py +++ b/numba_rvsdg/core/datastructures/ast_transforms.py @@ -494,7 +494,7 @@ def handle_if(self, node: ast.If) -> None: enif_index = self.block_index + 2 self.block_index += 3 - # Desugar test expression if needed, may modify currect_block. + # Desugar test expression if needed, may modify current_block. test_name = self.handle_expression(node.test) # Emit comparison value to current/header block. self.current_block.instructions.append(test_name) @@ -542,7 +542,7 @@ def handle_while(self, node: ast.While) -> None: # And create new header block self.add_block(head_index) - # Desugar test expression if needed, may modify currect_block. + # Desugar test expression if needed, may modify current_block. test_name = self.handle_expression(node.test) # Emit comparison expression into header. self.current_block.instructions.append(test_name)