Skip to content

Commit

Permalink
Merge pull request #25 from mlexchange:unit_tests
Browse files Browse the repository at this point in the history
Unit tests of TiledDataset Class
  • Loading branch information
dylanmcreynolds authored Jun 18, 2024
2 parents e1d57e6 + 1a092c7 commit e76c22a
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 6 deletions.
4 changes: 2 additions & 2 deletions src/_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,6 @@ def client(context):
client = from_context(context)
recons_container = client.create_container("reconstructions")
recons_container.write_array(np.zeros((2, 3, 3), dtype=np.int8), key="recon1")
masks_container = client.create_container("uid0001", metadata={"mask_idx": ["0"]})
masks_container.write_array(np.zeros((1, 3, 3), dtype=np.int8), key="mask")
masks_container = client.create_container("uid0001", metadata={"mask_idx": ["1"]})
masks_container.write_array(np.ones((1, 3, 3), dtype=np.int8), key="mask")
yield client
41 changes: 41 additions & 0 deletions src/_tests/example_tunet.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Example for parameters to excecute

# I/O
io_parameters:
data_tiled_uri:
data_tiled_api_key:
mask_tiled_uri:
mask_tiled_api_key:
seg_tiled_uri:
uid_save:
uid_retrieve:
models_dir: .

model_parameters:
network: "TUNet"
num_classes: 3
num_epochs: 3
optimizer: "Adam"
criterion: "CrossEntropyLoss"
weights: "[1.0, 2.0, 0.5]"
learning_rate: 0.1
activation: "ReLU"
normalization: "BatchNorm2d"
convolution: "Conv2d"

qlty_window: 64
qlty_step: 32
qlty_border: 8

shuffle_train: True
batch_size_train: 1

batch_size_val: 1

batch_size_inference: 2
val_pct: 0.2

depth: 4
base_channels: 8
growth_rate: 2
hidden_rate: 1
Empty file added src/_tests/test_inference.py
Empty file.
44 changes: 40 additions & 4 deletions src/_tests/test_tiled_dataset.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,52 @@
import numpy as np

from ..tiled_dataset import TiledDataset


def test_tiled_dataset(client):
def test_with_mask_training(client):
tiled_dataset = TiledDataset(
client["reconstructions"]["recon1"],
data_tiled_client=client["reconstructions"]["recon1"],
mask_tiled_client=client["uid0001"],
is_training=True,
)
assert tiled_dataset
assert tiled_dataset.mask_idx == [1]
assert len(tiled_dataset) == 1
assert len(tiled_dataset[0]) == 2
# Check data
assert tiled_dataset[0][0].shape == (3, 3)
assert not np.all(tiled_dataset[0][0]) # should be all 0s
# Check mask
assert tiled_dataset[0][1].shape == (3, 3)
assert np.all(tiled_dataset[0][1]) # should be all 1s


def test_with_mask_inference(client):
tiled_dataset = TiledDataset(
data_tiled_client=client["reconstructions"]["recon1"],
mask_tiled_client=client["uid0001"],
is_training=False,
)
assert tiled_dataset
assert tiled_dataset.mask_idx == [1]
assert len(tiled_dataset) == 1
# Check data
assert tiled_dataset[0].shape == (3, 3)
assert not np.all(tiled_dataset[0]) # should be all 0s


def test_tiled_dataset_with_masks(client):
def test_no_mask_inference(client):
tiled_dataset = TiledDataset(
client["reconstructions"]["recon1"], mask_tiled_client=client["uid0001"]
data_tiled_client=client["reconstructions"]["recon1"],
is_training=False,
)
assert tiled_dataset
assert len(tiled_dataset) == 2
# Check data
assert tiled_dataset[0].shape == (3, 3)
assert not np.all(tiled_dataset[0]) # should be all 0s


# TODO: Test qlty cropping within tiled_dataset.
# Since this part has been moved to the training script and performed outside,
# this is not on higher priority.
21 changes: 21 additions & 0 deletions src/_tests/test_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# TODO: test general yaml loading, check param type

# TODO: test model params loading, check pydantic class type / format

# TODO: check dir creation? How to handle file system change during pytest?

# TODO: load TiledDataset from fixture client, test already done.

# TODO: test data and mask array dim and shape

# TODO: test train_loader and val_loader from crop_split_load func, check length

# TODO: test build_network. How to deal with lengthy func? test all network options?

# TODO: test weights and criterion?

# TODO: test dvc?

# TODO: test trainer building

# TODO: test 1 epoch, check param saving

0 comments on commit e76c22a

Please sign in to comment.