diff --git a/luxonis_train/config/predefined_models/classification_model.py b/luxonis_train/config/predefined_models/classification_model.py index 5cb3dc7d..a69a415f 100644 --- a/luxonis_train/config/predefined_models/classification_model.py +++ b/luxonis_train/config/predefined_models/classification_model.py @@ -12,7 +12,7 @@ from .base_predefined_model import BasePredefinedModel -VariantLiteral: TypeAlias = Literal["lite", "heavy"] +VariantLiteral: TypeAlias = Literal["light", "heavy"] class ClassificationVariant(BaseModel): @@ -46,25 +46,23 @@ class ClassificationModel(BasePredefinedModel): def __init__( self, variant: VariantLiteral = "light", - task: Literal["multiclass", "multilabel"] = "multiclass", + backbone: str | None = None, backbone_params: Params | None = None, head_params: Params | None = None, loss_params: Params | None = None, - task_name: str | None = None, visualizer_params: Params | None = None, - backbone: str | None = None, + task: Literal["multiclass", "multilabel"] = "multiclass", + task_name: str | None = None, ): - self.variant = variant - var_config = get_variant(variant) + self.backbone = backbone or var_config.backbone self.backbone_params = backbone_params or var_config.backbone_params self.head_params = head_params or {} self.loss_params = loss_params or {} self.visualizer_params = visualizer_params or {} self.task = task self.task_name = task_name or "classification" - self.backbone = backbone or var_config.backbone @property def nodes(self) -> list[ModelNodeConfig]: