diff --git a/configs/coco_model.yaml b/configs/coco_model.yaml index 23516bea..dd0256b8 100644 --- a/configs/coco_model.yaml +++ b/configs/coco_model.yaml @@ -26,10 +26,47 @@ model: params: conf_thres: 0.25 iou_thres: 0.45 + losses: + name: ImplicitKeypointBBoxLoss + params: + keypoint_regression_loss_weight: 0.5 + keypoint_visibility_loss_weight: 0.7 + bbox_loss_weight: 0.05 + objectness_loss_weight: 0.2 + metrics: + - name: ObjectKeypointSimilarity + is_main_metric: true + - name: MeanAveragePrecisionKeypoints + + visualizers: + name: MultiVisualizer + attached_to: ImplicitKeypointBBoxHead + params: + visualizers: + - name: KeypointVisualizer + params: + nonvisible_color: blue + - name: BBoxVisualizer + params: + colors: + person: "#FF5055" - name: SegmentationHead inputs: - RepPANNeck + losses: + name: BCEWithLogitsLoss + metrics: + - name: F1Score + params: + task: binary + - name: JaccardIndex + params: + task: binary + visualizers: + name: SegmentationVisualizer + params: + colors: "#FF5055" - name: EfficientBBoxHead inputs: @@ -37,55 +74,12 @@ model: params: conf_thres: 0.75 iou_thres: 0.45 - - losses: - - name: AdaptiveDetectionLoss - attached_to: EfficientBBoxHead - - name: BCEWithLogitsLoss - attached_to: SegmentationHead - - name: ImplicitKeypointBBoxLoss - attached_to: ImplicitKeypointBBoxHead - params: - keypoint_regression_loss_weight: 0.5 - keypoint_visibility_loss_weight: 0.7 - bbox_loss_weight: 0.05 - objectness_loss_weight: 0.2 - - metrics: - - name: ObjectKeypointSimilarity - is_main_metric: true - attached_to: ImplicitKeypointBBoxHead - - name: MeanAveragePrecisionKeypoints - attached_to: ImplicitKeypointBBoxHead - - name: MeanAveragePrecision - attached_to: EfficientBBoxHead - - name: F1Score - attached_to: SegmentationHead - params: - task: binary - - name: JaccardIndex - attached_to: SegmentationHead - params: - task: binary - - visualizers: - - name: MultiVisualizer - attached_to: ImplicitKeypointBBoxHead - params: - visualizers: - - name: KeypointVisualizer - params: - nonvisible_color: blue - - name: BBoxVisualizer - params: - colors: - person: "#FF5055" - - name: SegmentationVisualizer - attached_to: SegmentationHead - params: - colors: "#FF5055" - - name: BBoxVisualizer - attached_to: EfficientBBoxHead + losses: + name: AdaptiveDetectionLoss + metrics: + name: MeanAveragePrecision + visualizers: + name: BBoxVisualizer tracker: project_name: coco_test diff --git a/configs/resnet_model.yaml b/configs/resnet_model.yaml index bb9f8f62..fd0b66cd 100644 --- a/configs/resnet_model.yaml +++ b/configs/resnet_model.yaml @@ -8,26 +8,21 @@ model: download_weights: True - name: ClassificationHead - inputs: - - ResNet - losses: - - name: CrossEntropyLoss - attached_to: ClassificationHead + losses: + name: CrossEntropyLoss - metrics: - - name: Accuracy - is_main_metric: true - attached_to: ClassificationHead + metrics: + name: Accuracy + is_main_metric: true - visualizers: - - name: ClassificationVisualizer - attached_to: ClassificationHead - params: - font_scale: 0.5 - color: [255, 0, 0] - thickness: 2 - include_plot: True + visualizers: + name: ClassificationVisualizer + params: + font_scale: 0.5 + color: [255, 0, 0] + thickness: 2 + include_plot: True loader: params: diff --git a/luxonis_train/config/config.py b/luxonis_train/config/config.py index 931bcd56..76a9baae 100644 --- a/luxonis_train/config/config.py +++ b/luxonis_train/config/config.py @@ -1,5 +1,6 @@ import logging import sys +import warnings from typing import Annotated, Any, Literal, TypeAlias from luxonis_ml.enums import DatasetType @@ -83,6 +84,44 @@ class ModelConfig(BaseModelExtraForbid): visualizers: list[AttachedModuleConfig] = [] outputs: list[str] = [] + @field_validator("nodes", mode="before") + @classmethod + def validate_nodes(cls, nodes: Any) -> Any: + logged_general_warning = False + if not isinstance(nodes, list): + return nodes + names = [] + last_body_index: int | None = None + for i, node in enumerate(nodes): + name = node.get("alias", node.get("name")) + if name is None: + raise ValueError( + f"Node {i} does not specify the `name` field." + ) + if "Head" in name and last_body_index is None: + last_body_index = i - 1 + names.append(name) + if i > 0 and "inputs" not in node: + if last_body_index is not None: + prev_name = names[last_body_index] + else: + prev_name = names[i - 1] + + if not logged_general_warning: + logger.warning( + f"Field `inputs` not specified for node '{name}'. " + "Assuming the model follows a linear multi-head topology " + "(backbone -> (neck?) -> head1, head2, ...). " + "If this is incorrect, please specify the `inputs` field explicitly." + ) + logged_general_warning = True + + logger.warning( + f"Setting `inputs` of '{name}' to '{prev_name}'. " + ) + node["inputs"] = [prev_name] + return nodes + @model_validator(mode="after") def check_predefined_model(self) -> Self: from .predefined_models.base_predefined_model import MODELS @@ -170,6 +209,30 @@ def check_unique_names(self) -> Self: names.add(name) return self + @model_validator(mode="before") + @classmethod + def check_attached_modules(cls, data: Params) -> Params: + if "nodes" not in data: + return data + for section in ["losses", "metrics", "visualizers"]: + if section not in data: + data[section] = [] + else: + warnings.warn( + f"Field `model.{section}` is deprecated. " + f"Please specify `{section}`under " + "the node they are attached to." + ) + for node in data["nodes"]: + if section in node: + cfg = node.pop(section) + if not isinstance(cfg, list): + cfg = [cfg] + for c in cfg: + c["attached_to"] = node.get("alias", node.get("name")) + data[section] += cfg + return data + class TrackerConfig(BaseModelExtraForbid): project_name: str | None = None diff --git a/requirements-config.txt b/requirements-config.txt new file mode 100644 index 00000000..0a7b2625 --- /dev/null +++ b/requirements-config.txt @@ -0,0 +1 @@ +luxonis-ml[data,utils]@git+https://github.com/luxonis/luxonis-ml.git@dev