Skip to content

Commit

Permalink
small reordering or parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
klemen1999 committed Oct 1, 2024
1 parent 1078457 commit 7143e91
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 34 deletions.
10 changes: 4 additions & 6 deletions luxonis_train/config/predefined_models/detection_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,26 +47,24 @@ def __init__(
self,
variant: VariantLiteral = "light",
use_neck: bool = True,
backbone: str | None = None,
backbone_params: Params | None = None,
neck_params: Params | None = None,
head_params: Params | None = None,
loss_params: Params | None = None,
task_name: str | None = None,
visualizer_params: Params | None = None,
backbone: str | None = None,
task_name: str | None = None,
):
self.variant = variant
self.use_neck = use_neck

var_config = get_variant(variant)

self.use_neck = use_neck
self.backbone_params = backbone_params or var_config.backbone_params
self.backbone = backbone or var_config.backbone
self.neck_params = neck_params or {}
self.head_params = head_params or {}
self.loss_params = loss_params or {"n_warmup_epochs": 0}
self.visualizer_params = visualizer_params or {}
self.task_name = task_name or "boundingbox"
self.backbone = backbone or var_config.backbone

@property
def nodes(self) -> list[ModelNodeConfig]:
Expand Down
28 changes: 14 additions & 14 deletions luxonis_train/config/predefined_models/keypoint_detection_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,34 +47,32 @@ def __init__(
self,
variant: VariantLiteral = "light",
use_neck: bool = True,
backbone: str | None = None,
backbone_params: Params | None = None,
neck_params: Params | None = None,
head_params: Params | None = None,
loss_params: Params | None = None,
head_type: Literal[
head: Literal[
"ImplicitKeypointBBoxHead", "EfficientKeypointBBoxHead"
] = "EfficientKeypointBBoxHead",
head_params: Params | None = None,
loss_params: Params | None = None,
kpt_visualizer_params: Params | None = None,
bbox_visualizer_params: Params | None = None,
bbox_task_name: str | None = None,
kpt_task_name: str | None = None,
backbone: str | None = None,
):
self.variant = variant
self.use_neck = use_neck

var_config = get_variant(variant)

self.use_neck = use_neck
self.backbone = backbone or var_config.backbone
self.backbone_params = backbone_params or var_config.backbone_params
self.neck_params = neck_params or {}
self.head = head
self.head_params = head_params or {}
self.loss_params = loss_params or {"n_warmup_epochs": 0}
self.kpt_visualizer_params = kpt_visualizer_params or {}
self.bbox_visualizer_params = bbox_visualizer_params or {}
self.bbox_task_name = bbox_task_name
self.kpt_task_name = kpt_task_name
self.head_type = head_type
self.backbone = backbone or var_config.backbone

@property
def nodes(self) -> list[ModelNodeConfig]:
Expand Down Expand Up @@ -107,11 +105,13 @@ def nodes(self) -> list[ModelNodeConfig]:

nodes.append(
ModelNodeConfig(
name=self.head_type,
name=self.head,
alias="kpt_detection_head",
inputs=["kpt_detection_neck"]
if self.use_neck
else ["kpt_detection_backbone"],
inputs=(
["kpt_detection_neck"]
if self.use_neck
else ["kpt_detection_backbone"]
),
freezing=self.head_params.pop("freezing", {}),
params=self.head_params,
task=task,
Expand All @@ -124,7 +124,7 @@ def losses(self) -> list[LossModuleConfig]:
"""Defines the loss module for the keypoint detection task."""
return [
LossModuleConfig(
name=self.head_type.replace("Head", "Loss"),
name=self.head.replace("Head", "Loss"),
attached_to="kpt_detection_head",
params=self.loss_params,
weight=1.0,
Expand Down
32 changes: 18 additions & 14 deletions luxonis_train/config/predefined_models/segmentation_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,36 +46,36 @@ class SegmentationModel(BasePredefinedModel):
def __init__(
self,
variant: VariantLiteral = "light",
task: Literal["binary", "multiclass"] = "binary",
backbone: str | None = None,
backbone_params: Params | None = None,
head_params: Params | None = None,
aux_head_params: Params | None = None,
loss_params: Params | None = None,
visualizer_params: Params | None = None,
task: Literal["binary", "multiclass"] = "binary",
task_name: str | None = None,
backbone: str | None = None,
):
self.variant = variant

var_config = get_variant(variant)

self.backbone = backbone or var_config.backbone
self.backbone_params = backbone_params or var_config.backbone_params
self.head_params = head_params or {}
self.aux_head_params = aux_head_params or {}
self.loss_params = loss_params or {}
self.visualizer_params = visualizer_params or {}
self.task = task
self.task_name = task_name or "segmentation"
self.backbone = backbone or var_config.backbone

@property
def nodes(self) -> list[ModelNodeConfig]:
"""Defines the model nodes, including backbone and head."""
self.head_params.update({"attach_index": -1})
self.aux_head_params.update({"attach_index": -2})
self.aux_head_params.update(
{"remove_on_export": True}
) if "remove_on_export" not in self.aux_head_params else None
(
self.aux_head_params.update({"remove_on_export": True})
if "remove_on_export" not in self.aux_head_params
else None
)

node_list = [
ModelNodeConfig(
Expand Down Expand Up @@ -111,9 +111,11 @@ def losses(self) -> list[LossModuleConfig]:
"""Defines the loss module for the segmentation task."""
loss_list = [
LossModuleConfig(
name="BCEWithLogitsLoss"
if self.task == "binary"
else "CrossEntropyLoss",
name=(
"BCEWithLogitsLoss"
if self.task == "binary"
else "CrossEntropyLoss"
),
alias="segmentation_loss",
attached_to="segmentation_head",
params=self.loss_params,
Expand All @@ -123,9 +125,11 @@ def losses(self) -> list[LossModuleConfig]:
if self.backbone_params.get("use_aux_heads", False):
loss_list.append(
LossModuleConfig(
name="BCEWithLogitsLoss"
if self.task == "binary"
else "CrossEntropyLoss",
name=(
"BCEWithLogitsLoss"
if self.task == "binary"
else "CrossEntropyLoss"
),
alias="aux_segmentation_loss",
attached_to="aux_segmentation_head",
params=self.loss_params,
Expand Down

0 comments on commit 7143e91

Please sign in to comment.