From 92502d8044a383a2e9dda47e0d4fff9e06bbe86b Mon Sep 17 00:00:00 2001 From: Nikita Date: Mon, 27 Jan 2025 17:16:34 +0000 Subject: [PATCH] test: add mnist dataset for ocr recognition testing --- tests/integration/conftest.py | 41 ++++++++++++++++++++++++++++++++ tests/integration/test_simple.py | 3 +++ 2 files changed, 44 insertions(+) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index ab2fb1e8..3c0b8700 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -102,6 +102,47 @@ def CIFAR10_subset_generator(): return dataset +@pytest.fixture(scope="session") +def mnist_dataset_ocr() -> LuxonisDataset: + dataset = LuxonisDataset("mnist_test", delete_existing=True) + output_folder = WORK_DIR / "mnist" + output_folder.mkdir(parents=True, exist_ok=True) + mnist_torch = torchvision.datasets.MNIST( + root=output_folder, train=False, download=True + ) + + classes = [ + "0", + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + ] + + def MNIST_subset_generator(): + for i, (image, label) in enumerate(mnist_torch): + if i == 1000: + break + path = output_folder / f"mnist_{i}.png" + image.save(path) + print(path, label) + yield { + "file": path, + "annotation": { + "metadata": {"text": classes[label], "text_length": 1}, + }, + } + + dataset.add(MNIST_subset_generator()) + dataset.make_splits() + return dataset + + @pytest.fixture def config(train_overfit: bool) -> dict[str, Any]: if train_overfit: # pragma: no cover diff --git a/tests/integration/test_simple.py b/tests/integration/test_simple.py index 52f72989..52670d9e 100644 --- a/tests/integration/test_simple.py +++ b/tests/integration/test_simple.py @@ -68,12 +68,15 @@ def test_predefined_models( config_file: str, coco_dataset: LuxonisDataset, cifar10_dataset: LuxonisDataset, + mnist_dataset_ocr: LuxonisDataset, ): config_file = f"configs/{config_file}.yaml" opts |= { "loader.params.dataset_name": ( cifar10_dataset.identifier if "classification" in config_file + else mnist_dataset_ocr.identifier + if "ocr_recognition" in config_file else coco_dataset.identifier ), "trainer.epochs": 1,