Skip to content

Commit

Permalink
fix: unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sehoffmann committed Dec 27, 2024
1 parent f0454e5 commit 974c9ef
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 234 deletions.
2 changes: 1 addition & 1 deletion dmlcloud/core/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def add_callback(self, callback: StageCallback):
self.callbacks.append(callback)

def log(self, name: str, value: Any, reduction: str = 'mean', prefixed: bool = True):
if prefixed:
if prefixed and self.metric_prefix:
name = f'{self.metric_prefix}/{name}'
self.tracker.log(name, value, reduction)

Expand Down
7 changes: 4 additions & 3 deletions test/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import pytest
from dmlcloud.core.distributed import deinitialize_torch_distributed, init_process_group_dummy
import torch
from dmlcloud.core.distributed import init


@pytest.fixture
def torch_distributed():
init_process_group_dummy()
init(kind='dummy')
yield
deinitialize_torch_distributed()
torch.distributed.destroy_process_group()
2 changes: 1 addition & 1 deletion test/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest
import torch
import xarray as xr
from dmlcloud.util.data import interleave_batches, shard_indices, sharded_xr_dataset, ShardedXrDataset
from dmlcloud.data import interleave_batches, shard_indices, sharded_xr_dataset, ShardedXrDataset
from numpy.testing import assert_array_equal
from torch.utils.data import DataLoader, IterableDataset

Expand Down
208 changes: 0 additions & 208 deletions test/test_metrics.py

This file was deleted.

53 changes: 32 additions & 21 deletions test/test_smoke.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,56 @@
import sys

import dmlcloud as dml
import pytest
import torch
from dmlcloud.core.pipeline import TrainingPipeline
from dmlcloud.core.stage import TrainValStage


class DummyDataset(torch.utils.data.Dataset):
def __len__(self):
return 8
return 256

def __getitem__(self, idx):
return torch.randn(10), torch.randint(0, 10, size=(1,)).item()
x = torch.randn(10)
y = x.sum() * 0.1
return x, y


class DummyStage(TrainValStage):
class DummyStage(dml.Stage):
def pre_stage(self):
self.model = torch.nn.Linear(10, 10)
self.pipeline.register_model('linear', self.model)
self.train_dl = torch.utils.data.DataLoader(DummyDataset(), batch_size=32)

self.optimizer = torch.optim.SGD(self.model.parameters(), lr=1e-3)
self.pipeline.register_optimizer('sgd', self.optimizer)
model = torch.nn.Sequential(
torch.nn.Linear(10, 32),
torch.nn.Linear(32, 1),
)
self.model = dml.wrap_ddp(model, self.device)
self.optim = torch.optim.Adam(self.model.parameters(), lr=dml.scale_lr(1e-2))
self.loss = torch.nn.L1Loss()

self.pipeline.register_dataset('train', torch.utils.data.DataLoader(DummyDataset(), batch_size=4))
self.pipeline.register_dataset('val', torch.utils.data.DataLoader(DummyDataset(), batch_size=4))
def run_epoch(self):
for x, y in self.train_dl:
self.optim.zero_grad()

self.loss = torch.nn.CrossEntropyLoss()
x, y = x.to(self.device), y.to(self.device)
output = self.model(x)
loss = self.loss(output[:, 0], y)
loss.backward()

def step(self, batch):
x, y = batch
x, y = x.to(self.device), y.to(self.device)
output = self.model(x)
loss = self.loss(output, y)
return loss
self.optim.step()

self.log('train/loss', loss)


class TestSmoke:
def test_smoke(self, torch_distributed):
pipeline = TrainingPipeline()
pipeline.append_stage(DummyStage(), max_epochs=1)
pipeline.run()
pipe = dml.Pipeline()
stage = DummyStage(epochs=3)
pipe.append(stage)
pipe.run()

assert stage.current_epoch == 3
assert 'train/loss' in stage.history
assert stage.history.last()['train/loss'] < 0.1


if __name__ == '__main__':
Expand Down

0 comments on commit 974c9ef

Please sign in to comment.