Skip to content

Commit

Permalink
bug[next]: Fix for GTIR partial type inference (#1840)
Browse files Browse the repository at this point in the history
Co-authored-by: Edoardo Paone <[email protected]>
Co-authored-by: Hannes Vogt <[email protected]>
  • Loading branch information
3 people authored Feb 4, 2025
1 parent ac253b6 commit c17b882
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions src/gt4py/next/iterator/type_system/type_synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,15 +141,19 @@ def can_deref(it: it_ts.IteratorType | ts.DeferredType) -> ts.ScalarType:


@_register_builtin_type_synthesizer
def if_(pred: ts.ScalarType, true_branch: ts.DataType, false_branch: ts.DataType) -> ts.DataType:
def if_(
pred: ts.ScalarType | ts.DeferredType, true_branch: ts.DataType, false_branch: ts.DataType
) -> ts.DataType:
if isinstance(true_branch, ts.TupleType) and isinstance(false_branch, ts.TupleType):
return tree_map(
collection_type=ts.TupleType,
result_collection_constructor=lambda elts: ts.TupleType(types=[*elts]),
)(functools.partial(if_, pred))(true_branch, false_branch)

assert not isinstance(true_branch, ts.TupleType) and not isinstance(false_branch, ts.TupleType)
assert isinstance(pred, ts.ScalarType) and pred.kind == ts.ScalarKind.BOOL
assert isinstance(pred, ts.DeferredType) or (
isinstance(pred, ts.ScalarType) and pred.kind == ts.ScalarKind.BOOL
)
# TODO(tehrengruber): Enable this or a similar check. In case the true- and false-branch are
# iterators defined on different positions this fails. For the GTFN backend we also don't
# want this, but for roundtrip it is totally fine.
Expand Down

0 comments on commit c17b882

Please sign in to comment.