From 7f12681b26b5969768b279546a9fd63b87972f01 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Wed, 25 Sep 2024 10:10:42 +0200 Subject: [PATCH] fixed tests --- tests/integration/conftest.py | 2 +- tests/integration/test_detection.py | 2 +- tests/integration/test_segmentation.py | 2 +- tests/unittests/test_base_node.py | 2 +- tests/unittests/test_utils/test_boxutils.py | 11 ++++++----- 5 files changed, 10 insertions(+), 9 deletions(-) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index ef5a2142..97189476 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -248,7 +248,7 @@ def CIFAR10_subset_generator(): @pytest.fixture def config(train_overfit: bool) -> dict[str, Any]: - if train_overfit: + if train_overfit: # pragma: no cover epochs = 100 else: epochs = 1 diff --git a/tests/integration/test_detection.py b/tests/integration/test_detection.py index fb184b6f..0ae63e6d 100644 --- a/tests/integration/test_detection.py +++ b/tests/integration/test_detection.py @@ -78,7 +78,7 @@ def train_and_test( model = LuxonisModel(config, opts) model.train() results = model.test(view="val") - if train_overfit: + if train_overfit: # pragma: no cover for name, value in results.items(): if "/map_50" in name or "/kpt_map_medium" in name: assert value > 0.8, f"{name} = {value} (expected > 0.8)" diff --git a/tests/integration/test_segmentation.py b/tests/integration/test_segmentation.py index c24e6fb9..72db55eb 100644 --- a/tests/integration/test_segmentation.py +++ b/tests/integration/test_segmentation.py @@ -117,7 +117,7 @@ def train_and_test( model = LuxonisModel(config, opts) model.train() results = model.test(view="val") - if train_overfit: + if train_overfit: # pragma: no cover for name, value in results.items(): if "metric" in name: assert value > 0.8, f"{name} = {value} (expected > 0.8)" diff --git a/tests/unittests/test_base_node.py b/tests/unittests/test_base_node.py index 79d9ec49..3ed284c3 100644 --- a/tests/unittests/test_base_node.py +++ b/tests/unittests/test_base_node.py @@ -76,7 +76,7 @@ def test_invalid(packet: Packet[Tensor]): node.wrap({"inp": torch.rand(3, 224, 224)}) -def tets_in_sizes(): +def test_in_sizes(): node = DummyNode( input_shapes=[{"features": [Size((3, 224, 224)) for _ in range(3)]}] ) diff --git a/tests/unittests/test_utils/test_boxutils.py b/tests/unittests/test_utils/test_boxutils.py index 2b05a428..938ddc3e 100644 --- a/tests/unittests/test_utils/test_boxutils.py +++ b/tests/unittests/test_utils/test_boxutils.py @@ -1,3 +1,5 @@ +from typing import Literal + import pytest import torch @@ -13,7 +15,10 @@ def generate_random_bboxes( - n_bboxes: int, max_width: int, max_height: int, format: str = "xyxy" + n_bboxes: int, + max_width: int, + max_height: int, + format: Literal["xyxy", "xywh", "cxcywh"], ): x1y1 = torch.rand(n_bboxes, 2) * torch.tensor( [max_width - 1, max_height - 1] @@ -33,10 +38,6 @@ def generate_random_bboxes( elif format == "cxcywh": cxcy = x1y1 + wh / 2 bboxes = torch.cat((cxcy, wh), dim=1) - else: - raise ValueError( - "Unsupported format. Choose from 'xyxy', 'xywh', 'cxcywh'." - ) return bboxes