Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement handling and/or expressions #165

Merged
merged 16 commits into from
Feb 4, 2025
Merged
152 changes: 149 additions & 3 deletions numba_rvsdg/core/datastructures/ast_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ def __init__(
self.code = code
self.tree = unparse_code(code)
self.block_index: int = 1 # 0 is reserved for genesis block
self.bool_op_index = 0 # can have multiple of these per block
self.blocks = ASTCFG()
# Initialize first (genesis) block, assume it's named zero.
# (This also initializes the self.current_block attribute.)
Expand Down Expand Up @@ -316,10 +317,18 @@ def handle_ast_node(self, node: type[ast.AST] | ast.stmt) -> None:
elif isinstance(
node,
(
ast.Assign,
ast.AugAssign,
ast.Assign,
ast.Expr,
ast.Return,
),
):
# Node has an expression, must handle it.
node.value = self.handle_expression(node.value)
self.current_block.instructions.append(node)
elif isinstance(
node,
(
ast.Break,
ast.Continue,
ast.Pass,
Expand All @@ -335,6 +344,139 @@ def handle_ast_node(self, node: type[ast.AST] | ast.stmt) -> None:
else:
raise NotImplementedError(f"Node type {node} not implemented")

def handle_expression(self, node: Any) -> Any:
"""Recursively handle expression nodes and their subexpressions.
Returns the processed expression."""

if isinstance(node, ast.BoolOp):
# Handle or/and operations.
if len(node.values) > 2:
# In this case the bool operation has more than two operands,
# we need to deconstruct this into a binary tree of bool
# operations and recursively deal with those. The tail_node
# contains the tail of the operand list.
tail_node = ast.BoolOp(node.op, node.values[1:])
return self.handle_bool_op(
ast.BoolOp(
self.handle_expression(node.op),
[node.values[0], tail_node],
)
)
elif len(node.values) == 2:
# Base case, boolean operation has only two operands.
return self.handle_bool_op(
ast.BoolOp(
node.op,
[self.handle_expression(v) for v in node.values],
)
)
else:
raise NotImplementedError("unreachable")
elif isinstance(node, ast.Compare):
# Recursively handle left and right sides of comparison.
node.left = self.handle_expression(node.left)
node.comparators = [
self.handle_expression(c) for c in node.comparators
]
return node
elif isinstance(node, ast.BinOp):
# Handle binary operations (+, -, *, / etc).
node.left = self.handle_expression(node.left)
node.right = self.handle_expression(node.right)
return node
elif isinstance(node, ast.Call):
# Handle function calls.
node.args = [self.handle_expression(a) for a in node.args]
return node
else:
# Base case: literals, names, etc.
return node

def handle_bool_op(self, node: ast.BoolOp) -> ast.Name:
"""Handle boolean operations (and/or).
Returns an ast.Name representing the result variable."""

# Create a new temp variable to store the result.
self.bool_op_index += 1
result_var = f"__scfg_bool_op_{self.bool_op_index}__"

# Create an assignment to bin temp variable from above to the left most
# value in the expression.
left = self.handle_expression(node.values[0])
self.current_block.instructions.append(
ast.Assign(
targets=[ast.Name(id=result_var, ctx=ast.Store())],
value=left,
lineno=0,
)
)

# Handle the or operator.
if isinstance(node.op, ast.Or):
# Create blocks for the true and false paths.
false_block_index = self.block_index
merge_block_index = self.block_index + 1
self.block_index += 2

# Test and jump based on first value.
self.current_block.instructions.append(
ast.Name(id=result_var, ctx=ast.Load())
)
self.current_block.set_jump_targets(
merge_block_index, false_block_index
)

# False block evaluates second value.
self.add_block(false_block_index)
right = self.handle_expression(node.values[1])
self.current_block.instructions.append(
ast.Assign(
targets=[ast.Name(id=result_var, ctx=ast.Store())],
value=right,
lineno=0,
)
)
self.current_block.set_jump_targets(merge_block_index)

# Create merge block
self.add_block(merge_block_index)

# Handle the and operator.
elif isinstance(node.op, ast.And):
# Create blocks for the true and false paths.
true_block_index = self.block_index
merge_block_index = self.block_index + 1
self.block_index += 2

# Test and jump based on first value.
self.current_block.instructions.append(
ast.Name(id=result_var, ctx=ast.Load())
)
self.current_block.set_jump_targets(
true_block_index, merge_block_index
)

# True block evaluates second value.
self.add_block(true_block_index)
right = self.handle_expression(node.values[1])
self.current_block.instructions.append(
ast.Assign(
targets=[ast.Name(id=result_var, ctx=ast.Store())],
value=right,
lineno=0,
)
)
self.current_block.set_jump_targets(merge_block_index)

# Create merge block.
self.add_block(merge_block_index)

else:
raise NotImplementedError("unreachable")

# Return name node referencing our result variable.
return ast.Name(id=result_var, ctx=ast.Load())

def handle_function_def(self, node: ast.FunctionDef) -> None:
"""Handle a function definition."""
# Insert implicit return None, if the function isn't terminated. May
Expand All @@ -352,8 +494,10 @@ def handle_if(self, node: ast.If) -> None:
enif_index = self.block_index + 2
self.block_index += 3

# Desugar test expression if needed, may modify current_block.
test_name = self.handle_expression(node.test)
# Emit comparison value to current/header block.
self.current_block.instructions.append(node.test)
self.current_block.instructions.append(test_name)
# Setup jump targets for current/header block.
self.current_block.set_jump_targets(then_index, else_index)

Expand Down Expand Up @@ -398,8 +542,10 @@ def handle_while(self, node: ast.While) -> None:
# And create new header block
self.add_block(head_index)

# Desugar test expression if needed, may modify current_block.
test_name = self.handle_expression(node.test)
# Emit comparison expression into header.
self.current_block.instructions.append(node.test)
self.current_block.instructions.append(test_name)
# Set the jump targets to be the body and the else branch.
self.current_block.set_jump_targets(body_index, else_index)

Expand Down
Loading
Loading