diff --git a/src/gt4py/next/iterator/transforms/inline_lifts.py b/src/gt4py/next/iterator/transforms/inline_lifts.py index 8d36e5f004..aa3f1fe3dc 100644 --- a/src/gt4py/next/iterator/transforms/inline_lifts.py +++ b/src/gt4py/next/iterator/transforms/inline_lifts.py @@ -83,7 +83,7 @@ def _transform_and_extract_lift_args( sym = ir.Sym(id=arg.id) assert sym not in extracted_args or extracted_args[sym] == arg extracted_args[sym] = arg - new_args.append(arg) + new_args.append(arg) # TODO: type? else: new_symbol = _generate_unique_symbol( desired_name=(inner_stencil, i), @@ -92,7 +92,7 @@ def _transform_and_extract_lift_args( ) assert new_symbol not in extracted_args extracted_args[new_symbol] = arg - new_args.append(ir.SymRef(id=new_symbol.id)) + new_args.append(ir.SymRef(id=new_symbol.id)) # TODO: type? itir_node = im.lift(inner_stencil)(*new_args) itir_node.location = node.location @@ -159,6 +159,7 @@ def visit_FunCall( if self.flags & self.Flag.PROPAGATE_SHIFT and _is_shift_lift(node): shift = node.fun + shift.type = None assert len(node.args) == 1 lift_call = node.args[0] new_args = [ diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 19ab3ecdda..131b773dd2 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -385,25 +385,30 @@ def apply_shift( assert isinstance(it, it_ts.IteratorType) if it.position_dims == "unknown": # nothing to do here return it - new_position_dims = [*it.position_dims] - assert len(offset_literals) % 2 == 0 - for offset_axis, _ in zip(offset_literals[:-1:2], offset_literals[1::2], strict=True): - assert isinstance(offset_axis, it_ts.OffsetLiteralType) and isinstance( - offset_axis.value, common.Dimension - ) - type_ = offset_provider_type[offset_axis.value.value] - if isinstance(type_, common.Dimension): - pass - elif isinstance(type_, common.NeighborConnectivityType): - found = False - for i, dim in enumerate(new_position_dims): - if dim.value == type_.source_dim.value: - assert not found - new_position_dims[i] = type_.codomain - found = True - assert found - else: - raise NotImplementedError(f"{type_} is not a supported Connectivity type.") + new_position_dims: list[common.Dimension] | str + if offset_provider_type: + new_position_dims = [*it.position_dims] + assert len(offset_literals) % 2 == 0 + for offset_axis, _ in zip(offset_literals[:-1:2], offset_literals[1::2], strict=True): + assert isinstance(offset_axis, it_ts.OffsetLiteralType) and isinstance( + offset_axis.value, common.Dimension + ) + type_ = offset_provider_type[offset_axis.value.value] + if isinstance(type_, common.Dimension): + pass + elif isinstance(type_, common.NeighborConnectivityType): + found = False + for i, dim in enumerate(new_position_dims): + if dim.value == type_.source_dim.value: + assert not found + new_position_dims[i] = type_.codomain + found = True + assert found + else: + raise NotImplementedError(f"{type_} is not a supported Connectivity type.") + else: + # during re-inference we don't have an offset provider type + new_position_dims = "unknown" return it_ts.IteratorType( position_dims=new_position_dims, defined_dims=it.defined_dims,