Skip to content

Commit

Permalink
more strict config types
Browse files Browse the repository at this point in the history
  • Loading branch information
kozlov721 committed Aug 19, 2024
1 parent 9b17a70 commit f00a9b4
Showing 1 changed file with 36 additions and 24 deletions.
60 changes: 36 additions & 24 deletions luxonis_train/utils/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import sys
from typing import Annotated, Any, Literal
from typing import Annotated, Any, Literal, TypeAlias

from luxonis_ml.data import LabelType
from luxonis_ml.utils import (
Expand All @@ -10,20 +10,32 @@
LuxonisFileSystem,
)
from pydantic import Field, model_validator
from pydantic.types import FilePath, NonNegativeFloat, NonNegativeInt, PositiveInt
from typing_extensions import Self

logger = logging.getLogger(__name__)

Params: TypeAlias = dict[str, Any]


class AttachedModuleConfig(BaseModelExtraForbid):
name: str
attached_to: str
alias: str | None = None
params: dict[str, Any] = {}
params: Params = {}


class LossModuleConfig(AttachedModuleConfig):
weight: float = 1.0
weight: NonNegativeFloat = 1.0

@model_validator(mode="after")
def validate_weight(self) -> Self:
if self.weight == 0:
logger.warning(
f"Loss '{self.name}' has weight set to 0. "
"This loss will not contribute to the training."
)
return self


class MetricModuleConfig(AttachedModuleConfig):
Expand All @@ -32,32 +44,32 @@ class MetricModuleConfig(AttachedModuleConfig):

class FreezingConfig(BaseModelExtraForbid):
active: bool = False
unfreeze_after: int | float | None = None
unfreeze_after: NonNegativeInt | NonNegativeFloat | None = None


class ModelNodeConfig(BaseModelExtraForbid):
name: str
alias: str | None = None
inputs: list[str] = [] # From preceding nodes
input_sources: list[str] = [] # From data loader
params: dict[str, Any] = {}
freezing: FreezingConfig = FreezingConfig()
task: str | dict[LabelType, str] | None = None
params: Params = {}


class PredefinedModelConfig(BaseModelExtraForbid):
name: str
params: dict[str, Any] = {}
include_nodes: bool = True
include_losses: bool = True
include_metrics: bool = True
include_visualizers: bool = True
params: Params = {}


class ModelConfig(BaseModelExtraForbid):
name: str = "model"
predefined_model: PredefinedModelConfig | None = None
weights: str | None = None
weights: FilePath | None = None
nodes: list[ModelNodeConfig] = []
losses: list[LossModuleConfig] = []
metrics: list[MetricModuleConfig] = []
Expand Down Expand Up @@ -146,7 +158,7 @@ class LoaderConfig(BaseModelExtraForbid):
train_view: str = "train"
val_view: str = "val"
test_view: str = "test"
params: dict[str, Any] = {}
params: Params = {}


class NormalizeAugmentationConfig(BaseModelExtraForbid):
Expand All @@ -160,7 +172,7 @@ class NormalizeAugmentationConfig(BaseModelExtraForbid):
class AugmentationConfig(BaseModelExtraForbid):
name: str
active: bool = True
params: dict[str, Any] = {}
params: Params = {}


class PreprocessingConfig(BaseModelExtraForbid):
Expand Down Expand Up @@ -192,23 +204,23 @@ def get_active_augmentations(self) -> list[AugmentationConfig]:
class CallbackConfig(BaseModelExtraForbid):
name: str
active: bool = True
params: dict[str, Any] = {}
params: Params = {}


class OptimizerConfig(BaseModelExtraForbid):
name: str = "Adam"
params: dict[str, Any] = {}
params: Params = {}


class SchedulerConfig(BaseModelExtraForbid):
name: str = "ConstantLR"
params: dict[str, Any] = {}
params: Params = {}


class TrainerConfig(BaseModelExtraForbid):
preprocessing: PreprocessingConfig = PreprocessingConfig()

accelerator: Literal["auto", "cpu", "gpu"] = "auto"
accelerator: Literal["auto", "cpu", "gpu", "tpu"] = "auto"
devices: int | list[int] | str = "auto"
strategy: Literal["auto", "ddp"] = "auto"
num_sanity_val_steps: int = 2
Expand All @@ -217,17 +229,17 @@ class TrainerConfig(BaseModelExtraForbid):
verbose: bool = True

seed: int | None = None
batch_size: int = 32
accumulate_grad_batches: int = 1
batch_size: PositiveInt = 32
accumulate_grad_batches: PositiveInt = 1
use_weighted_sampler: bool = False
epochs: int = 100
num_workers: int = 2
train_metrics_interval: int = -1
validation_interval: int = 1
num_log_images: int = 4
epochs: PositiveInt = 100
num_workers: NonNegativeInt = 4
train_metrics_interval: Literal[-1] | PositiveInt = -1
validation_interval: Literal[-1] | PositiveInt = 1
num_log_images: NonNegativeInt = 4
skip_last_batch: bool = True
log_sub_losses: bool = True
save_top_k: int = 3
save_top_k: Literal[-1] | NonNegativeInt = 3

callbacks: list[CallbackConfig] = []

Expand Down Expand Up @@ -256,7 +268,7 @@ def check_validation_interval(self) -> Self:


class OnnxExportConfig(BaseModelExtraForbid):
opset_version: int = 12
opset_version: PositiveInt = 12
dynamic_axes: dict[str, Any] | None = None


Expand Down Expand Up @@ -305,8 +317,8 @@ class TunerConfig(BaseModelExtraForbid):
study_name: str = "test-study"
continue_existing_study: bool = True
use_pruner: bool = True
n_trials: int | None = 15
timeout: int | None = None
n_trials: PositiveInt | None = 15
timeout: PositiveInt | None = None
storage: StorageConfig = StorageConfig()
params: Annotated[
dict[str, list[str | int | float | bool | list]],
Expand Down

0 comments on commit f00a9b4

Please sign in to comment.