Skip to content

Commit

Permalink
added Self types
Browse files Browse the repository at this point in the history
  • Loading branch information
kozlov721 committed Aug 3, 2024
1 parent d7d39f3 commit eb89c86
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions luxonis_train/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class ModelConfig(CustomBaseModel):
outputs: list[str] = []

@model_validator(mode="after")
def check_predefined_model(self):
def check_predefined_model(self) -> Self:
from luxonis_train.utils.registry import MODELS

if self.predefined_model:
Expand All @@ -85,7 +85,7 @@ def check_predefined_model(self):
return self

@model_validator(mode="after")
def check_graph(self):
def check_graph(self) -> Self:
from luxonis_train.utils.general import is_acyclic

graph = {node.alias or node.name: node.inputs for node in self.nodes}
Expand All @@ -104,7 +104,7 @@ def check_graph(self):
return self

@model_validator(mode="after")
def check_unique_names(self):
def check_unique_names(self) -> Self:
for section, objects in [
("nodes", self.nodes),
("losses", self.losses),
Expand Down Expand Up @@ -165,7 +165,7 @@ class PreprocessingConfig(CustomBaseModel):
augmentations: list[AugmentationConfig] = []

@model_validator(mode="after")
def check_normalize(self):
def check_normalize(self) -> Self:
if self.normalize.active:
self.augmentations.append(
AugmentationConfig(name="Normalize", params=self.normalize.params)
Expand Down Expand Up @@ -227,7 +227,7 @@ class TrainerConfig(CustomBaseModel):
scheduler: SchedulerConfig = SchedulerConfig()

@model_validator(mode="after")
def check_num_workes_platform(self):
def check_num_workes_platform(self) -> Self:
if (
sys.platform == "win32" or sys.platform == "darwin"
) and self.num_workers != 0:
Expand All @@ -238,7 +238,7 @@ def check_num_workes_platform(self):
return self

@model_validator(mode="after")
def check_validation_interval(self):
def check_validation_interval(self) -> Self:
if self.validation_interval > self.epochs:
logger.warning(
"Setting `validation_interval` same as `epochs` otherwise no checkpoint would be generated."
Expand Down Expand Up @@ -272,7 +272,7 @@ class ExportConfig(CustomBaseModel):
upload_url: str | None = None

@model_validator(mode="after")
def check_values(self):
def check_values(self) -> Self:
def pad_values(values: float | list[float] | None):
if values is None:
return None
Expand Down

0 comments on commit eb89c86

Please sign in to comment.