Skip to content

Commit

Permalink
feat: keep chunked sharding separate from xarray
Browse files Browse the repository at this point in the history
  • Loading branch information
sehoffmann committed Apr 2, 2024
1 parent 624b8ee commit a9d9615
Showing 1 changed file with 38 additions and 14 deletions.
52 changes: 38 additions & 14 deletions dmlcloud/util/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,31 @@ def shard_indices(
return indices[rank::world_size].tolist() # this also converts np.int64 to python's int


def chunk_and_shard_indices(
num_elements: int,
chunk_size: int,
rank: int,
world_size: int,
chunk_overlap: int = 0,
even_shards: bool = True,
equal_chunks: bool = True,
shuffle: bool = False,
seed: int = 0,
):
if equal_chunks:
num_chunks = num_elements // chunk_size
else:
num_chunks = (num_elements + chunk_size - 1) // chunk_size

chunk_indices = shard_indices(num_chunks, rank, world_size, shuffle=shuffle, even_shards=even_shards, seed=seed)
chunks = []
for chunk_idx in chunk_indices:
start = chunk_idx * chunk_size
end = start + chunk_size + chunk_overlap
chunks.append((start, end))
return chunks


def sharded_xr_dataset(
ds: xr.Dataset | xr.DataArray,
dim: str,
Expand All @@ -44,29 +69,28 @@ def sharded_xr_dataset(
load: bool = False,
load_kwargs: dict | None = None,
) -> Iterable[xr.Dataset | xr.DataArray]:
num_total_elements = len(ds[dim])

if equal_chunks:
num_chunks = num_total_elements // chunk_size
else:
num_chunks = (num_total_elements + chunk_size - 1) // chunk_size

if rank is None:
rank = dist.get_rank(process_group)
if world_size is None:
world_size = dist.get_world_size(process_group)

chunk_indices = shard_indices(num_chunks, rank, world_size, shuffle=shuffle, even_shards=even_shards, seed=seed)

for chunk_idx in chunk_indices:
start = chunk_idx * chunk_size
end = start + chunk_size + chunk_overlap
num_elements = len(ds[dim])
chunks = chunk_and_shard_indices(
num_elements,
chunk_size,
rank,
world_size,
chunk_overlap=chunk_overlap,
even_shards=even_shards,
equal_chunks=equal_chunks,
shuffle=shuffle,
seed=seed,
)
for start, end in chunks:
chunk = ds.isel({dim: slice(start, end)})

if load:
kwargs = load_kwargs or {}
chunk.load(**kwargs)

yield chunk


Expand Down

0 comments on commit a9d9615

Please sign in to comment.