Skip to content

Commit

Permalink
feat: chunked_xr_dataset()
Browse files Browse the repository at this point in the history
  • Loading branch information
sehoffmann committed Mar 18, 2024
1 parent 09f2697 commit 037db39
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 53 deletions.
61 changes: 61 additions & 0 deletions dmlcloud/util/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from typing import Iterable
import numpy as np
import xarray as xr
import torch.distributed as dist


def shard_indices(
n: int,
rank: int,
world_size: int,
shuffle: bool=False,
drop_remainder: bool=True,
seed: int=0
) -> list[int]:
indices = np.arange(n)

if shuffle:
np.random.Generator(np.random.MT19937(seed)).shuffle(indices)

if drop_remainder:
indices = indices[: n - n % world_size]

return indices[rank::world_size].tolist() # this also converts np.int64 to python's int


def chunked_xr_dataset(
ds: xr.Dataset | xr.DataArray,
chunk_size: int,
dim: str,
shuffle: bool=False,
drop_remainder: bool=True,
seed: int=0,
rank: int|None=None,
world_size: int|None=None,
process_group: dist.ProcessGroup|None=None,
load: bool = True,
) -> Iterable[xr.Dataset | xr.DataArray]:
num_total_elements = len(ds[dim])
num_chunks = num_total_elements // chunk_size

if rank is None:
rank = dist.get_rank(process_group)
if world_size is None:
world_size = dist.get_world_size(process_group)

chunk_indices = shard_indices(
num_chunks,
rank,
world_size,
shuffle=shuffle,
drop_remainder=drop_remainder,
seed=seed
)

for chunk_idx in chunk_indices:
start = chunk_idx * chunk_size
end = start + chunk_size
chunk = ds.isel({dim: slice(start, end)})
if load:
chunk.load()
yield chunk
12 changes: 0 additions & 12 deletions dmlcloud/util/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,18 +80,6 @@ def print_worker(msg, barrier=True, flush=True):
dist.barrier()


def shard_indices(n, rank, size, shuffle=True, drop_remainder=False, seed=0):
indices = np.arange(n)

if shuffle:
np.random.Generator(np.random.MT19937(seed)).shuffle(indices)

if drop_remainder:
indices = indices[: n - n % size]

return indices[rank::size]


def init_process_group_dummy():
"""
Initializes the process group with a single process.
Expand Down
99 changes: 99 additions & 0 deletions test/test_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import sys

import xarray as xr
import numpy as np
import pytest
from dmlcloud.util.data import shard_indices, chunked_xr_dataset
from numpy.testing import assert_array_equal


class TestSharding:

def test_types(self):
indices = shard_indices(10, 0, 2, shuffle=False, drop_remainder=False)
assert isinstance(indices, list)
assert all(isinstance(i, int) for i in indices)

def test_even(self):
assert shard_indices(10, 0, 2, shuffle=False, drop_remainder=False) == [0, 2, 4, 6, 8]
assert shard_indices(10, 1, 2, shuffle=False, drop_remainder=False) == [1, 3, 5, 7, 9]

def test_uneven(self):
assert shard_indices(10, 0, 3, shuffle=False, drop_remainder=False) == [0, 3, 6, 9]
assert shard_indices(10, 1, 3, shuffle=False, drop_remainder=False) == [1, 4, 7]
assert shard_indices(10, 2, 3, shuffle=False, drop_remainder=False) == [2, 5, 8]

assert shard_indices(11, 0, 2, shuffle=False, drop_remainder=False) == [0, 2, 4, 6, 8, 10]
assert shard_indices(11, 1, 2, shuffle=False, drop_remainder=False) == [1, 3, 5, 7, 9]

def test_dropping(self):
assert shard_indices(10, 0, 3, shuffle=False, drop_remainder=True) == [0, 3, 6]
assert shard_indices(10, 1, 3, shuffle=False, drop_remainder=True) == [1, 4, 7]
assert shard_indices(10, 2, 3, shuffle=False, drop_remainder=True) == [2, 5, 8]

assert shard_indices(11, 0, 2, shuffle=False, drop_remainder=True) == [0, 2, 4, 6, 8]
assert shard_indices(11, 1, 2, shuffle=False, drop_remainder=True) == [1, 3, 5, 7, 9]

def test_shuffling(self):
indices = shard_indices(10, 0, 2, shuffle=True, drop_remainder=False, seed=0)
assert len(indices) == 5
assert len(np.unique(indices)) == 5
assert indices != list(sorted(indices))
assert (np.array(indices) >= 0).all() and (np.array(indices) <= 9).all()


class TestChunking:

def test_basic(self):
ds = xr.DataArray(np.arange(100), dims=['x'], name='var').to_dataset()
world_size = 3
chunk_size = 15

chunks_1 = list(chunked_xr_dataset(ds, chunk_size, 'x', world_size=world_size, rank=0, shuffle=False))
chunks_2 = list(chunked_xr_dataset(ds, chunk_size, 'x', world_size=world_size, rank=1, shuffle=False))
chunks_3 = list(chunked_xr_dataset(ds, chunk_size, 'x', world_size=world_size, rank=2, shuffle=False))

assert len(chunks_1) == 2
assert len(chunks_2) == 2
assert len(chunks_3) == 2

assert isinstance(chunks_1[0], xr.Dataset)

assert chunks_1[0].x.size == 15
assert chunks_1[1].x.size == 15
assert chunks_2[0].x.size == 15
assert chunks_2[1].x.size == 15
assert chunks_3[0].x.size == 15
assert chunks_3[1].x.size == 15

assert_array_equal(chunks_1[0]['var'], np.arange(0, 15))
assert_array_equal(chunks_2[0]['var'], np.arange(15, 30))
assert_array_equal(chunks_3[0]['var'], np.arange(30, 45))
assert_array_equal(chunks_1[1]['var'], np.arange(45, 60))
assert_array_equal(chunks_2[1]['var'], np.arange(60, 75))
assert_array_equal(chunks_3[1]['var'], np.arange(75, 90))


def test_shuffled(self):
ds = xr.DataArray(np.arange(100), dims=['x'], name='var').to_dataset()
world_size = 3
chunk_size = 15

chunks_1 = list(chunked_xr_dataset(ds, chunk_size, 'x', world_size=world_size, rank=0, shuffle=True, seed=0))
chunks_2 = list(chunked_xr_dataset(ds, chunk_size, 'x', world_size=world_size, rank=1, shuffle=True, seed=0))
chunks_3 = list(chunked_xr_dataset(ds, chunk_size, 'x', world_size=world_size, rank=2, shuffle=True, seed=0))

assert len(chunks_1) == 2
assert len(chunks_2) == 2
assert len(chunks_3) == 2

catted = xr.concat(chunks_1 + chunks_2 + chunks_3, dim='x')['var'].values
assert catted.tolist() != list(range(90))
assert list(sorted(catted.tolist())) == list(range(90))

chunk = chunks_1[0]['var'].values
assert chunk.tolist() == list(range(chunk[0], chunk[-1] + 1))


if __name__ == '__main__':
sys.exit(pytest.main([__file__]))
2 changes: 0 additions & 2 deletions test/test_metrics.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import sys

sys.path.insert(0, './')

import pytest
import torch
from dmlcloud.metrics import MetricReducer, MetricTracker, Reduction
Expand Down
39 changes: 0 additions & 39 deletions test/test_sharding.py

This file was deleted.

0 comments on commit 037db39

Please sign in to comment.