Skip to content

Commit

Permalink
Merge branch 'main' into more_datatypes
Browse files Browse the repository at this point in the history
  • Loading branch information
tehrengruber authored Jan 21, 2025
2 parents f988ea4 + 7e566fc commit 67d0763
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
5 changes: 0 additions & 5 deletions src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,6 @@ def visit_HorizontalExecution(
global_ctx: DaCeIRBuilder.GlobalContext,
iteration_ctx: DaCeIRBuilder.IterationContext,
symbol_collector: DaCeIRBuilder.SymbolCollector,
loop_order,
k_interval,
**kwargs: Any,
):
Expand Down Expand Up @@ -522,7 +521,6 @@ def visit_VerticalLoopSection(
self,
node: oir.VerticalLoopSection,
*,
loop_order,
iteration_ctx: DaCeIRBuilder.IterationContext,
global_ctx: DaCeIRBuilder.GlobalContext,
symbol_collector: DaCeIRBuilder.SymbolCollector,
Expand All @@ -546,7 +544,6 @@ def visit_VerticalLoopSection(
iteration_ctx=iteration_ctx,
global_ctx=global_ctx,
symbol_collector=symbol_collector,
loop_order=loop_order,
k_interval=node.interval,
**kwargs,
)
Expand Down Expand Up @@ -723,7 +720,6 @@ def _process_loop_item(
scope_nodes,
item: Loop,
*,
global_ctx: DaCeIRBuilder.GlobalContext,
iteration_ctx: DaCeIRBuilder.IterationContext,
symbol_collector: DaCeIRBuilder.SymbolCollector,
**kwargs: Any,
Expand Down Expand Up @@ -840,7 +836,6 @@ def visit_VerticalLoop(
sections = flatten_list(
self.generic_visit(
node.sections,
loop_order=node.loop_order,
global_ctx=global_ctx,
iteration_ctx=iteration_ctx,
symbol_collector=symbol_collector,
Expand Down
9 changes: 7 additions & 2 deletions src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# SPDX-License-Identifier: BSD-3-Clause

from collections.abc import Iterable
from typing import TypeGuard
from typing import Any, TypeGuard

from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import ir_makers as im
Expand Down Expand Up @@ -63,10 +63,14 @@ def is_let(node: itir.Node) -> TypeGuard[itir.FunCall]:
return isinstance(node, itir.FunCall) and isinstance(node.fun, itir.Lambda)


def is_call_to(node: itir.Node, fun: str | Iterable[str]) -> TypeGuard[itir.FunCall]:
def is_call_to(node: Any, fun: str | Iterable[str]) -> TypeGuard[itir.FunCall]:
"""
Match call expression to a given function.
If the `node` argument is not an `itir.Node` the function does not error, but just returns
`False`. This is useful in visitors, where sometimes we pass a list of nodes or a leaf
attribute which can be anything.
>>> from gt4py.next.iterator.ir_utils import ir_makers as im
>>> node = im.call("plus")(1, 2)
>>> is_call_to(node, "plus")
Expand All @@ -76,6 +80,7 @@ def is_call_to(node: itir.Node, fun: str | Iterable[str]) -> TypeGuard[itir.FunC
>>> is_call_to(node, ("plus", "minus"))
True
"""
assert not isinstance(fun, itir.Node) # to avoid accidentally passing the fun as first argument
if isinstance(fun, (list, tuple, set, Iterable)) and not isinstance(fun, str):
return any((is_call_to(node, f) for f in fun))
return (
Expand Down

0 comments on commit 67d0763

Please sign in to comment.