diff --git a/dmlcloud/util/data.py b/dmlcloud/util/data.py new file mode 100644 index 0000000..69fc51f --- /dev/null +++ b/dmlcloud/util/data.py @@ -0,0 +1,61 @@ +from typing import Iterable +import numpy as np +import xarray as xr +import torch.distributed as dist + + +def shard_indices( + n: int, + rank: int, + world_size: int, + shuffle: bool=False, + drop_remainder: bool=True, + seed: int=0 + ) -> list[int]: + indices = np.arange(n) + + if shuffle: + np.random.Generator(np.random.MT19937(seed)).shuffle(indices) + + if drop_remainder: + indices = indices[: n - n % world_size] + + return indices[rank::world_size].tolist() # this also converts np.int64 to python's int + + +def chunked_xr_dataset( + ds: xr.Dataset | xr.DataArray, + chunk_size: int, + dim: str, + shuffle: bool=False, + drop_remainder: bool=True, + seed: int=0, + rank: int|None=None, + world_size: int|None=None, + process_group: dist.ProcessGroup|None=None, + load: bool = True, + ) -> Iterable[xr.Dataset | xr.DataArray]: + num_total_elements = len(ds[dim]) + num_chunks = num_total_elements // chunk_size + + if rank is None: + rank = dist.get_rank(process_group) + if world_size is None: + world_size = dist.get_world_size(process_group) + + chunk_indices = shard_indices( + num_chunks, + rank, + world_size, + shuffle=shuffle, + drop_remainder=drop_remainder, + seed=seed + ) + + for chunk_idx in chunk_indices: + start = chunk_idx * chunk_size + end = start + chunk_size + chunk = ds.isel({dim: slice(start, end)}) + if load: + chunk.load() + yield chunk \ No newline at end of file diff --git a/dmlcloud/util/distributed.py b/dmlcloud/util/distributed.py index 13768fe..5a1bcbb 100644 --- a/dmlcloud/util/distributed.py +++ b/dmlcloud/util/distributed.py @@ -80,18 +80,6 @@ def print_worker(msg, barrier=True, flush=True): dist.barrier() -def shard_indices(n, rank, size, shuffle=True, drop_remainder=False, seed=0): - indices = np.arange(n) - - if shuffle: - np.random.Generator(np.random.MT19937(seed)).shuffle(indices) - - if drop_remainder: - indices = indices[: n - n % size] - - return indices[rank::size] - - def init_process_group_dummy(): """ Initializes the process group with a single process. diff --git a/test/test_data.py b/test/test_data.py new file mode 100644 index 0000000..e4fcc3f --- /dev/null +++ b/test/test_data.py @@ -0,0 +1,99 @@ +import sys + +import xarray as xr +import numpy as np +import pytest +from dmlcloud.util.data import shard_indices, chunked_xr_dataset +from numpy.testing import assert_array_equal + + +class TestSharding: + + def test_types(self): + indices = shard_indices(10, 0, 2, shuffle=False, drop_remainder=False) + assert isinstance(indices, list) + assert all(isinstance(i, int) for i in indices) + + def test_even(self): + assert shard_indices(10, 0, 2, shuffle=False, drop_remainder=False) == [0, 2, 4, 6, 8] + assert shard_indices(10, 1, 2, shuffle=False, drop_remainder=False) == [1, 3, 5, 7, 9] + + def test_uneven(self): + assert shard_indices(10, 0, 3, shuffle=False, drop_remainder=False) == [0, 3, 6, 9] + assert shard_indices(10, 1, 3, shuffle=False, drop_remainder=False) == [1, 4, 7] + assert shard_indices(10, 2, 3, shuffle=False, drop_remainder=False) == [2, 5, 8] + + assert shard_indices(11, 0, 2, shuffle=False, drop_remainder=False) == [0, 2, 4, 6, 8, 10] + assert shard_indices(11, 1, 2, shuffle=False, drop_remainder=False) == [1, 3, 5, 7, 9] + + def test_dropping(self): + assert shard_indices(10, 0, 3, shuffle=False, drop_remainder=True) == [0, 3, 6] + assert shard_indices(10, 1, 3, shuffle=False, drop_remainder=True) == [1, 4, 7] + assert shard_indices(10, 2, 3, shuffle=False, drop_remainder=True) == [2, 5, 8] + + assert shard_indices(11, 0, 2, shuffle=False, drop_remainder=True) == [0, 2, 4, 6, 8] + assert shard_indices(11, 1, 2, shuffle=False, drop_remainder=True) == [1, 3, 5, 7, 9] + + def test_shuffling(self): + indices = shard_indices(10, 0, 2, shuffle=True, drop_remainder=False, seed=0) + assert len(indices) == 5 + assert len(np.unique(indices)) == 5 + assert indices != list(sorted(indices)) + assert (np.array(indices) >= 0).all() and (np.array(indices) <= 9).all() + + +class TestChunking: + + def test_basic(self): + ds = xr.DataArray(np.arange(100), dims=['x'], name='var').to_dataset() + world_size = 3 + chunk_size = 15 + + chunks_1 = list(chunked_xr_dataset(ds, chunk_size, 'x', world_size=world_size, rank=0, shuffle=False)) + chunks_2 = list(chunked_xr_dataset(ds, chunk_size, 'x', world_size=world_size, rank=1, shuffle=False)) + chunks_3 = list(chunked_xr_dataset(ds, chunk_size, 'x', world_size=world_size, rank=2, shuffle=False)) + + assert len(chunks_1) == 2 + assert len(chunks_2) == 2 + assert len(chunks_3) == 2 + + assert isinstance(chunks_1[0], xr.Dataset) + + assert chunks_1[0].x.size == 15 + assert chunks_1[1].x.size == 15 + assert chunks_2[0].x.size == 15 + assert chunks_2[1].x.size == 15 + assert chunks_3[0].x.size == 15 + assert chunks_3[1].x.size == 15 + + assert_array_equal(chunks_1[0]['var'], np.arange(0, 15)) + assert_array_equal(chunks_2[0]['var'], np.arange(15, 30)) + assert_array_equal(chunks_3[0]['var'], np.arange(30, 45)) + assert_array_equal(chunks_1[1]['var'], np.arange(45, 60)) + assert_array_equal(chunks_2[1]['var'], np.arange(60, 75)) + assert_array_equal(chunks_3[1]['var'], np.arange(75, 90)) + + + def test_shuffled(self): + ds = xr.DataArray(np.arange(100), dims=['x'], name='var').to_dataset() + world_size = 3 + chunk_size = 15 + + chunks_1 = list(chunked_xr_dataset(ds, chunk_size, 'x', world_size=world_size, rank=0, shuffle=True, seed=0)) + chunks_2 = list(chunked_xr_dataset(ds, chunk_size, 'x', world_size=world_size, rank=1, shuffle=True, seed=0)) + chunks_3 = list(chunked_xr_dataset(ds, chunk_size, 'x', world_size=world_size, rank=2, shuffle=True, seed=0)) + + assert len(chunks_1) == 2 + assert len(chunks_2) == 2 + assert len(chunks_3) == 2 + + catted = xr.concat(chunks_1 + chunks_2 + chunks_3, dim='x')['var'].values + assert catted.tolist() != list(range(90)) + assert list(sorted(catted.tolist())) == list(range(90)) + + chunk = chunks_1[0]['var'].values + assert chunk.tolist() == list(range(chunk[0], chunk[-1] + 1)) + + +if __name__ == '__main__': + sys.exit(pytest.main([__file__])) diff --git a/test/test_metrics.py b/test/test_metrics.py index 4f276e8..1deca76 100644 --- a/test/test_metrics.py +++ b/test/test_metrics.py @@ -1,7 +1,5 @@ import sys -sys.path.insert(0, './') - import pytest import torch from dmlcloud.metrics import MetricReducer, MetricTracker, Reduction diff --git a/test/test_sharding.py b/test/test_sharding.py deleted file mode 100644 index 53a9006..0000000 --- a/test/test_sharding.py +++ /dev/null @@ -1,39 +0,0 @@ -import sys - -import numpy as np -import pytest -from dmlcloud.util.distributed import shard_indices -from numpy.testing import assert_array_equal - - -class TestSharding: - def test_even(self): - assert_array_equal(shard_indices(10, 0, 2, shuffle=False, drop_remainder=False), [0, 2, 4, 6, 8]) - assert_array_equal(shard_indices(10, 1, 2, shuffle=False, drop_remainder=False), [1, 3, 5, 7, 9]) - - def test_uneven(self): - assert_array_equal(shard_indices(10, 0, 3, shuffle=False, drop_remainder=False), [0, 3, 6, 9]) - assert_array_equal(shard_indices(10, 1, 3, shuffle=False, drop_remainder=False), [1, 4, 7]) - assert_array_equal(shard_indices(10, 2, 3, shuffle=False, drop_remainder=False), [2, 5, 8]) - - assert_array_equal(shard_indices(11, 0, 2, shuffle=False, drop_remainder=False), [0, 2, 4, 6, 8, 10]) - assert_array_equal(shard_indices(11, 1, 2, shuffle=False, drop_remainder=False), [1, 3, 5, 7, 9]) - - def test_dropping(self): - assert_array_equal(shard_indices(10, 0, 3, shuffle=False, drop_remainder=True), [0, 3, 6]) - assert_array_equal(shard_indices(10, 1, 3, shuffle=False, drop_remainder=True), [1, 4, 7]) - assert_array_equal(shard_indices(10, 2, 3, shuffle=False, drop_remainder=True), [2, 5, 8]) - - assert_array_equal(shard_indices(11, 0, 2, shuffle=False, drop_remainder=True), [0, 2, 4, 6, 8]) - assert_array_equal(shard_indices(11, 1, 2, shuffle=False, drop_remainder=True), [1, 3, 5, 7, 9]) - - def test_shuffling(self): - indices = shard_indices(10, 0, 2, shuffle=True, drop_remainder=False, seed=0) - assert len(indices) == 5 - assert len(np.unique(indices)) == 5 - assert list(indices) != list(sorted(indices)) - assert (indices >= 0).all() and (indices <= 9).all() - - -if __name__ == '__main__': - sys.exit(pytest.main([__file__]))