Skip to content

Commit

Permalink
Merge branch 'main' into decouple_inferences
Browse files Browse the repository at this point in the history
  • Loading branch information
SF-N authored Feb 7, 2025
2 parents c851f8f + 34b574a commit 74e4584
Show file tree
Hide file tree
Showing 16 changed files with 834 additions and 268 deletions.
14 changes: 12 additions & 2 deletions src/gt4py/next/iterator/ir_utils/ir_makers.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def domain(
)


def as_fieldop(expr: itir.Expr | str, domain: Optional[itir.Expr] = None) -> call:
def as_fieldop(expr: itir.Expr | str, domain: Optional[itir.Expr] = None) -> Callable:
"""
Create an `as_fieldop` call.
Expand All @@ -445,7 +445,9 @@ def as_fieldop(expr: itir.Expr | str, domain: Optional[itir.Expr] = None) -> cal
>>> str(as_fieldop(lambda_("it1", "it2")(plus(deref("it1"), deref("it2"))))("field1", "field2"))
'(⇑(λ(it1, it2) → ·it1 + ·it2))(field1, field2)'
"""
return call(
from gt4py.next.iterator.ir_utils import domain_utils

result = call(
call("as_fieldop")(
*(
(
Expand All @@ -458,6 +460,14 @@ def as_fieldop(expr: itir.Expr | str, domain: Optional[itir.Expr] = None) -> cal
)
)

def _populate_domain_annex_wrapper(*args, **kwargs):
node = result(*args, **kwargs)
if domain:
node.annex.domain = domain_utils.SymbolicDomain.from_expr(domain)
return node

return _populate_domain_annex_wrapper


def op_as_fieldop(
op: str | itir.SymRef | itir.Lambda | Callable, domain: Optional[itir.FunCall] = None
Expand Down
5 changes: 3 additions & 2 deletions src/gt4py/next/iterator/transforms/collapse_tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def all(self) -> CollapseTuple.Transformation:
ignore_tuple_size: bool
enabled_transformations: Transformation = Transformation.all() # noqa: RUF009 [function-call-in-dataclass-default-argument]

PRESERVED_ANNEX_ATTRS = ("type",)
PRESERVED_ANNEX_ATTRS = ("type", "domain")

@classmethod
def apply(
Expand Down Expand Up @@ -236,6 +236,7 @@ def transform_collapse_make_tuple_tuple_get(
# tuple argument differs, just continue with the rest of the tree
return None

itir_type_inference.reinfer(first_expr) # type is needed so reinfer on-demand
assert self.ignore_tuple_size or isinstance(
first_expr.type, (ts.TupleType, ts.DeferredType)
)
Expand All @@ -255,7 +256,7 @@ def transform_collapse_tuple_get_make_tuple(
and cpm.is_call_to(node.args[1], "make_tuple")
):
# `tuple_get(i, make_tuple(e_0, e_1, ..., e_i, ..., e_N))` -> `e_i`
assert type_info.is_integer(node.args[0].type)
assert not node.args[0].type or type_info.is_integer(node.args[0].type)
make_tuple_call = node.args[1]
idx = int(node.args[0].value)
assert idx < len(
Expand Down
5 changes: 5 additions & 0 deletions src/gt4py/next/iterator/transforms/constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@


class ConstantFolding(PreserveLocationVisitor, NodeTranslator):
PRESERVED_ANNEX_ATTRS = (
"type",
"domain",
)

@classmethod
def apply(cls, node: ir.Node) -> ir.Node:
return cls().visit(node)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def transform(self, node: ir.Node, **kwargs) -> Optional[ir.Node]:
if result is not None:
assert (
result is not node
) # transformation should have returned None, since nothing changed
), f"Transformation {transformation.name.lower()} should have returned None, since nothing changed."
itir_type_inference.reinfer(result)
return result
return None
Loading

0 comments on commit 74e4584

Please sign in to comment.