diff --git a/luxonis_train/utils/config.py b/luxonis_train/utils/config.py index 687dfdd1..2e8460ca 100644 --- a/luxonis_train/utils/config.py +++ b/luxonis_train/utils/config.py @@ -63,7 +63,7 @@ class ModelConfig(CustomBaseModel): outputs: list[str] = [] @model_validator(mode="after") - def check_predefined_model(self): + def check_predefined_model(self) -> Self: from luxonis_train.utils.registry import MODELS if self.predefined_model: @@ -85,7 +85,7 @@ def check_predefined_model(self): return self @model_validator(mode="after") - def check_graph(self): + def check_graph(self) -> Self: from luxonis_train.utils.general import is_acyclic graph = {node.alias or node.name: node.inputs for node in self.nodes} @@ -104,7 +104,7 @@ def check_graph(self): return self @model_validator(mode="after") - def check_unique_names(self): + def check_unique_names(self) -> Self: for section, objects in [ ("nodes", self.nodes), ("losses", self.losses), @@ -165,7 +165,7 @@ class PreprocessingConfig(CustomBaseModel): augmentations: list[AugmentationConfig] = [] @model_validator(mode="after") - def check_normalize(self): + def check_normalize(self) -> Self: if self.normalize.active: self.augmentations.append( AugmentationConfig(name="Normalize", params=self.normalize.params) @@ -227,7 +227,7 @@ class TrainerConfig(CustomBaseModel): scheduler: SchedulerConfig = SchedulerConfig() @model_validator(mode="after") - def check_num_workes_platform(self): + def check_num_workes_platform(self) -> Self: if ( sys.platform == "win32" or sys.platform == "darwin" ) and self.num_workers != 0: @@ -238,7 +238,7 @@ def check_num_workes_platform(self): return self @model_validator(mode="after") - def check_validation_interval(self): + def check_validation_interval(self) -> Self: if self.validation_interval > self.epochs: logger.warning( "Setting `validation_interval` same as `epochs` otherwise no checkpoint would be generated." @@ -272,7 +272,7 @@ class ExportConfig(CustomBaseModel): upload_url: str | None = None @model_validator(mode="after") - def check_values(self): + def check_values(self) -> Self: def pad_values(values: float | list[float] | None): if values is None: return None