Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
tehrengruber committed Feb 4, 2025
1 parent 7bf42dd commit 3b5a5dd
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/gt4py/next/iterator/transforms/fuse_as_fieldop.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def fuse_as_fieldop(
return new_node


def _arg_inline_predicate(node: itir.Expr, shifts):
def _arg_inline_predicate(node: itir.Expr, shifts: set[tuple[itir.OffsetLiteral, ...]]) -> bool:
if _is_tuple_expr_of_literals(node):
return True

Expand Down
8 changes: 6 additions & 2 deletions src/gt4py/next/iterator/transforms/inline_lifts.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,11 @@ def _transform_and_extract_lift_args(
new_args = []
for i, arg in enumerate(node.args):
if isinstance(arg, ir.SymRef):
# TODO(tehrengruber): Is it possible to reinfer the type if it is not inherited here?
sym = ir.Sym(id=arg.id)
assert sym not in extracted_args or extracted_args[sym] == arg
extracted_args[sym] = arg
new_args.append(arg) # TODO: type?
new_args.append(arg)
else:
new_symbol = _generate_unique_symbol(
desired_name=(inner_stencil, i),
Expand All @@ -92,7 +93,8 @@ def _transform_and_extract_lift_args(
)
assert new_symbol not in extracted_args
extracted_args[new_symbol] = arg
new_args.append(ir.SymRef(id=new_symbol.id)) # TODO: type?
# TODO(tehrengruber): Is it possible to reinfer the type if it is not inherited here?
new_args.append(ir.SymRef(id=new_symbol.id))

itir_node = im.lift(inner_stencil)(*new_args)
itir_node.location = node.location
Expand Down Expand Up @@ -159,6 +161,8 @@ def visit_FunCall(

if self.flags & self.Flag.PROPAGATE_SHIFT and _is_shift_lift(node):
shift = node.fun
# This transformation does not preserve the type (the position dims of the iterator
# change). Delete type to avoid errors.
shift.type = None
assert len(node.args) == 1
lift_call = node.args[0]
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/iterator/transforms/trace_shifts.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def fun(*args):
@classmethod
def trace_stencil(
cls, stencil: ir.Expr, *, num_args: Optional[int] = None, save_to_annex: bool = False
):
) -> list[set[tuple[ir.OffsetLiteral, ...]]]:
# If we get a lambda we can deduce the number of arguments.
if isinstance(stencil, ir.Lambda):
assert num_args is None or num_args == len(stencil.params)
Expand Down

0 comments on commit 3b5a5dd

Please sign in to comment.