Skip to content

Commit

Permalink
fix in classification predefined model
Browse files Browse the repository at this point in the history
  • Loading branch information
klemen1999 committed Oct 1, 2024
1 parent c987592 commit 1078457
Showing 1 changed file with 5 additions and 7 deletions.
12 changes: 5 additions & 7 deletions luxonis_train/config/predefined_models/classification_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from .base_predefined_model import BasePredefinedModel

VariantLiteral: TypeAlias = Literal["lite", "heavy"]
VariantLiteral: TypeAlias = Literal["light", "heavy"]


class ClassificationVariant(BaseModel):
Expand Down Expand Up @@ -46,25 +46,23 @@ class ClassificationModel(BasePredefinedModel):
def __init__(
self,
variant: VariantLiteral = "light",
task: Literal["multiclass", "multilabel"] = "multiclass",
backbone: str | None = None,
backbone_params: Params | None = None,
head_params: Params | None = None,
loss_params: Params | None = None,
task_name: str | None = None,
visualizer_params: Params | None = None,
backbone: str | None = None,
task: Literal["multiclass", "multilabel"] = "multiclass",
task_name: str | None = None,
):
self.variant = variant

var_config = get_variant(variant)

self.backbone = backbone or var_config.backbone
self.backbone_params = backbone_params or var_config.backbone_params
self.head_params = head_params or {}
self.loss_params = loss_params or {}
self.visualizer_params = visualizer_params or {}
self.task = task
self.task_name = task_name or "classification"
self.backbone = backbone or var_config.backbone

@property
def nodes(self) -> list[ModelNodeConfig]:
Expand Down

0 comments on commit 1078457

Please sign in to comment.