Skip to content

Commit

Permalink
feature[next]: Runtime check args in is_call_to (#1796)
Browse files Browse the repository at this point in the history
Calling `is_call_to` with the arguments in the wrong order happens
easily. This PR adds a runtime check to avoid this.
  • Loading branch information
tehrengruber authored Jan 21, 2025
1 parent 8eae147 commit 7e566fc
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# SPDX-License-Identifier: BSD-3-Clause

from collections.abc import Iterable
from typing import TypeGuard
from typing import Any, TypeGuard

from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import ir_makers as im
Expand Down Expand Up @@ -63,10 +63,14 @@ def is_let(node: itir.Node) -> TypeGuard[itir.FunCall]:
return isinstance(node, itir.FunCall) and isinstance(node.fun, itir.Lambda)


def is_call_to(node: itir.Node, fun: str | Iterable[str]) -> TypeGuard[itir.FunCall]:
def is_call_to(node: Any, fun: str | Iterable[str]) -> TypeGuard[itir.FunCall]:
"""
Match call expression to a given function.
If the `node` argument is not an `itir.Node` the function does not error, but just returns
`False`. This is useful in visitors, where sometimes we pass a list of nodes or a leaf
attribute which can be anything.
>>> from gt4py.next.iterator.ir_utils import ir_makers as im
>>> node = im.call("plus")(1, 2)
>>> is_call_to(node, "plus")
Expand All @@ -76,6 +80,7 @@ def is_call_to(node: itir.Node, fun: str | Iterable[str]) -> TypeGuard[itir.FunC
>>> is_call_to(node, ("plus", "minus"))
True
"""
assert not isinstance(fun, itir.Node) # to avoid accidentally passing the fun as first argument
if isinstance(fun, (list, tuple, set, Iterable)) and not isinstance(fun, str):
return any((is_call_to(node, f) for f in fun))
return (
Expand Down

0 comments on commit 7e566fc

Please sign in to comment.