Skip to content

Commit

Permalink
add LDF creation and adjust to new nn archive format
Browse files Browse the repository at this point in the history
  • Loading branch information
jkbmrz committed Mar 13, 2024
1 parent 7302204 commit f2536f6
Showing 1 changed file with 68 additions and 29 deletions.
97 changes: 68 additions & 29 deletions tests/unittests/test_core/test_archiver.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,83 @@
import pytest
import torch
import torchvision
import io
import json
import os
import random
import shutil
import tarfile

import cv2
import numpy as np
import onnx
import json
import torch
import torchvision
import yaml
from luxonis_ml.data import LuxonisDataset

import luxonis_train
from luxonis_train.core import Archiver


class TestArchiver:

@classmethod
def setup_class(cls):
"""Create and load all files required for testing."""

luxonis_train_parent_dir = os.path.dirname(
os.path.dirname(luxonis_train.__file__)
)
tmp_test_path = os.path.join(
cls.tmp_path = os.path.join(
luxonis_train_parent_dir,
"tests",
"unittests",
"test_core",
)

# make LDF
os.mkdir(os.path.join(cls.tmp_path, "images"))
cls.ldf_name = "dummyLDF"
labels = ["label1", "label2", "label3"]

def classification_dataset_generator():
for i in range(10):
img = np.random.randint(0, 256, (50, 50, 3), dtype=np.uint8)
img_file_path = os.path.join(cls.tmp_path, "images", f"img{i}.png")
cv2.imwrite(img_file_path, img)
yield {
"file": img_file_path,
"type": "classification",
"value": True,
"class": random.choice(labels),
}

if LuxonisDataset.exists(cls.ldf_name):
print("Deleting existing dataset")
LuxonisDataset(cls.ldf_name).delete_dataset()
dataset = LuxonisDataset(cls.ldf_name)
dataset.add(classification_dataset_generator)
dataset.set_classes(list(labels))
dataset.make_splits()

# make config
config_dict = {
"model": {"name": "dummy", "predefined_model": {"name": "DetectionModel"}},
"dataset": {
"name": "dummyldf"
}, # TODO: set LDF name automatically - choose one random LDF or make a random LDF!
"model": {
"name": "test_model",
"predefined_model": {"name": "ClassificationModel"},
},
"dataset": {"name": cls.ldf_name},
"archiver": {
"archive_name": "tmp_nn_archive",
"archive_save_directory": cls.tmp_path,
},
}
cls.config_path = os.path.join(tmp_test_path, "tmp_config.yaml")
cls.config_path = os.path.join(cls.tmp_path, "tmp_config.yaml")
with open(cls.config_path, "w") as yaml_file:
yaml_str = yaml.dump(config_dict)
yaml_file.write(yaml_str)

# make model
model = torchvision.models.squeezenet1_0(pretrained=False)
cls.model_path = "tmp_squeezenet1_0.onnx"
model = torchvision.models.mobilenet_v2(pretrained=False)
cls.model_path = os.path.join(cls.tmp_path, "tmp_mobilenet_v2.onnx")

n, c, h, w = 1, 3, 224, 224
input_shape = torch.randn(n, c, h, w)
cls.input_names = ["TestInput"]
Expand All @@ -61,43 +97,46 @@ def setup_class(cls):

# load archive files into memory
with tarfile.open(cls.archive_path, mode="r") as tar:
cls.archive_fnames = tar.getnames() # List all the contents of the tar file
cls.archive_fnames = tar.getnames()
for fname in cls.archive_fnames:
f = tar.extractfile(fname)
if fname.endswith(".json"):
json_file = tar.extractfile(fname)
cls.json_dict = json.load(json_file)
cls.json_dict = json.load(f)
elif fname.endswith(".onnx"):
onnx_file = tar.extractfile(fname)
cls.onnx_model = onnx.load(onnx_file)
model_bytes = f.read()
model_io = io.BytesIO(model_bytes)
cls.onnx_model = onnx.load(model_io)

@classmethod
def teardown_class(cls):
"""Remove all created files."""
os.remove(cls.archive_path)
os.remove(cls.config_path)
os.remove(cls.model_path)
LuxonisDataset(cls.ldf_name).delete_dataset()
shutil.rmtree(os.path.join(cls.tmp_path, "images"))

def test_archive_suffix(self):
def test_archive_creation(self):
"""Test if nn_archive was created."""
assert self.archive_path.endswith("tar.gz")
assert os.path.exists(self.archive_path)

def test_archive_suffix(self):
"""Test if nn_archive is compressed using xz option (should be the default
option)."""
assert self.archive_path.endswith("tar.xz")

def test_archive_contents(self):
"""Test if nn_archive consists of exactly one JSON and one ONNX file."""
"""Test if nn_archive consists of exactly one JSON file named config.json and
one ONNX file."""
assert (
len(self.archive_fnames) == 2
and any([fname.endswith(".json") for fname in self.archive_fnames])
and any([fname == "config.json" for fname in self.archive_fnames])
and any([fname.endswith(".onnx") for fname in self.archive_fnames])
)

def test_onnx(self):
"""Test if archived ONNX model is valid."""
try:
onnx.checker.check_model(self.onnx_model, full_check=True)
except onnx.checker.ValidationError as e:
print("The model is invalid: %s" % e)
assert False
else:
assert True
assert onnx.checker.check_model(self.onnx_model, full_check=True) is None

def test_config_inputs(self):
"""Test if archived config inputs are valid."""
Expand Down

0 comments on commit f2536f6

Please sign in to comment.