Skip to content

Commit

Permalink
test: add mnist dataset for ocr recognition testing
Browse files Browse the repository at this point in the history
  • Loading branch information
sokovninn committed Jan 27, 2025
1 parent 76cddf5 commit 92502d8
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 0 deletions.
41 changes: 41 additions & 0 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,47 @@ def CIFAR10_subset_generator():
return dataset


@pytest.fixture(scope="session")
def mnist_dataset_ocr() -> LuxonisDataset:
dataset = LuxonisDataset("mnist_test", delete_existing=True)
output_folder = WORK_DIR / "mnist"
output_folder.mkdir(parents=True, exist_ok=True)
mnist_torch = torchvision.datasets.MNIST(
root=output_folder, train=False, download=True
)

classes = [
"0",
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9",
]

def MNIST_subset_generator():
for i, (image, label) in enumerate(mnist_torch):
if i == 1000:
break
path = output_folder / f"mnist_{i}.png"
image.save(path)
print(path, label)
yield {
"file": path,
"annotation": {
"metadata": {"text": classes[label], "text_length": 1},
},
}

dataset.add(MNIST_subset_generator())
dataset.make_splits()
return dataset


@pytest.fixture
def config(train_overfit: bool) -> dict[str, Any]:
if train_overfit: # pragma: no cover
Expand Down
3 changes: 3 additions & 0 deletions tests/integration/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,15 @@ def test_predefined_models(
config_file: str,
coco_dataset: LuxonisDataset,
cifar10_dataset: LuxonisDataset,
mnist_dataset_ocr: LuxonisDataset,
):
config_file = f"configs/{config_file}.yaml"
opts |= {
"loader.params.dataset_name": (
cifar10_dataset.identifier
if "classification" in config_file
else mnist_dataset_ocr.identifier
if "ocr_recognition" in config_file
else coco_dataset.identifier
),
"trainer.epochs": 1,
Expand Down

0 comments on commit 92502d8

Please sign in to comment.