-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #25 from mlexchange:unit_tests
Unit tests of TiledDataset Class
- Loading branch information
Showing
5 changed files
with
104 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# Example for parameters to excecute | ||
|
||
# I/O | ||
io_parameters: | ||
data_tiled_uri: | ||
data_tiled_api_key: | ||
mask_tiled_uri: | ||
mask_tiled_api_key: | ||
seg_tiled_uri: | ||
uid_save: | ||
uid_retrieve: | ||
models_dir: . | ||
|
||
model_parameters: | ||
network: "TUNet" | ||
num_classes: 3 | ||
num_epochs: 3 | ||
optimizer: "Adam" | ||
criterion: "CrossEntropyLoss" | ||
weights: "[1.0, 2.0, 0.5]" | ||
learning_rate: 0.1 | ||
activation: "ReLU" | ||
normalization: "BatchNorm2d" | ||
convolution: "Conv2d" | ||
|
||
qlty_window: 64 | ||
qlty_step: 32 | ||
qlty_border: 8 | ||
|
||
shuffle_train: True | ||
batch_size_train: 1 | ||
|
||
batch_size_val: 1 | ||
|
||
batch_size_inference: 2 | ||
val_pct: 0.2 | ||
|
||
depth: 4 | ||
base_channels: 8 | ||
growth_rate: 2 | ||
hidden_rate: 1 |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,52 @@ | ||
import numpy as np | ||
|
||
from ..tiled_dataset import TiledDataset | ||
|
||
|
||
def test_tiled_dataset(client): | ||
def test_with_mask_training(client): | ||
tiled_dataset = TiledDataset( | ||
client["reconstructions"]["recon1"], | ||
data_tiled_client=client["reconstructions"]["recon1"], | ||
mask_tiled_client=client["uid0001"], | ||
is_training=True, | ||
) | ||
assert tiled_dataset | ||
assert tiled_dataset.mask_idx == [1] | ||
assert len(tiled_dataset) == 1 | ||
assert len(tiled_dataset[0]) == 2 | ||
# Check data | ||
assert tiled_dataset[0][0].shape == (3, 3) | ||
assert not np.all(tiled_dataset[0][0]) # should be all 0s | ||
# Check mask | ||
assert tiled_dataset[0][1].shape == (3, 3) | ||
assert np.all(tiled_dataset[0][1]) # should be all 1s | ||
|
||
|
||
def test_with_mask_inference(client): | ||
tiled_dataset = TiledDataset( | ||
data_tiled_client=client["reconstructions"]["recon1"], | ||
mask_tiled_client=client["uid0001"], | ||
is_training=False, | ||
) | ||
assert tiled_dataset | ||
assert tiled_dataset.mask_idx == [1] | ||
assert len(tiled_dataset) == 1 | ||
# Check data | ||
assert tiled_dataset[0].shape == (3, 3) | ||
assert not np.all(tiled_dataset[0]) # should be all 0s | ||
|
||
|
||
def test_tiled_dataset_with_masks(client): | ||
def test_no_mask_inference(client): | ||
tiled_dataset = TiledDataset( | ||
client["reconstructions"]["recon1"], mask_tiled_client=client["uid0001"] | ||
data_tiled_client=client["reconstructions"]["recon1"], | ||
is_training=False, | ||
) | ||
assert tiled_dataset | ||
assert len(tiled_dataset) == 2 | ||
# Check data | ||
assert tiled_dataset[0].shape == (3, 3) | ||
assert not np.all(tiled_dataset[0]) # should be all 0s | ||
|
||
|
||
# TODO: Test qlty cropping within tiled_dataset. | ||
# Since this part has been moved to the training script and performed outside, | ||
# this is not on higher priority. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# TODO: test general yaml loading, check param type | ||
|
||
# TODO: test model params loading, check pydantic class type / format | ||
|
||
# TODO: check dir creation? How to handle file system change during pytest? | ||
|
||
# TODO: load TiledDataset from fixture client, test already done. | ||
|
||
# TODO: test data and mask array dim and shape | ||
|
||
# TODO: test train_loader and val_loader from crop_split_load func, check length | ||
|
||
# TODO: test build_network. How to deal with lengthy func? test all network options? | ||
|
||
# TODO: test weights and criterion? | ||
|
||
# TODO: test dvc? | ||
|
||
# TODO: test trainer building | ||
|
||
# TODO: test 1 epoch, check param saving |