Skip to content

Commit

Permalink
fix: requirements & formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
sehoffmann committed Mar 18, 2024
1 parent 720d808 commit 1659466
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 33 deletions.
45 changes: 18 additions & 27 deletions dmlcloud/util/data.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
from typing import Iterable

import numpy as np
import xarray as xr
import torch.distributed as dist
import xarray as xr


def shard_indices(
n: int,
rank: int,
world_size: int,
shuffle: bool=False,
drop_remainder: bool=True,
seed: int=0
) -> list[int]:
n: int, rank: int, world_size: int, shuffle: bool = False, drop_remainder: bool = True, seed: int = 0
) -> list[int]:
indices = np.arange(n)

if shuffle:
Expand All @@ -24,32 +20,27 @@ def shard_indices(


def chunked_xr_dataset(
ds: xr.Dataset | xr.DataArray,
chunk_size: int,
dim: str,
shuffle: bool=False,
drop_remainder: bool=True,
seed: int=0,
rank: int|None=None,
world_size: int|None=None,
process_group: dist.ProcessGroup|None=None,
load: bool = True,
) -> Iterable[xr.Dataset | xr.DataArray]:
ds: xr.Dataset | xr.DataArray,
chunk_size: int,
dim: str,
shuffle: bool = False,
drop_remainder: bool = True,
seed: int = 0,
rank: int | None = None,
world_size: int | None = None,
process_group: dist.ProcessGroup | None = None,
load: bool = True,
) -> Iterable[xr.Dataset | xr.DataArray]:
num_total_elements = len(ds[dim])
num_chunks = num_total_elements // chunk_size

if rank is None:
rank = dist.get_rank(process_group)
if world_size is None:
world_size = dist.get_world_size(process_group)

chunk_indices = shard_indices(
num_chunks,
rank,
world_size,
shuffle=shuffle,
drop_remainder=drop_remainder,
seed=seed
num_chunks, rank, world_size, shuffle=shuffle, drop_remainder=drop_remainder, seed=seed
)

for chunk_idx in chunk_indices:
Expand All @@ -58,4 +49,4 @@ def chunked_xr_dataset(
chunk = ds.isel({dim: slice(start, end)})
if load:
chunk.load()
yield chunk
yield chunk
1 change: 0 additions & 1 deletion dmlcloud/util/distributed.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
from contextlib import contextmanager

import numpy as np
import torch.distributed as dist

from .tcp import find_free_port, get_local_ips
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
torch
numpy
xarray
progress_table>=0.1.20,<1.0.0
omegaconf
7 changes: 2 additions & 5 deletions test/test_data.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import sys

import xarray as xr
import numpy as np
import pytest
from dmlcloud.util.data import shard_indices, chunked_xr_dataset
import xarray as xr
from dmlcloud.util.data import chunked_xr_dataset, shard_indices
from numpy.testing import assert_array_equal


class TestSharding:

def test_types(self):
indices = shard_indices(10, 0, 2, shuffle=False, drop_remainder=False)
assert isinstance(indices, list)
Expand Down Expand Up @@ -43,7 +42,6 @@ def test_shuffling(self):


class TestChunking:

def test_basic(self):
ds = xr.DataArray(np.arange(100), dims=['x'], name='var').to_dataset()
world_size = 3
Expand Down Expand Up @@ -73,7 +71,6 @@ def test_basic(self):
assert_array_equal(chunks_2[1]['var'], np.arange(60, 75))
assert_array_equal(chunks_3[1]['var'], np.arange(75, 90))


def test_shuffled(self):
ds = xr.DataArray(np.arange(100), dims=['x'], name='var').to_dataset()
world_size = 3
Expand Down

0 comments on commit 1659466

Please sign in to comment.