Skip to content

Commit

Permalink
Merge pull request #165 from numba/handle_bool_op
Browse files Browse the repository at this point in the history
implement handling and/or expressions
  • Loading branch information
esc authored Feb 4, 2025
2 parents bae1812 + f2109c5 commit bf09ccf
Show file tree
Hide file tree
Showing 2 changed files with 630 additions and 3 deletions.
152 changes: 149 additions & 3 deletions numba_rvsdg/core/datastructures/ast_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ def __init__(
self.code = code
self.tree = unparse_code(code)
self.block_index: int = 1 # 0 is reserved for genesis block
self.bool_op_index = 0 # can have multiple of these per block
self.blocks = ASTCFG()
# Initialize first (genesis) block, assume it's named zero.
# (This also initializes the self.current_block attribute.)
Expand Down Expand Up @@ -316,10 +317,18 @@ def handle_ast_node(self, node: type[ast.AST] | ast.stmt) -> None:
elif isinstance(
node,
(
ast.Assign,
ast.AugAssign,
ast.Assign,
ast.Expr,
ast.Return,
),
):
# Node has an expression, must handle it.
node.value = self.handle_expression(node.value)
self.current_block.instructions.append(node)
elif isinstance(
node,
(
ast.Break,
ast.Continue,
ast.Pass,
Expand All @@ -335,6 +344,139 @@ def handle_ast_node(self, node: type[ast.AST] | ast.stmt) -> None:
else:
raise NotImplementedError(f"Node type {node} not implemented")

def handle_expression(self, node: Any) -> Any:
"""Recursively handle expression nodes and their subexpressions.
Returns the processed expression."""

if isinstance(node, ast.BoolOp):
# Handle or/and operations.
if len(node.values) > 2:
# In this case the bool operation has more than two operands,
# we need to deconstruct this into a binary tree of bool
# operations and recursively deal with those. The tail_node
# contains the tail of the operand list.
tail_node = ast.BoolOp(node.op, node.values[1:])
return self.handle_bool_op(
ast.BoolOp(
self.handle_expression(node.op),
[node.values[0], tail_node],
)
)
elif len(node.values) == 2:
# Base case, boolean operation has only two operands.
return self.handle_bool_op(
ast.BoolOp(
node.op,
[self.handle_expression(v) for v in node.values],
)
)
else:
raise NotImplementedError("unreachable")
elif isinstance(node, ast.Compare):
# Recursively handle left and right sides of comparison.
node.left = self.handle_expression(node.left)
node.comparators = [
self.handle_expression(c) for c in node.comparators
]
return node
elif isinstance(node, ast.BinOp):
# Handle binary operations (+, -, *, / etc).
node.left = self.handle_expression(node.left)
node.right = self.handle_expression(node.right)
return node
elif isinstance(node, ast.Call):
# Handle function calls.
node.args = [self.handle_expression(a) for a in node.args]
return node
else:
# Base case: literals, names, etc.
return node

def handle_bool_op(self, node: ast.BoolOp) -> ast.Name:
"""Handle boolean operations (and/or).
Returns an ast.Name representing the result variable."""

# Create a new temp variable to store the result.
self.bool_op_index += 1
result_var = f"__scfg_bool_op_{self.bool_op_index}__"

# Create an assignment to bin temp variable from above to the left most
# value in the expression.
left = self.handle_expression(node.values[0])
self.current_block.instructions.append(
ast.Assign(
targets=[ast.Name(id=result_var, ctx=ast.Store())],
value=left,
lineno=0,
)
)

# Handle the or operator.
if isinstance(node.op, ast.Or):
# Create blocks for the true and false paths.
false_block_index = self.block_index
merge_block_index = self.block_index + 1
self.block_index += 2

# Test and jump based on first value.
self.current_block.instructions.append(
ast.Name(id=result_var, ctx=ast.Load())
)
self.current_block.set_jump_targets(
merge_block_index, false_block_index
)

# False block evaluates second value.
self.add_block(false_block_index)
right = self.handle_expression(node.values[1])
self.current_block.instructions.append(
ast.Assign(
targets=[ast.Name(id=result_var, ctx=ast.Store())],
value=right,
lineno=0,
)
)
self.current_block.set_jump_targets(merge_block_index)

# Create merge block
self.add_block(merge_block_index)

# Handle the and operator.
elif isinstance(node.op, ast.And):
# Create blocks for the true and false paths.
true_block_index = self.block_index
merge_block_index = self.block_index + 1
self.block_index += 2

# Test and jump based on first value.
self.current_block.instructions.append(
ast.Name(id=result_var, ctx=ast.Load())
)
self.current_block.set_jump_targets(
true_block_index, merge_block_index
)

# True block evaluates second value.
self.add_block(true_block_index)
right = self.handle_expression(node.values[1])
self.current_block.instructions.append(
ast.Assign(
targets=[ast.Name(id=result_var, ctx=ast.Store())],
value=right,
lineno=0,
)
)
self.current_block.set_jump_targets(merge_block_index)

# Create merge block.
self.add_block(merge_block_index)

else:
raise NotImplementedError("unreachable")

# Return name node referencing our result variable.
return ast.Name(id=result_var, ctx=ast.Load())

def handle_function_def(self, node: ast.FunctionDef) -> None:
"""Handle a function definition."""
# Insert implicit return None, if the function isn't terminated. May
Expand All @@ -352,8 +494,10 @@ def handle_if(self, node: ast.If) -> None:
enif_index = self.block_index + 2
self.block_index += 3

# Desugar test expression if needed, may modify current_block.
test_name = self.handle_expression(node.test)
# Emit comparison value to current/header block.
self.current_block.instructions.append(node.test)
self.current_block.instructions.append(test_name)
# Setup jump targets for current/header block.
self.current_block.set_jump_targets(then_index, else_index)

Expand Down Expand Up @@ -392,8 +536,10 @@ def handle_while(self, node: ast.While) -> None:
# And create new header block
self.add_block(head_index)

# Desugar test expression if needed, may modify current_block.
test_name = self.handle_expression(node.test)
# Emit comparison expression into header.
self.current_block.instructions.append(node.test)
self.current_block.instructions.append(test_name)
# Set the jump targets to be the body and the else branch.
self.current_block.set_jump_targets(body_index, else_index)

Expand Down
Loading

0 comments on commit bf09ccf

Please sign in to comment.