Skip to content

Commit

Permalink
Add unit test and refactor TiledDataset
Browse files Browse the repository at this point in the history
  • Loading branch information
dylanmcreynolds committed Jul 25, 2024
1 parent e76c22a commit 079aba4
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 2 deletions.
57 changes: 57 additions & 0 deletions src/_tests/test_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import numpy as np
import pytest
from tiled.catalog import from_uri
from tiled.client import Context, from_context
from tiled.server.app import build_app

from ..tiled_dataset import TiledDataset


@pytest.fixture
def catalog(tmpdir):
adapter = from_uri(
f"sqlite+aiosqlite:///{tmpdir}/catalog.db",
writable_storage=str(tmpdir),
init_if_not_exists=True,
)
yield adapter


@pytest.fixture
def app(catalog):
app = build_app(catalog)
yield app


@pytest.fixture
def context(app):
with Context.from_app(app) as context:
yield context


@pytest.fixture
def client(context):
"Fixture for tests which only read data"
client = from_context(context)
recons_container = client.create_container("reconstructions")
recons_container.write_array(np.zeros((2, 3, 3), dtype=np.int8), key="recon1")
masks_container = client.create_container("masks", metadata={"mask_idx": ["0"]})
masks_container.write_array(np.zeros((1, 3, 3), dtype=np.int8), key="mask1")
yield client


@pytest.mark.asyncio
async def test_tiled_dataset(client):
tiled_dataset = TiledDataset(
client["reconstructions"]["recon1"],
)
assert tiled_dataset
assert tiled_dataset[0].shape == (3, 3)


@pytest.mark.asyncio
async def test_tiled_dataset_with_masks(client):
tiled_dataset = TiledDataset(
client["reconstructions"]["recon1"], mask_tiled_client=client["masks"]
)
assert tiled_dataset[0].shape == (3, 3)
4 changes: 2 additions & 2 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from tiled.client import from_uri
from torchvision import transforms

from network import build_network
from parameters import (
from .network import build_network
from .parameters import (
IOParameters,
MSDNetParameters,
SMSNetEnsembleParameters,
Expand Down

0 comments on commit 079aba4

Please sign in to comment.