diff --git a/numba_rvsdg/core/datastructures/ast_transforms.py b/numba_rvsdg/core/datastructures/ast_transforms.py index e87ef39..e208f29 100644 --- a/numba_rvsdg/core/datastructures/ast_transforms.py +++ b/numba_rvsdg/core/datastructures/ast_transforms.py @@ -251,6 +251,7 @@ def __init__( self.code = code self.tree = unparse_code(code) self.block_index: int = 1 # 0 is reserved for genesis block + self.bool_op_index = 0 # can have multiple of these per block self.blocks = ASTCFG() # Initialize first (genesis) block, assume it's named zero. # (This also initializes the self.current_block attribute.) @@ -316,10 +317,18 @@ def handle_ast_node(self, node: type[ast.AST] | ast.stmt) -> None: elif isinstance( node, ( - ast.Assign, ast.AugAssign, + ast.Assign, ast.Expr, ast.Return, + ), + ): + # Node has an expression, must handle it. + node.value = self.handle_expression(node.value) + self.current_block.instructions.append(node) + elif isinstance( + node, + ( ast.Break, ast.Continue, ast.Pass, @@ -335,6 +344,139 @@ def handle_ast_node(self, node: type[ast.AST] | ast.stmt) -> None: else: raise NotImplementedError(f"Node type {node} not implemented") + def handle_expression(self, node: Any) -> Any: + """Recursively handle expression nodes and their subexpressions. + Returns the processed expression.""" + + if isinstance(node, ast.BoolOp): + # Handle or/and operations. + if len(node.values) > 2: + # In this case the bool operation has more than two operands, + # we need to deconstruct this into a binary tree of bool + # operations and recursively deal with those. The tail_node + # contains the tail of the operand list. + tail_node = ast.BoolOp(node.op, node.values[1:]) + return self.handle_bool_op( + ast.BoolOp( + self.handle_expression(node.op), + [node.values[0], tail_node], + ) + ) + elif len(node.values) == 2: + # Base case, boolean operation has only two operands. + return self.handle_bool_op( + ast.BoolOp( + node.op, + [self.handle_expression(v) for v in node.values], + ) + ) + else: + raise NotImplementedError("unreachable") + elif isinstance(node, ast.Compare): + # Recursively handle left and right sides of comparison. + node.left = self.handle_expression(node.left) + node.comparators = [ + self.handle_expression(c) for c in node.comparators + ] + return node + elif isinstance(node, ast.BinOp): + # Handle binary operations (+, -, *, / etc). + node.left = self.handle_expression(node.left) + node.right = self.handle_expression(node.right) + return node + elif isinstance(node, ast.Call): + # Handle function calls. + node.args = [self.handle_expression(a) for a in node.args] + return node + else: + # Base case: literals, names, etc. + return node + + def handle_bool_op(self, node: ast.BoolOp) -> ast.Name: + """Handle boolean operations (and/or). + Returns an ast.Name representing the result variable.""" + + # Create a new temp variable to store the result. + self.bool_op_index += 1 + result_var = f"__scfg_bool_op_{self.bool_op_index}__" + + # Create an assignment to bin temp variable from above to the left most + # value in the expression. + left = self.handle_expression(node.values[0]) + self.current_block.instructions.append( + ast.Assign( + targets=[ast.Name(id=result_var, ctx=ast.Store())], + value=left, + lineno=0, + ) + ) + + # Handle the or operator. + if isinstance(node.op, ast.Or): + # Create blocks for the true and false paths. + false_block_index = self.block_index + merge_block_index = self.block_index + 1 + self.block_index += 2 + + # Test and jump based on first value. + self.current_block.instructions.append( + ast.Name(id=result_var, ctx=ast.Load()) + ) + self.current_block.set_jump_targets( + merge_block_index, false_block_index + ) + + # False block evaluates second value. + self.add_block(false_block_index) + right = self.handle_expression(node.values[1]) + self.current_block.instructions.append( + ast.Assign( + targets=[ast.Name(id=result_var, ctx=ast.Store())], + value=right, + lineno=0, + ) + ) + self.current_block.set_jump_targets(merge_block_index) + + # Create merge block + self.add_block(merge_block_index) + + # Handle the and operator. + elif isinstance(node.op, ast.And): + # Create blocks for the true and false paths. + true_block_index = self.block_index + merge_block_index = self.block_index + 1 + self.block_index += 2 + + # Test and jump based on first value. + self.current_block.instructions.append( + ast.Name(id=result_var, ctx=ast.Load()) + ) + self.current_block.set_jump_targets( + true_block_index, merge_block_index + ) + + # True block evaluates second value. + self.add_block(true_block_index) + right = self.handle_expression(node.values[1]) + self.current_block.instructions.append( + ast.Assign( + targets=[ast.Name(id=result_var, ctx=ast.Store())], + value=right, + lineno=0, + ) + ) + self.current_block.set_jump_targets(merge_block_index) + + # Create merge block. + self.add_block(merge_block_index) + + else: + raise NotImplementedError("unreachable") + + # Return name node referencing our result variable. + return ast.Name(id=result_var, ctx=ast.Load()) + def handle_function_def(self, node: ast.FunctionDef) -> None: """Handle a function definition.""" # Insert implicit return None, if the function isn't terminated. May @@ -352,8 +494,10 @@ def handle_if(self, node: ast.If) -> None: enif_index = self.block_index + 2 self.block_index += 3 + # Desugar test expression if needed, may modify current_block. + test_name = self.handle_expression(node.test) # Emit comparison value to current/header block. - self.current_block.instructions.append(node.test) + self.current_block.instructions.append(test_name) # Setup jump targets for current/header block. self.current_block.set_jump_targets(then_index, else_index) @@ -398,8 +542,10 @@ def handle_while(self, node: ast.While) -> None: # And create new header block self.add_block(head_index) + # Desugar test expression if needed, may modify current_block. + test_name = self.handle_expression(node.test) # Emit comparison expression into header. - self.current_block.instructions.append(node.test) + self.current_block.instructions.append(test_name) # Set the jump targets to be the body and the else branch. self.current_block.set_jump_targets(body_index, else_index) diff --git a/numba_rvsdg/tests/test_ast_transforms.py b/numba_rvsdg/tests/test_ast_transforms.py index bfc7dc0..cf44b78 100644 --- a/numba_rvsdg/tests/test_ast_transforms.py +++ b/numba_rvsdg/tests/test_ast_transforms.py @@ -1220,6 +1220,7 @@ def function(x: int) -> int: } empty = {"7", "10", "12", "13", "16", "17", "20", "23", "26"} arguments = [(1,), (2,), (3,), (4,), (5,), (6,), (7,)] + self.compare( function, expected, @@ -1227,6 +1228,486 @@ def function(x: int) -> int: arguments=arguments, ) + def test_and(self): + def function(x: int, y: int) -> int: + # Returns last truthy value, or False if any others are falsy. + return x and y + + expected = { + "0": { + "instructions": [ + "__scfg_bool_op_1__ = x", + "__scfg_bool_op_1__", + ], + "jump_targets": ["1", "2"], + "name": "0", + }, + "1": { + "instructions": ["__scfg_bool_op_1__ = y"], + "jump_targets": ["2"], + "name": "1", + }, + "2": { + "instructions": ["return __scfg_bool_op_1__"], + "jump_targets": [], + "name": "2", + }, + } + + self.compare(function, expected, arguments=[(0, 0), (0, 1), (1, 0)]) + + def test_or(self): + def function(x: int, y: int) -> int: + # Returns first truthy value, or last value if all falsy. + return x or y + + expected = { + "0": { + "instructions": [ + "__scfg_bool_op_1__ = x", + "__scfg_bool_op_1__", + ], + "jump_targets": ["2", "1"], + "name": "0", + }, + "1": { + "instructions": ["__scfg_bool_op_1__ = y"], + "jump_targets": ["2"], + "name": "1", + }, + "2": { + "instructions": ["return __scfg_bool_op_1__"], + "jump_targets": [], + "name": "2", + }, + } + + self.compare( + function, + expected, + arguments=[ + (1, 0), # First value truthy. + (0, 2), # First value falsy, second truthy. + (0, 0), # All values falsy - returns last value. + ], + ) + + def test_multi_operand_and(self): + def function(x: int, y: int, z: int) -> bool: + return x and y and z + + expected = { + "0": { + "instructions": [ + "__scfg_bool_op_1__ = x", + "__scfg_bool_op_1__", + ], + "jump_targets": ["1", "2"], + "name": "0", + }, + "1": { + "instructions": [ + "__scfg_bool_op_2__ = y", + "__scfg_bool_op_2__", + ], + "jump_targets": ["3", "4"], + "name": "1", + }, + "2": { + "instructions": ["return __scfg_bool_op_1__"], + "jump_targets": [], + "name": "2", + }, + "3": { + "instructions": ["__scfg_bool_op_2__ = z"], + "jump_targets": ["4"], + "name": "3", + }, + "4": { + "instructions": ["__scfg_bool_op_1__ = __scfg_bool_op_2__"], + "jump_targets": ["2"], + "name": "4", + }, + } + + self.compare( + function, + expected, + arguments=[ + (0, 0, 0), # All false. + (0, 0, 1), # Only z true - short circuits at x. + (0, 1, 0), # Only y true - short circuits at x. + (0, 1, 1), # y and z true - short circuits at x. + (1, 0, 0), # Only x true - short circuits at y. + (1, 0, 1), # x and z true - short circuits at y. + (1, 1, 0), # x and y true - fails at z. + (1, 1, 1), # All true - complete evaluation. + ], + ) + + def test_nested_andor(self): + def function(x: int, y: int, a: int, b: int) -> bool: + return (x and y) or (a and b) + + expected = { + "0": { + "instructions": [ + "__scfg_bool_op_1__ = x", + "__scfg_bool_op_1__", + ], + "jump_targets": ["1", "2"], + "name": "0", + }, + "1": { + "instructions": ["__scfg_bool_op_1__ = y"], + "jump_targets": ["2"], + "name": "1", + }, + "2": { + "instructions": [ + "__scfg_bool_op_2__ = a", + "__scfg_bool_op_2__", + ], + "jump_targets": ["3", "4"], + "name": "2", + }, + "3": { + "instructions": ["__scfg_bool_op_2__ = b"], + "jump_targets": ["4"], + "name": "3", + }, + "4": { + "instructions": [ + "__scfg_bool_op_3__ = __scfg_bool_op_1__", + "__scfg_bool_op_3__", + ], + "jump_targets": ["6", "5"], + "name": "4", + }, + "5": { + "instructions": ["__scfg_bool_op_3__ = __scfg_bool_op_2__"], + "jump_targets": ["6"], + "name": "5", + }, + "6": { + "instructions": ["return __scfg_bool_op_3__"], + "jump_targets": [], + "name": "6", + }, + } + + self.compare( + function, + expected, + arguments=[ + (0, 0, 0, 0), # F F F F -> False + (0, 0, 0, 1), # F F F T -> False + (0, 0, 1, 0), # F F T F -> False + (0, 0, 1, 1), # F F T T -> True + (0, 1, 0, 0), # F T F F -> False + (0, 1, 0, 1), # F T F T -> False + (0, 1, 1, 0), # F T T F -> False + (0, 1, 1, 1), # F T T T -> True + (1, 0, 0, 0), # T F F F -> False + (1, 0, 0, 1), # T F F T -> False + (1, 0, 1, 0), # T F T F -> False + (1, 0, 1, 1), # T F T T -> True + (1, 1, 0, 0), # T T F F -> True + (1, 1, 0, 1), # T T F T -> True + (1, 1, 1, 0), # T T T F -> True + (1, 1, 1, 1), # T T T T -> True + ], + ) + + def test_if_with_bool_ops(self): + def function(x: int, y: int) -> int: + if x and y: + return 1 + return 0 + + expected = { + "0": { + "instructions": [ + "__scfg_bool_op_1__ = x", + "__scfg_bool_op_1__", + ], + "jump_targets": ["4", "5"], + "name": "0", + }, + "1": { + "instructions": ["return 1"], + "jump_targets": [], + "name": "1", + }, + "3": { + "instructions": ["return 0"], + "jump_targets": [], + "name": "3", + }, + "4": { + "instructions": ["__scfg_bool_op_1__ = y"], + "jump_targets": ["5"], + "name": "4", + }, + "5": { + "instructions": ["__scfg_bool_op_1__"], + "jump_targets": ["1", "3"], + "name": "5", + }, + } + + self.compare( + function, + expected, + empty={"2"}, + arguments=[ + (0, 0), # All false + (0, 1), # y true + (1, 0), # x true + (1, 1), # All true + ], + ) + + def test_elif_with_bool_ops(self): + def function(x: int, y: int) -> int: + if x and y: + return 1 + elif x or y: + return 2 + return 0 + + expected = { + "0": { + "instructions": [ + "__scfg_bool_op_1__ = x", + "__scfg_bool_op_1__", + ], + "jump_targets": ["4", "5"], + "name": "0", + }, + "1": { + "instructions": ["return 1"], + "jump_targets": [], + "name": "1", + }, + "10": { + "instructions": ["__scfg_bool_op_2__"], + "jump_targets": ["6", "3"], + "name": "10", + }, + "2": { + "instructions": [ + "__scfg_bool_op_2__ = x", + "__scfg_bool_op_2__", + ], + "jump_targets": ["10", "9"], + "name": "2", + }, + "3": { + "instructions": ["return 0"], + "jump_targets": [], + "name": "3", + }, + "4": { + "instructions": ["__scfg_bool_op_1__ = y"], + "jump_targets": ["5"], + "name": "4", + }, + "5": { + "instructions": ["__scfg_bool_op_1__"], + "jump_targets": ["1", "2"], + "name": "5", + }, + "6": { + "instructions": ["return 2"], + "jump_targets": [], + "name": "6", + }, + "9": { + "instructions": ["__scfg_bool_op_2__ = y"], + "jump_targets": ["10"], + "name": "9", + }, + } + + self.compare( + function, + expected, + empty={"8", "7"}, + arguments=[ + (0, 0), # All false + (0, 1), # y true + (1, 0), # x true + (1, 1), # All true + ], + ) + + def test_while_with_bool_ops(self): + def function(x: int, y: int) -> int: + count = 0 + while x and y: + count += 1 + x -= 1 + y -= 1 + return count + + expected = { + "0": { + "instructions": ["count = 0"], + "jump_targets": ["1"], + "name": "0", + }, + "1": { + "instructions": [ + "__scfg_bool_op_1__ = x", + "__scfg_bool_op_1__", + ], + "jump_targets": ["5", "6"], + "name": "1", + }, + "2": { + "instructions": ["count += 1", "x -= 1", "y -= 1"], + "jump_targets": ["1"], + "name": "2", + }, + "3": { + "instructions": ["return count"], + "jump_targets": [], + "name": "3", + }, + "5": { + "instructions": ["__scfg_bool_op_1__ = y"], + "jump_targets": ["6"], + "name": "5", + }, + "6": { + "instructions": ["__scfg_bool_op_1__"], + "jump_targets": ["2", "3"], + "name": "6", + }, + } + + self.compare( + function, + expected, + empty={"4"}, + arguments=[ + (0, 0), # Both false, loop doesn't run + (2, 1), # y becomes false first, runs once + (1, 2), # x becomes false first, runs once + (2, 2), # Both true, runs twice + ], + ) + + def test_bool_op_assignment(self): + def function(x: int, y: int) -> bool: + result = x and y + return result + + expected = { + "0": { + "instructions": [ + "__scfg_bool_op_1__ = x", + "__scfg_bool_op_1__", + ], + "jump_targets": ["1", "2"], + "name": "0", + }, + "1": { + "instructions": ["__scfg_bool_op_1__ = y"], + "jump_targets": ["2"], + "name": "1", + }, + "2": { + "instructions": [ + "result = __scfg_bool_op_1__", + "return result", + ], + "jump_targets": [], + "name": "2", + }, + } + + self.compare( + function, + expected, + arguments=[ + (0, 0), # Both false + (0, 1), # x false, y true + (1, 0), # x true, y false + (1, 1), # Both true + ], + ) + + def test_bool_base_expression(self): + def function(x: int) -> int: + result = [] + x or [result.append(i) for i in [0, 0, 0]] + return len(result) + + expected = { + "0": { + "instructions": [ + "result = []", + "__scfg_bool_op_1__ = x", + "__scfg_bool_op_1__", + ], + "jump_targets": ["2", "1"], + "name": "0", + }, + "1": { + "instructions": [ + "__scfg_bool_op_1__ = " + "[result.append(i) for i in [0, 0, 0]]" + ], + "jump_targets": ["2"], + "name": "1", + }, + "2": { + "instructions": ["__scfg_bool_op_1__", "return len(result)"], + "jump_targets": [], + "name": "2", + }, + } + + self.compare(function, expected, arguments=[(0,), (1,)]) + + def test_expression_in_function_call(self): + def function(x: int, y: int) -> int: + return int(x or y) + + expected = { + "0": { + "instructions": [ + "__scfg_bool_op_1__ = x", + "__scfg_bool_op_1__", + ], + "jump_targets": ["2", "1"], + "name": "0", + }, + "1": { + "instructions": ["__scfg_bool_op_1__ = y"], + "jump_targets": ["2"], + "name": "1", + }, + "2": { + "instructions": ["return int(__scfg_bool_op_1__)"], + "jump_targets": [], + "name": "2", + }, + } + + self.compare( + function, + expected, + arguments=[ + (0, 0), + (0, 1), + (1, 0), + (1, 1), + ], + ) + class TestEntryPoints(TestCase):