From 7143e916a6faef3941c2272c56a76b48e970d6a3 Mon Sep 17 00:00:00 2001 From: klemen1999 Date: Tue, 1 Oct 2024 20:11:54 +0200 Subject: [PATCH] small reordering or parameters --- .../predefined_models/detection_model.py | 10 +++--- .../keypoint_detection_model.py | 28 ++++++++-------- .../predefined_models/segmentation_model.py | 32 +++++++++++-------- 3 files changed, 36 insertions(+), 34 deletions(-) diff --git a/luxonis_train/config/predefined_models/detection_model.py b/luxonis_train/config/predefined_models/detection_model.py index 53fda1eb..86597cca 100644 --- a/luxonis_train/config/predefined_models/detection_model.py +++ b/luxonis_train/config/predefined_models/detection_model.py @@ -47,26 +47,24 @@ def __init__( self, variant: VariantLiteral = "light", use_neck: bool = True, + backbone: str | None = None, backbone_params: Params | None = None, neck_params: Params | None = None, head_params: Params | None = None, loss_params: Params | None = None, - task_name: str | None = None, visualizer_params: Params | None = None, - backbone: str | None = None, + task_name: str | None = None, ): - self.variant = variant - self.use_neck = use_neck - var_config = get_variant(variant) + self.use_neck = use_neck self.backbone_params = backbone_params or var_config.backbone_params + self.backbone = backbone or var_config.backbone self.neck_params = neck_params or {} self.head_params = head_params or {} self.loss_params = loss_params or {"n_warmup_epochs": 0} self.visualizer_params = visualizer_params or {} self.task_name = task_name or "boundingbox" - self.backbone = backbone or var_config.backbone @property def nodes(self) -> list[ModelNodeConfig]: diff --git a/luxonis_train/config/predefined_models/keypoint_detection_model.py b/luxonis_train/config/predefined_models/keypoint_detection_model.py index 38820d7e..09fcd095 100644 --- a/luxonis_train/config/predefined_models/keypoint_detection_model.py +++ b/luxonis_train/config/predefined_models/keypoint_detection_model.py @@ -47,34 +47,32 @@ def __init__( self, variant: VariantLiteral = "light", use_neck: bool = True, + backbone: str | None = None, backbone_params: Params | None = None, neck_params: Params | None = None, - head_params: Params | None = None, - loss_params: Params | None = None, - head_type: Literal[ + head: Literal[ "ImplicitKeypointBBoxHead", "EfficientKeypointBBoxHead" ] = "EfficientKeypointBBoxHead", + head_params: Params | None = None, + loss_params: Params | None = None, kpt_visualizer_params: Params | None = None, bbox_visualizer_params: Params | None = None, bbox_task_name: str | None = None, kpt_task_name: str | None = None, - backbone: str | None = None, ): - self.variant = variant - self.use_neck = use_neck - var_config = get_variant(variant) + self.use_neck = use_neck + self.backbone = backbone or var_config.backbone self.backbone_params = backbone_params or var_config.backbone_params self.neck_params = neck_params or {} + self.head = head self.head_params = head_params or {} self.loss_params = loss_params or {"n_warmup_epochs": 0} self.kpt_visualizer_params = kpt_visualizer_params or {} self.bbox_visualizer_params = bbox_visualizer_params or {} self.bbox_task_name = bbox_task_name self.kpt_task_name = kpt_task_name - self.head_type = head_type - self.backbone = backbone or var_config.backbone @property def nodes(self) -> list[ModelNodeConfig]: @@ -107,11 +105,13 @@ def nodes(self) -> list[ModelNodeConfig]: nodes.append( ModelNodeConfig( - name=self.head_type, + name=self.head, alias="kpt_detection_head", - inputs=["kpt_detection_neck"] - if self.use_neck - else ["kpt_detection_backbone"], + inputs=( + ["kpt_detection_neck"] + if self.use_neck + else ["kpt_detection_backbone"] + ), freezing=self.head_params.pop("freezing", {}), params=self.head_params, task=task, @@ -124,7 +124,7 @@ def losses(self) -> list[LossModuleConfig]: """Defines the loss module for the keypoint detection task.""" return [ LossModuleConfig( - name=self.head_type.replace("Head", "Loss"), + name=self.head.replace("Head", "Loss"), attached_to="kpt_detection_head", params=self.loss_params, weight=1.0, diff --git a/luxonis_train/config/predefined_models/segmentation_model.py b/luxonis_train/config/predefined_models/segmentation_model.py index 8f03d480..6a155506 100644 --- a/luxonis_train/config/predefined_models/segmentation_model.py +++ b/luxonis_train/config/predefined_models/segmentation_model.py @@ -46,19 +46,18 @@ class SegmentationModel(BasePredefinedModel): def __init__( self, variant: VariantLiteral = "light", - task: Literal["binary", "multiclass"] = "binary", + backbone: str | None = None, backbone_params: Params | None = None, head_params: Params | None = None, aux_head_params: Params | None = None, loss_params: Params | None = None, visualizer_params: Params | None = None, + task: Literal["binary", "multiclass"] = "binary", task_name: str | None = None, - backbone: str | None = None, ): - self.variant = variant - var_config = get_variant(variant) + self.backbone = backbone or var_config.backbone self.backbone_params = backbone_params or var_config.backbone_params self.head_params = head_params or {} self.aux_head_params = aux_head_params or {} @@ -66,16 +65,17 @@ def __init__( self.visualizer_params = visualizer_params or {} self.task = task self.task_name = task_name or "segmentation" - self.backbone = backbone or var_config.backbone @property def nodes(self) -> list[ModelNodeConfig]: """Defines the model nodes, including backbone and head.""" self.head_params.update({"attach_index": -1}) self.aux_head_params.update({"attach_index": -2}) - self.aux_head_params.update( - {"remove_on_export": True} - ) if "remove_on_export" not in self.aux_head_params else None + ( + self.aux_head_params.update({"remove_on_export": True}) + if "remove_on_export" not in self.aux_head_params + else None + ) node_list = [ ModelNodeConfig( @@ -111,9 +111,11 @@ def losses(self) -> list[LossModuleConfig]: """Defines the loss module for the segmentation task.""" loss_list = [ LossModuleConfig( - name="BCEWithLogitsLoss" - if self.task == "binary" - else "CrossEntropyLoss", + name=( + "BCEWithLogitsLoss" + if self.task == "binary" + else "CrossEntropyLoss" + ), alias="segmentation_loss", attached_to="segmentation_head", params=self.loss_params, @@ -123,9 +125,11 @@ def losses(self) -> list[LossModuleConfig]: if self.backbone_params.get("use_aux_heads", False): loss_list.append( LossModuleConfig( - name="BCEWithLogitsLoss" - if self.task == "binary" - else "CrossEntropyLoss", + name=( + "BCEWithLogitsLoss" + if self.task == "binary" + else "CrossEntropyLoss" + ), alias="aux_segmentation_loss", attached_to="aux_segmentation_head", params=self.loss_params,