From 4f1cf3b920b3a4e7ce37abab22898906160aa7c8 Mon Sep 17 00:00:00 2001 From: Pratyai Mazumder Date: Tue, 22 Oct 2024 01:00:30 +0200 Subject: [PATCH] Since `indices` may be given as a `set` for `Range`s, and the iteration order of a set is not specified, we should be able to handle that too (e.g., by returning from `offset_by()` a map from indices to new values, instead of just a sequence). --- dace/subsets.py | 15 +++++++++------ tests/subsets_test.py | 10 +++++----- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/dace/subsets.py b/dace/subsets.py index aae20a40d6..e120b18946 100644 --- a/dace/subsets.py +++ b/dace/subsets.py @@ -432,12 +432,14 @@ def data_dims(self): return (sum(1 if (re - rb + 1) != 1 else 0 for rb, re, _ in self.ranges) + sum(1 if ts != 1 else 0 for ts in self.tile_sizes)) - def offset_by(self, off: Collection, negative: bool, indices: Collection): + def offset_by(self, off: Sequence, negative: bool, indices: Collection): assert all(i < len(self.ranges) for i in indices) assert all(i < len(off) for i in indices) mult = -1 if negative else 1 - return Range([(self.ranges[i][0] + mult * off[i], self.ranges[i][1] + mult * off[i], self.ranges[i][2]) - for i in indices]) + return { + i: (self.ranges[i][0] + mult * off[i], self.ranges[i][1] + mult * off[i], self.ranges[i][2]) + for i in indices + } def offset_new(self, other, negative, indices=None): if not isinstance(other, Subset): @@ -447,13 +449,14 @@ def offset_new(self, other, negative, indices=None): other = Indices([other for _ in self.ranges]) if indices is None: indices = set(range(len(self.ranges))) - return self.offset_by(other.min_element(), negative, indices) + new_ranges = self.offset_by(other.min_element(), negative, indices) + return Range([new_ranges[i] for i in sorted(indices)]) def offset(self, other, negative, indices=None): if indices is None: indices = set(range(len(self.ranges))) new_ranges = self.offset_new(other, negative, indices).ranges - for i, r in zip(indices, new_ranges): + for i, r in zip(sorted(indices), new_ranges): self.ranges[i] = r def dims(self): @@ -947,7 +950,7 @@ def strides(self): def absolute_strides(self, global_shape): return [1] * len(self.indices) - def offset_by(self, off: Collection, negative: bool): + def offset_by(self, off: Sequence, negative: bool): assert len(off) <= len(self.indices) mult = -1 if negative else 1 return Indices([self.indices[i] + mult * off for i, off in enumerate(off)]) diff --git a/tests/subsets_test.py b/tests/subsets_test.py index d86910b2f4..41d5ecc93e 100644 --- a/tests/subsets_test.py +++ b/tests/subsets_test.py @@ -17,28 +17,28 @@ def test_range_offset_same_shape(self): # No offset off = [0, 0] rExpect = r0 - self.assertEqual(rExpect, r0.offset_by(off, False, [0, 1])) + self.assertEqual({0: rExpect.ranges[0], 1: rExpect.ranges[1]}, r0.offset_by(off, False, [0, 1])) self.assertEqual(rExpect, r0.offset_new(make_a_range_with_min_elements(off), False, [0, 1])) - self.assertEqual(rExpect, r0.offset_by(off, True, [0, 1])) + self.assertEqual({0: rExpect.ranges[0], 1: rExpect.ranges[1]}, r0.offset_by(off, True, [0, 1])) self.assertEqual(rExpect, r0.offset_new(make_a_range_with_min_elements(off), True, [0, 1])) # Positive offset off = [5, 4] negative = False rExpect = subsets.Range([(10, 10 + n - 1, 1), (9, 9 + m - 1, 1)]) - self.assertEqual(rExpect, r0.offset_by(off, negative, [0, 1])) + self.assertEqual({0: rExpect.ranges[0], 1: rExpect.ranges[1]}, r0.offset_by(off, negative, [0, 1])) self.assertEqual(rExpect, r0.offset_new(make_a_range_with_min_elements(off), negative, [0, 1])) # Only partially rExpect = subsets.Range([(9, 9 + m - 1, 1)]) partInds = [1] - self.assertEqual(rExpect, r0.offset_by(off, negative, partInds)) + self.assertEqual({1: rExpect.ranges[0]}, r0.offset_by(off, negative, partInds)) self.assertEqual(rExpect, r0.offset_new(make_a_range_with_min_elements(off), negative, partInds)) # Negative offset off = [5, 4] negative = True rExpect = subsets.Range([(0, n - 1, 1), (1, 1 + m - 1, 1)]) - self.assertEqual(rExpect, r0.offset_by(off, negative, [0, 1])) + self.assertEqual({0: rExpect.ranges[0], 1: rExpect.ranges[1]}, r0.offset_by(off, negative, [0, 1])) self.assertEqual(rExpect, r0.offset_new(make_a_range_with_min_elements(off), negative, [0, 1])) def test_range_offset_partial_indices(self):