Skip to content

Commit

Permalink
Use FixedPointTransformation
Browse files Browse the repository at this point in the history
  • Loading branch information
tehrengruber committed Feb 4, 2025
1 parent 545861a commit bc71808
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 104 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def transform(self, node: ir.Node, **kwargs) -> Optional[ir.Node]:
if result is not None:
assert (
result is not node
) # transformation should have returned None, since nothing changed
), f"Transformation {transformation.name.lower()} should have returned None, since nothing changed."
itir_type_inference.reinfer(result)
return result
return None
233 changes: 136 additions & 97 deletions src/gt4py/next/iterator/transforms/fuse_as_fieldop.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations

import dataclasses
import enum
import functools
import operator
from typing import Optional

from gt4py import eve
Expand All @@ -19,6 +23,7 @@
ir_makers as im,
)
from gt4py.next.iterator.transforms import (
fixed_point_transformation,
inline_center_deref_lift_vars,
inline_lambdas,
inline_lifts,
Expand Down Expand Up @@ -265,8 +270,10 @@ def _make_tuple_element_inline_predicate(node: itir.Expr):
return False


@dataclasses.dataclass
class FuseAsFieldOp(eve.NodeTranslator):
@dataclasses.dataclass(frozen=True, kw_only=True)
class FuseAsFieldOp(
fixed_point_transformation.FixedPointTransformation, eve.PreserveLocationVisitor
):
"""
Merge multiple `as_fieldop` calls into one.
Expand All @@ -293,8 +300,23 @@ class FuseAsFieldOp(eve.NodeTranslator):
as_fieldop(λ(inp1, inp2, inp3) → ·inp1 × ·inp2 + ·inp3, c⟨ IDimₕ: [0, 1[ ⟩)(inp1, inp2, inp3)
""" # noqa: RUF002 # ignore ambiguous multiplication character

class Transformation(enum.Flag):
#: Let `f_expr` be an expression with list dtype then
#: `let(f, f_expr) -> as_fieldop(...)(f)` -> `as_fieldop(...)(f_expr)`
FUSE_MAKE_TUPLE = enum.auto()
#: `as_fieldop(...)(as_fieldop(...)(a, b), c)`
#: -> as_fieldop(fused_stencil)(a, b, c)
FUSE_AS_FIELDOP = enum.auto()
INLINE_LET_VARS_OPCOUNT_PRESERVING = enum.auto()

@classmethod
def all(self) -> FuseAsFieldOp.Transformation:
return functools.reduce(operator.or_, self.__members__.values())

PRESERVED_ANNEX_ATTRS = ("domain",)

enabled_transformations = Transformation.all()

uids: eve_utils.UIDGenerator

@classmethod
Expand All @@ -306,7 +328,10 @@ def apply(
uids: Optional[eve_utils.UIDGenerator] = None,
allow_undeclared_symbols=False,
within_set_at_expr: Optional[bool] = None,
enabled_transformations: Optional[Transformation] = None,
):
enabled_transformations = enabled_transformations or cls.enabled_transformations

node = type_inference.infer(
node,
offset_provider_type=offset_provider_type,
Expand All @@ -319,13 +344,9 @@ def apply(
if not uids:
uids = eve_utils.UIDGenerator()

return cls(uids=uids).visit(node, within_set_at_expr=within_set_at_expr)

def visit(self, node, **kwargs):
new_node = super().visit(node, **kwargs)
if isinstance(node, itir.Expr) and hasattr(node.annex, "domain"):
new_node.annex.domain = node.annex.domain
return new_node
return cls(uids=uids, enabled_transformations=enabled_transformations).visit(
node, within_set_at_expr=within_set_at_expr
)

def visit_SetAt(self, node: itir.SetAt, **kwargs):
return itir.SetAt(
Expand All @@ -334,14 +355,116 @@ def visit_SetAt(self, node: itir.SetAt, **kwargs):
target=node.target,
)

def visit_FunCall(self, node: itir.FunCall, **kwargs):
def transform_fuse_make_tuple(self, node: itir.Node, **kwargs):
if not cpm.is_call_to(node, "make_tuple"):
return None

for arg in node.args:
type_inference.reinfer(arg)
assert not isinstance(arg.type, ts.FieldType) or (
hasattr(arg.annex, "domain")
and isinstance(arg.annex.domain, domain_utils.SymbolicDomain)
)

eligible_els = [_make_tuple_element_inline_predicate(arg) for arg in node.args]
field_args = [arg for i, arg in enumerate(node.args) if eligible_els[i]]
distinct_domains = set(arg.annex.domain.as_expr() for arg in field_args)
if len(distinct_domains) != len(field_args):
new_els: list[itir.Expr | None] = [None for _ in node.args]
field_args_by_domain: dict[itir.FunCall, list[tuple[int, itir.Expr]]] = {}
for i, arg in enumerate(node.args):
if eligible_els[i]:
assert isinstance(arg.annex.domain, domain_utils.SymbolicDomain)
domain = arg.annex.domain.as_expr()
field_args_by_domain.setdefault(domain, [])
field_args_by_domain[domain].append((i, arg))
else:
new_els[i] = arg # keep as is

if len(field_args_by_domain) == 1 and all(eligible_els):
# if we only have a single domain covering all args we don't need to create an
# unnecessary let
((domain, inner_field_args),) = field_args_by_domain.items()
new_node = im.op_as_fieldop(lambda *args: im.make_tuple(*args), domain)(
*(arg for _, arg in inner_field_args)
)
new_node = self.visit(new_node, **{**kwargs, "recurse": False})
else:
let_vars = {}
for domain, inner_field_args in field_args_by_domain.items():
if len(inner_field_args) > 1:
var = self.uids.sequential_id(prefix="__fasfop")
fused_args = im.op_as_fieldop(lambda *args: im.make_tuple(*args), domain)(
*(arg for _, arg in inner_field_args)
)
type_inference.reinfer(arg)
# don't recurse into nested args, but only consider newly created `as_fieldop`
# note: this will always inline (as we inline center accessed)
let_vars[var] = self.visit(fused_args, **{**kwargs, "recurse": False})
for outer_tuple_idx, (inner_tuple_idx, _) in enumerate(inner_field_args):
new_el = im.tuple_get(outer_tuple_idx, var)
new_el.annex.domain = domain_utils.SymbolicDomain.from_expr(domain)
new_els[inner_tuple_idx] = new_el
else:
i, arg = inner_field_args[0]
new_els[i] = arg
assert not any(el is None for el in new_els)
assert let_vars
new_node = im.let(*let_vars.items())(im.make_tuple(*new_els))
new_node = inline_lambdas.inline_lambda(new_node, opcount_preserving=True)
return new_node
return None

def transform_fuse_as_fieldop(self, node: itir.Node, **kwargs):
if cpm.is_applied_as_fieldop(node):
node = _canonicalize_as_fieldop(node)
stencil = node.fun.args[0] # type: ignore[attr-defined] # ensure cpm.is_applied_as_fieldop
assert isinstance(stencil, itir.Lambda) or cpm.is_call_to(stencil, "scan")
args: list[itir.Expr] = node.args
shifts = trace_shifts.trace_stencil(stencil, num_args=len(args))

eligible_els = [
_arg_inline_predicate(arg, arg_shifts)
for arg, arg_shifts in zip(args, shifts, strict=True)
]
if any(eligible_els):
return self.visit(
fuse_as_fieldop(node, eligible_els, uids=self.uids),
**{**kwargs, "recurse": False},
)
return None

def transform_inline_let_vars_opcount_preserving(self, node: itir.Node, **kwargs):
# when multiple `as_fieldop` calls are fused that use the same argument, this argument
# might become referenced once only. In order to be able to continue fusing such arguments
# try inlining here.
if cpm.is_let(node):
new_node = inline_lambdas.inline_lambda(node, opcount_preserving=True)
if new_node is not node: # nothing has been inlined
return self.visit(new_node, **kwargs)

return None

def generic_visit(self, node, **kwargs):
if cpm.is_applied_as_fieldop(node): # don't descend in stencil
return im.as_fieldop(*node.fun.args)(*self.generic_visit(node.args, **kwargs))
# TODO(tehrengruber): This is a common pattern that should be absorbed in
# `FixedPointTransformation`.
if kwargs.get("recurse", True):
return super().generic_visit(node, **kwargs)
else:
return node

def visit(self, node, **kwargs):
if not kwargs.get("within_set_at_expr"):
return node

# inline all fields with list dtype. This needs to happen before the children are visited
# such that the `as_fieldop` can be fused.
# TODO(tehrengruber): what should we do in case the field with list dtype is a let itself?
# This could duplicate other expressions which we did not intend to duplicate.
# TODO(tehrengruber): This should be moved into a `transform_` method, but
# `FixedPointTransformation` does not support pre-order transformations yet.
if cpm.is_let(node):
for arg in node.args:
type_inference.reinfer(arg)
Expand All @@ -353,93 +476,9 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs):
node = inline_lambdas.inline_lambda(node, eligible_params=eligible_els)
return self.visit(node, **kwargs)

if cpm.is_applied_as_fieldop(node): # don't descend in stencil
node = im.as_fieldop(*node.fun.args)(*self.generic_visit(node.args, **kwargs)) # type: ignore[attr-defined] # ensured by cpm.is_applied_as_fieldop
elif kwargs.get("recurse", True):
node = self.generic_visit(node, **kwargs)
node = super().visit(node, **kwargs)

if cpm.is_call_to(node, "make_tuple"):
for arg in node.args:
type_inference.reinfer(arg)
assert not isinstance(arg.type, ts.FieldType) or (
hasattr(arg.annex, "domain")
and isinstance(arg.annex.domain, domain_utils.SymbolicDomain)
)

eligible_els = [_make_tuple_element_inline_predicate(arg) for arg in node.args]
field_args = [arg for i, arg in enumerate(node.args) if eligible_els[i]]
distinct_domains = set(arg.annex.domain.as_expr() for arg in field_args)
if len(distinct_domains) != len(field_args):
new_els: list[itir.Expr | None] = [None for _ in node.args]
field_args_by_domain: dict[itir.FunCall, list[tuple[int, itir.Expr]]] = {}
for i, arg in enumerate(node.args):
if eligible_els[i]:
assert isinstance(arg.annex.domain, domain_utils.SymbolicDomain)
domain = arg.annex.domain.as_expr()
field_args_by_domain.setdefault(domain, [])
field_args_by_domain[domain].append((i, arg))
else:
new_els[i] = arg # keep as is

if len(field_args_by_domain) == 1 and all(eligible_els):
# if we only have a single domain covering all args we don't need to create an
# unnecessary let
((domain, inner_field_args),) = field_args_by_domain.items()
new_node = im.op_as_fieldop(lambda *args: im.make_tuple(*args), domain)(
*(arg for _, arg in inner_field_args)
)
new_node = self.visit(new_node, **{**kwargs, "recurse": False})
else:
let_vars = {}
for domain, inner_field_args in field_args_by_domain.items():
if len(inner_field_args) > 1:
var = self.uids.sequential_id(prefix="__fasfop")
fused_args = im.op_as_fieldop(
lambda *args: im.make_tuple(*args), domain
)(*(arg for _, arg in inner_field_args))
type_inference.reinfer(arg)
# don't recurse into nested args, but only consider newly created `as_fieldop`
# note: this will always inline (as we inline center accessed)
let_vars[var] = self.visit(fused_args, **{**kwargs, "recurse": False})
for outer_tuple_idx, (inner_tuple_idx, _) in enumerate(
inner_field_args
):
new_el = im.tuple_get(outer_tuple_idx, var)
new_el.annex.domain = domain_utils.SymbolicDomain.from_expr(domain)
new_els[inner_tuple_idx] = new_el
else:
i, arg = inner_field_args[0]
new_els[i] = arg
assert not any(el is None for el in new_els)
assert let_vars
new_node = im.let(*let_vars.items())(im.make_tuple(*new_els))
new_node = inline_lambdas.inline_lambda(new_node, opcount_preserving=True)
return new_node

if cpm.is_call_to(node.fun, "as_fieldop"):
node = _canonicalize_as_fieldop(node)

# when multiple `as_fieldop` calls are fused that use the same argument, this argument
# might become referenced once only. In order to be able to continue fusing such arguments
# try inlining here.
if cpm.is_let(node):
new_node = inline_lambdas.inline_lambda(node, opcount_preserving=True)
if new_node is not node: # nothing has been inlined
return self.visit(new_node, **kwargs)

if cpm.is_call_to(node.fun, "as_fieldop"):
stencil = node.fun.args[0]
assert isinstance(stencil, itir.Lambda) or cpm.is_call_to(stencil, "scan")
args: list[itir.Expr] = node.args
shifts = trace_shifts.trace_stencil(stencil, num_args=len(args))
if isinstance(node, itir.Expr) and hasattr(node.annex, "domain"):
node.annex.domain = node.annex.domain

eligible_els = [
_arg_inline_predicate(arg, arg_shifts)
for arg, arg_shifts in zip(args, shifts, strict=True)
]
if any(eligible_els):
return self.visit(
fuse_as_fieldop(node, eligible_els, uids=self.uids),
**{**kwargs, "recurse": False},
)
return node
Original file line number Diff line number Diff line change
Expand Up @@ -345,12 +345,12 @@ def test_inline_as_fieldop_with_list_dtype():
dims=[IDim], dtype=ts.ListType(element_type=ts.ScalarType(kind=ts.ScalarKind.INT32))
)
d = im.domain("cartesian_domain", {IDim: (0, 1)})
testee = im.as_fieldop(im.lambda_("inp")(im.call("reduce")(im.deref("inp"), 0)), d)(
im.as_fieldop("deref")(im.ref("inp", list_field_type))
)
expected = im.as_fieldop(im.lambda_("inp")(im.call("reduce")(im.deref("inp"), 0)), d)(
im.ref("inp", list_field_type)
)
testee = im.as_fieldop(
im.lambda_("inp")(im.call(im.call("reduce")("plus", 0))(im.deref("inp"))), d
)(im.as_fieldop("deref")(im.ref("inp", list_field_type)))
expected = im.as_fieldop(
im.lambda_("inp")(im.call(im.call("reduce")("plus", 0))(im.deref("inp"))), d
)(im.ref("inp", list_field_type))
actual = fuse_as_fieldop.FuseAsFieldOp.apply(
testee, offset_provider_type={}, allow_undeclared_symbols=True
)
Expand Down

0 comments on commit bc71808

Please sign in to comment.