Skip to content

Commit

Permalink
fix[next]: reshuffling for fields with non-zero domain start (#1845)
Browse files Browse the repository at this point in the history
  • Loading branch information
havogt authored Feb 3, 2025
1 parent b18bbf9 commit ac253b6
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
5 changes: 4 additions & 1 deletion src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,10 @@ def _reshuffling_premap(
conn_map[dim] = _identity_connectivity(new_domain, dim, cls=type(connectivity))

# Take data
take_indices = tuple(conn_map[dim].ndarray for dim in data.domain.dims)
take_indices = tuple(
conn_map[dim].ndarray - data.domain[dim].unit_range.start # shift to 0-based indexing
for dim in data.domain.dims
)
new_buffer = data._ndarray.__getitem__(take_indices)

return data.__class__.from_array(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -364,10 +364,11 @@ def test_reshuffling_premap():

ij_field = common._field(
np.asarray([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]),
domain=common.Domain(dims=(I, J), ranges=(UnitRange(0, 3), UnitRange(0, 3))),
domain=common.Domain(dims=(I, J), ranges=(UnitRange(1, 4), UnitRange(2, 5))),
)

max_ij_conn = common._connectivity(
np.fromfunction(lambda i, j: np.maximum(i, j), (3, 3), dtype=int),
np.asarray([[1, 2, 3], [2, 2, 3], [3, 3, 3]], dtype=int),
domain=common.Domain(
dims=ij_field.domain.dims,
ranges=ij_field.domain.ranges,
Expand All @@ -378,7 +379,7 @@ def test_reshuffling_premap():
result = ij_field.premap(max_ij_conn)
expected = common._field(
np.asarray([[0.0, 4.0, 8.0], [3.0, 4.0, 8.0], [6.0, 7.0, 8.0]]),
domain=common.Domain(dims=(I, J), ranges=(UnitRange(0, 3), UnitRange(0, 3))),
domain=common.Domain(dims=(I, J), ranges=(UnitRange(1, 4), UnitRange(2, 5))),
)

assert result.domain == expected.domain
Expand Down

0 comments on commit ac253b6

Please sign in to comment.