From f2536f6e7355adb2f4c5c2bf3a6ef5aaa964d28e Mon Sep 17 00:00:00 2001 From: jkbmrz Date: Wed, 13 Mar 2024 09:05:04 +0100 Subject: [PATCH] add LDF creation and adjust to new nn archive format --- tests/unittests/test_core/test_archiver.py | 97 +++++++++++++++------- 1 file changed, 68 insertions(+), 29 deletions(-) diff --git a/tests/unittests/test_core/test_archiver.py b/tests/unittests/test_core/test_archiver.py index 13225ef6..4cadf200 100644 --- a/tests/unittests/test_core/test_archiver.py +++ b/tests/unittests/test_core/test_archiver.py @@ -1,18 +1,23 @@ -import pytest -import torch -import torchvision +import io +import json import os +import random +import shutil import tarfile + +import cv2 +import numpy as np import onnx -import json +import torch +import torchvision import yaml +from luxonis_ml.data import LuxonisDataset import luxonis_train from luxonis_train.core import Archiver class TestArchiver: - @classmethod def setup_class(cls): """Create and load all files required for testing.""" @@ -20,28 +25,59 @@ def setup_class(cls): luxonis_train_parent_dir = os.path.dirname( os.path.dirname(luxonis_train.__file__) ) - tmp_test_path = os.path.join( + cls.tmp_path = os.path.join( luxonis_train_parent_dir, "tests", "unittests", "test_core", ) + # make LDF + os.mkdir(os.path.join(cls.tmp_path, "images")) + cls.ldf_name = "dummyLDF" + labels = ["label1", "label2", "label3"] + + def classification_dataset_generator(): + for i in range(10): + img = np.random.randint(0, 256, (50, 50, 3), dtype=np.uint8) + img_file_path = os.path.join(cls.tmp_path, "images", f"img{i}.png") + cv2.imwrite(img_file_path, img) + yield { + "file": img_file_path, + "type": "classification", + "value": True, + "class": random.choice(labels), + } + + if LuxonisDataset.exists(cls.ldf_name): + print("Deleting existing dataset") + LuxonisDataset(cls.ldf_name).delete_dataset() + dataset = LuxonisDataset(cls.ldf_name) + dataset.add(classification_dataset_generator) + dataset.set_classes(list(labels)) + dataset.make_splits() + # make config config_dict = { - "model": {"name": "dummy", "predefined_model": {"name": "DetectionModel"}}, - "dataset": { - "name": "dummyldf" - }, # TODO: set LDF name automatically - choose one random LDF or make a random LDF! + "model": { + "name": "test_model", + "predefined_model": {"name": "ClassificationModel"}, + }, + "dataset": {"name": cls.ldf_name}, + "archiver": { + "archive_name": "tmp_nn_archive", + "archive_save_directory": cls.tmp_path, + }, } - cls.config_path = os.path.join(tmp_test_path, "tmp_config.yaml") + cls.config_path = os.path.join(cls.tmp_path, "tmp_config.yaml") with open(cls.config_path, "w") as yaml_file: yaml_str = yaml.dump(config_dict) yaml_file.write(yaml_str) # make model - model = torchvision.models.squeezenet1_0(pretrained=False) - cls.model_path = "tmp_squeezenet1_0.onnx" + model = torchvision.models.mobilenet_v2(pretrained=False) + cls.model_path = os.path.join(cls.tmp_path, "tmp_mobilenet_v2.onnx") + n, c, h, w = 1, 3, 224, 224 input_shape = torch.randn(n, c, h, w) cls.input_names = ["TestInput"] @@ -61,14 +97,15 @@ def setup_class(cls): # load archive files into memory with tarfile.open(cls.archive_path, mode="r") as tar: - cls.archive_fnames = tar.getnames() # List all the contents of the tar file + cls.archive_fnames = tar.getnames() for fname in cls.archive_fnames: + f = tar.extractfile(fname) if fname.endswith(".json"): - json_file = tar.extractfile(fname) - cls.json_dict = json.load(json_file) + cls.json_dict = json.load(f) elif fname.endswith(".onnx"): - onnx_file = tar.extractfile(fname) - cls.onnx_model = onnx.load(onnx_file) + model_bytes = f.read() + model_io = io.BytesIO(model_bytes) + cls.onnx_model = onnx.load(model_io) @classmethod def teardown_class(cls): @@ -76,28 +113,30 @@ def teardown_class(cls): os.remove(cls.archive_path) os.remove(cls.config_path) os.remove(cls.model_path) + LuxonisDataset(cls.ldf_name).delete_dataset() + shutil.rmtree(os.path.join(cls.tmp_path, "images")) - def test_archive_suffix(self): + def test_archive_creation(self): """Test if nn_archive was created.""" - assert self.archive_path.endswith("tar.gz") + assert os.path.exists(self.archive_path) + + def test_archive_suffix(self): + """Test if nn_archive is compressed using xz option (should be the default + option).""" + assert self.archive_path.endswith("tar.xz") def test_archive_contents(self): - """Test if nn_archive consists of exactly one JSON and one ONNX file.""" + """Test if nn_archive consists of exactly one JSON file named config.json and + one ONNX file.""" assert ( len(self.archive_fnames) == 2 - and any([fname.endswith(".json") for fname in self.archive_fnames]) + and any([fname == "config.json" for fname in self.archive_fnames]) and any([fname.endswith(".onnx") for fname in self.archive_fnames]) ) def test_onnx(self): """Test if archived ONNX model is valid.""" - try: - onnx.checker.check_model(self.onnx_model, full_check=True) - except onnx.checker.ValidationError as e: - print("The model is invalid: %s" % e) - assert False - else: - assert True + assert onnx.checker.check_model(self.onnx_model, full_check=True) is None def test_config_inputs(self): """Test if archived config inputs are valid."""