From f00a9b4f1de57970a8bc6e05f3edbb7c4a5d57c4 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Mon, 19 Aug 2024 16:06:55 +0200 Subject: [PATCH] more strict config types --- luxonis_train/utils/config.py | 60 +++++++++++++++++++++-------------- 1 file changed, 36 insertions(+), 24 deletions(-) diff --git a/luxonis_train/utils/config.py b/luxonis_train/utils/config.py index 9c4a0f2c..dfa427b5 100644 --- a/luxonis_train/utils/config.py +++ b/luxonis_train/utils/config.py @@ -1,6 +1,6 @@ import logging import sys -from typing import Annotated, Any, Literal +from typing import Annotated, Any, Literal, TypeAlias from luxonis_ml.data import LabelType from luxonis_ml.utils import ( @@ -10,20 +10,32 @@ LuxonisFileSystem, ) from pydantic import Field, model_validator +from pydantic.types import FilePath, NonNegativeFloat, NonNegativeInt, PositiveInt from typing_extensions import Self logger = logging.getLogger(__name__) +Params: TypeAlias = dict[str, Any] + class AttachedModuleConfig(BaseModelExtraForbid): name: str attached_to: str alias: str | None = None - params: dict[str, Any] = {} + params: Params = {} class LossModuleConfig(AttachedModuleConfig): - weight: float = 1.0 + weight: NonNegativeFloat = 1.0 + + @model_validator(mode="after") + def validate_weight(self) -> Self: + if self.weight == 0: + logger.warning( + f"Loss '{self.name}' has weight set to 0. " + "This loss will not contribute to the training." + ) + return self class MetricModuleConfig(AttachedModuleConfig): @@ -32,7 +44,7 @@ class MetricModuleConfig(AttachedModuleConfig): class FreezingConfig(BaseModelExtraForbid): active: bool = False - unfreeze_after: int | float | None = None + unfreeze_after: NonNegativeInt | NonNegativeFloat | None = None class ModelNodeConfig(BaseModelExtraForbid): @@ -40,24 +52,24 @@ class ModelNodeConfig(BaseModelExtraForbid): alias: str | None = None inputs: list[str] = [] # From preceding nodes input_sources: list[str] = [] # From data loader - params: dict[str, Any] = {} freezing: FreezingConfig = FreezingConfig() task: str | dict[LabelType, str] | None = None + params: Params = {} class PredefinedModelConfig(BaseModelExtraForbid): name: str - params: dict[str, Any] = {} include_nodes: bool = True include_losses: bool = True include_metrics: bool = True include_visualizers: bool = True + params: Params = {} class ModelConfig(BaseModelExtraForbid): name: str = "model" predefined_model: PredefinedModelConfig | None = None - weights: str | None = None + weights: FilePath | None = None nodes: list[ModelNodeConfig] = [] losses: list[LossModuleConfig] = [] metrics: list[MetricModuleConfig] = [] @@ -146,7 +158,7 @@ class LoaderConfig(BaseModelExtraForbid): train_view: str = "train" val_view: str = "val" test_view: str = "test" - params: dict[str, Any] = {} + params: Params = {} class NormalizeAugmentationConfig(BaseModelExtraForbid): @@ -160,7 +172,7 @@ class NormalizeAugmentationConfig(BaseModelExtraForbid): class AugmentationConfig(BaseModelExtraForbid): name: str active: bool = True - params: dict[str, Any] = {} + params: Params = {} class PreprocessingConfig(BaseModelExtraForbid): @@ -192,23 +204,23 @@ def get_active_augmentations(self) -> list[AugmentationConfig]: class CallbackConfig(BaseModelExtraForbid): name: str active: bool = True - params: dict[str, Any] = {} + params: Params = {} class OptimizerConfig(BaseModelExtraForbid): name: str = "Adam" - params: dict[str, Any] = {} + params: Params = {} class SchedulerConfig(BaseModelExtraForbid): name: str = "ConstantLR" - params: dict[str, Any] = {} + params: Params = {} class TrainerConfig(BaseModelExtraForbid): preprocessing: PreprocessingConfig = PreprocessingConfig() - accelerator: Literal["auto", "cpu", "gpu"] = "auto" + accelerator: Literal["auto", "cpu", "gpu", "tpu"] = "auto" devices: int | list[int] | str = "auto" strategy: Literal["auto", "ddp"] = "auto" num_sanity_val_steps: int = 2 @@ -217,17 +229,17 @@ class TrainerConfig(BaseModelExtraForbid): verbose: bool = True seed: int | None = None - batch_size: int = 32 - accumulate_grad_batches: int = 1 + batch_size: PositiveInt = 32 + accumulate_grad_batches: PositiveInt = 1 use_weighted_sampler: bool = False - epochs: int = 100 - num_workers: int = 2 - train_metrics_interval: int = -1 - validation_interval: int = 1 - num_log_images: int = 4 + epochs: PositiveInt = 100 + num_workers: NonNegativeInt = 4 + train_metrics_interval: Literal[-1] | PositiveInt = -1 + validation_interval: Literal[-1] | PositiveInt = 1 + num_log_images: NonNegativeInt = 4 skip_last_batch: bool = True log_sub_losses: bool = True - save_top_k: int = 3 + save_top_k: Literal[-1] | NonNegativeInt = 3 callbacks: list[CallbackConfig] = [] @@ -256,7 +268,7 @@ def check_validation_interval(self) -> Self: class OnnxExportConfig(BaseModelExtraForbid): - opset_version: int = 12 + opset_version: PositiveInt = 12 dynamic_axes: dict[str, Any] | None = None @@ -305,8 +317,8 @@ class TunerConfig(BaseModelExtraForbid): study_name: str = "test-study" continue_existing_study: bool = True use_pruner: bool = True - n_trials: int | None = 15 - timeout: int | None = None + n_trials: PositiveInt | None = 15 + timeout: PositiveInt | None = None storage: StorageConfig = StorageConfig() params: Annotated[ dict[str, list[str | int | float | bool | list]],