From 1078457e78d6499fa3e89f3f531959910a3fa493 Mon Sep 17 00:00:00 2001 From: klemen1999 Date: Tue, 1 Oct 2024 20:04:42 +0200 Subject: [PATCH] fix in classification predefined model --- .../config/predefined_models/classification_model.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/luxonis_train/config/predefined_models/classification_model.py b/luxonis_train/config/predefined_models/classification_model.py index 5cb3dc7d..a69a415f 100644 --- a/luxonis_train/config/predefined_models/classification_model.py +++ b/luxonis_train/config/predefined_models/classification_model.py @@ -12,7 +12,7 @@ from .base_predefined_model import BasePredefinedModel -VariantLiteral: TypeAlias = Literal["lite", "heavy"] +VariantLiteral: TypeAlias = Literal["light", "heavy"] class ClassificationVariant(BaseModel): @@ -46,25 +46,23 @@ class ClassificationModel(BasePredefinedModel): def __init__( self, variant: VariantLiteral = "light", - task: Literal["multiclass", "multilabel"] = "multiclass", + backbone: str | None = None, backbone_params: Params | None = None, head_params: Params | None = None, loss_params: Params | None = None, - task_name: str | None = None, visualizer_params: Params | None = None, - backbone: str | None = None, + task: Literal["multiclass", "multilabel"] = "multiclass", + task_name: str | None = None, ): - self.variant = variant - var_config = get_variant(variant) + self.backbone = backbone or var_config.backbone self.backbone_params = backbone_params or var_config.backbone_params self.head_params = head_params or {} self.loss_params = loss_params or {} self.visualizer_params = visualizer_params or {} self.task = task self.task_name = task_name or "classification" - self.backbone = backbone or var_config.backbone @property def nodes(self) -> list[ModelNodeConfig]: