diff --git a/luxonis_train/core/core.py b/luxonis_train/core/core.py index 2b1607ad..16953062 100644 --- a/luxonis_train/core/core.py +++ b/luxonis_train/core/core.py @@ -268,7 +268,7 @@ def thread_exception_hook(args): self.thread.start() def export( - self, onnx_save_path: str | None = None, *, weights: str | None = None + self, onnx_save_path: str | None = None, *, weights: str | Path | None = None ) -> None: """Runs export. @@ -429,7 +429,6 @@ def _objective(trial: optuna.trial.Trial) -> float: for a in cfg_copy.trainer.preprocessing.augmentations if a.name != "Normalize" ] # manually remove Normalize so it doesn't duplicate it when creating new cfg instance - Config.clear_instance() cfg = Config.get_config(cfg_copy.model_dump(), curr_params) child_tracker.log_hyperparams(curr_params) diff --git a/luxonis_train/utils/loaders/base_loader.py b/luxonis_train/utils/loaders/base_loader.py index e18d7f5e..5e884955 100644 --- a/luxonis_train/utils/loaders/base_loader.py +++ b/luxonis_train/utils/loaders/base_loader.py @@ -103,7 +103,7 @@ def get_classes(self) -> dict[str, list[str]]: @rtype: dict[LabelType, list[str]] @return: A dictionary mapping tasks to their classes. """ - pass + ... def get_n_keypoints(self) -> dict[str, int] | None: """Returns the dictionary defining the semantic skeleton for each class using diff --git a/luxonis_train/utils/loaders/luxonis_loader_torch.py b/luxonis_train/utils/loaders/luxonis_loader_torch.py index 4a8b505e..328f87be 100644 --- a/luxonis_train/utils/loaders/luxonis_loader_torch.py +++ b/luxonis_train/utils/loaders/luxonis_loader_torch.py @@ -1,6 +1,6 @@ import logging from pathlib import Path -from typing import Literal, cast +from typing import Literal import numpy as np from luxonis_ml.data import ( @@ -158,13 +158,10 @@ def _parse_dataset( logger.info(f"Parsing dataset from {dataset_dir} with name '{dataset_name}'") - return cast( - LuxonisDataset, - LuxonisParser( - dataset_dir, - dataset_name=dataset_name, - dataset_type=dataset_type, - save_dir="data", - delete_existing=True, - ).parse(), - ) + return LuxonisParser( + dataset_dir, + dataset_name=dataset_name, + dataset_type=dataset_type, + save_dir="data", + delete_existing=True, + ).parse() diff --git a/luxonis_train/utils/types.py b/luxonis_train/utils/types.py index 375ab565..84b8e019 100644 --- a/luxonis_train/utils/types.py +++ b/luxonis_train/utils/types.py @@ -45,15 +45,7 @@ def from_missing_task(cls, task: str, present_tasks: list[str], class_name: str) class BaseProtocol(BaseModel): class Config: arbitrary_types_allowed = True - - @classmethod - def get_task(cls) -> str: - if len(cls.__annotations__) == 1: - return list(cls.__annotations__)[0] - raise ValueError( - "Protocol must have exactly one field for automatic task inference. " - "Implement custom `prepare` method in your attached module." - ) + extra = "forbid" class FeaturesProtocol(BaseProtocol): diff --git a/tests/integration/test_sanity.py b/tests/integration/test_sanity.py index cf7af8aa..5afa385b 100644 --- a/tests/integration/test_sanity.py +++ b/tests/integration/test_sanity.py @@ -10,7 +10,6 @@ from multi_input_modules import * from luxonis_train.core import LuxonisModel -from luxonis_train.utils.config import Config TEST_OUTPUT = Path("tests/integration/_test-output") INFER_PATH = Path("tests/integration/_infer_save_dir") @@ -35,7 +34,6 @@ def manage_out_dir(): @pytest.fixture(scope="function", autouse=True) def clear_files(): - Config.clear_instance() yield STUDY_PATH.unlink(missing_ok=True) ONNX_PATH.unlink(missing_ok=True)