Skip to content

Commit

Permalink
Fix partial inference for shift
Browse files Browse the repository at this point in the history
  • Loading branch information
tehrengruber committed Feb 4, 2025
1 parent 45d017e commit 7bf42dd
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 21 deletions.
5 changes: 3 additions & 2 deletions src/gt4py/next/iterator/transforms/inline_lifts.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def _transform_and_extract_lift_args(
sym = ir.Sym(id=arg.id)
assert sym not in extracted_args or extracted_args[sym] == arg
extracted_args[sym] = arg
new_args.append(arg)
new_args.append(arg) # TODO: type?
else:
new_symbol = _generate_unique_symbol(
desired_name=(inner_stencil, i),
Expand All @@ -92,7 +92,7 @@ def _transform_and_extract_lift_args(
)
assert new_symbol not in extracted_args
extracted_args[new_symbol] = arg
new_args.append(ir.SymRef(id=new_symbol.id))
new_args.append(ir.SymRef(id=new_symbol.id)) # TODO: type?

itir_node = im.lift(inner_stencil)(*new_args)
itir_node.location = node.location
Expand Down Expand Up @@ -159,6 +159,7 @@ def visit_FunCall(

if self.flags & self.Flag.PROPAGATE_SHIFT and _is_shift_lift(node):
shift = node.fun
shift.type = None
assert len(node.args) == 1
lift_call = node.args[0]
new_args = [
Expand Down
43 changes: 24 additions & 19 deletions src/gt4py/next/iterator/type_system/type_synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,25 +385,30 @@ def apply_shift(
assert isinstance(it, it_ts.IteratorType)
if it.position_dims == "unknown": # nothing to do here
return it
new_position_dims = [*it.position_dims]
assert len(offset_literals) % 2 == 0
for offset_axis, _ in zip(offset_literals[:-1:2], offset_literals[1::2], strict=True):
assert isinstance(offset_axis, it_ts.OffsetLiteralType) and isinstance(
offset_axis.value, common.Dimension
)
type_ = offset_provider_type[offset_axis.value.value]
if isinstance(type_, common.Dimension):
pass
elif isinstance(type_, common.NeighborConnectivityType):
found = False
for i, dim in enumerate(new_position_dims):
if dim.value == type_.source_dim.value:
assert not found
new_position_dims[i] = type_.codomain
found = True
assert found
else:
raise NotImplementedError(f"{type_} is not a supported Connectivity type.")
new_position_dims: list[common.Dimension] | str
if offset_provider_type:
new_position_dims = [*it.position_dims]
assert len(offset_literals) % 2 == 0
for offset_axis, _ in zip(offset_literals[:-1:2], offset_literals[1::2], strict=True):
assert isinstance(offset_axis, it_ts.OffsetLiteralType) and isinstance(
offset_axis.value, common.Dimension
)
type_ = offset_provider_type[offset_axis.value.value]
if isinstance(type_, common.Dimension):
pass
elif isinstance(type_, common.NeighborConnectivityType):
found = False
for i, dim in enumerate(new_position_dims):
if dim.value == type_.source_dim.value:
assert not found
new_position_dims[i] = type_.codomain
found = True
assert found
else:
raise NotImplementedError(f"{type_} is not a supported Connectivity type.")
else:
# during re-inference we don't have an offset provider type
new_position_dims = "unknown"
return it_ts.IteratorType(
position_dims=new_position_dims,
defined_dims=it.defined_dims,
Expand Down

0 comments on commit 7bf42dd

Please sign in to comment.