Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplified Node Configuration #79

Merged
merged 6 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 43 additions & 49 deletions configs/coco_model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,66 +26,60 @@ model:
params:
conf_thres: 0.25
iou_thres: 0.45
losses:
name: ImplicitKeypointBBoxLoss
params:
keypoint_regression_loss_weight: 0.5
keypoint_visibility_loss_weight: 0.7
bbox_loss_weight: 0.05
objectness_loss_weight: 0.2
metrics:
- name: ObjectKeypointSimilarity
is_main_metric: true
- name: MeanAveragePrecisionKeypoints

visualizers:
name: MultiVisualizer
attached_to: ImplicitKeypointBBoxHead
params:
visualizers:
- name: KeypointVisualizer
params:
nonvisible_color: blue
- name: BBoxVisualizer
params:
colors:
person: "#FF5055"

- name: SegmentationHead
inputs:
- RepPANNeck
losses:
name: BCEWithLogitsLoss
metrics:
- name: F1Score
params:
task: binary
- name: JaccardIndex
params:
task: binary
visualizers:
name: SegmentationVisualizer
params:
colors: "#FF5055"

- name: EfficientBBoxHead
inputs:
- RepPANNeck
params:
conf_thres: 0.75
iou_thres: 0.45

losses:
- name: AdaptiveDetectionLoss
attached_to: EfficientBBoxHead
- name: BCEWithLogitsLoss
attached_to: SegmentationHead
- name: ImplicitKeypointBBoxLoss
attached_to: ImplicitKeypointBBoxHead
params:
keypoint_regression_loss_weight: 0.5
keypoint_visibility_loss_weight: 0.7
bbox_loss_weight: 0.05
objectness_loss_weight: 0.2

metrics:
- name: ObjectKeypointSimilarity
is_main_metric: true
attached_to: ImplicitKeypointBBoxHead
- name: MeanAveragePrecisionKeypoints
attached_to: ImplicitKeypointBBoxHead
- name: MeanAveragePrecision
attached_to: EfficientBBoxHead
- name: F1Score
attached_to: SegmentationHead
params:
task: binary
- name: JaccardIndex
attached_to: SegmentationHead
params:
task: binary

visualizers:
- name: MultiVisualizer
attached_to: ImplicitKeypointBBoxHead
params:
visualizers:
- name: KeypointVisualizer
params:
nonvisible_color: blue
- name: BBoxVisualizer
params:
colors:
person: "#FF5055"
- name: SegmentationVisualizer
attached_to: SegmentationHead
params:
colors: "#FF5055"
- name: BBoxVisualizer
attached_to: EfficientBBoxHead
losses:
name: AdaptiveDetectionLoss
metrics:
name: MeanAveragePrecision
visualizers:
name: BBoxVisualizer

tracker:
project_name: coco_test
Expand Down
29 changes: 12 additions & 17 deletions configs/resnet_model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,21 @@ model:
download_weights: True

- name: ClassificationHead
inputs:
- ResNet

losses:
- name: CrossEntropyLoss
attached_to: ClassificationHead
losses:
name: CrossEntropyLoss

metrics:
- name: Accuracy
is_main_metric: true
attached_to: ClassificationHead
metrics:
name: Accuracy
is_main_metric: true

visualizers:
- name: ClassificationVisualizer
attached_to: ClassificationHead
params:
font_scale: 0.5
color: [255, 0, 0]
thickness: 2
include_plot: True
visualizers:
name: ClassificationVisualizer
params:
font_scale: 0.5
color: [255, 0, 0]
thickness: 2
include_plot: True

loader:
params:
Expand Down
63 changes: 63 additions & 0 deletions luxonis_train/config/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import sys
import warnings
from typing import Annotated, Any, Literal, TypeAlias

from luxonis_ml.enums import DatasetType
Expand Down Expand Up @@ -83,6 +84,44 @@ class ModelConfig(BaseModelExtraForbid):
visualizers: list[AttachedModuleConfig] = []
outputs: list[str] = []

@field_validator("nodes", mode="before")
@classmethod
def validate_nodes(cls, nodes: Any) -> Any:
logged_general_warning = False
if not isinstance(nodes, list):
return nodes
names = []
last_body_index: int | None = None
for i, node in enumerate(nodes):
name = node.get("alias", node.get("name"))
if name is None:
raise ValueError(
f"Node {i} does not specify the `name` field."
)
if "Head" in name and last_body_index is None:
kozlov721 marked this conversation as resolved.
Show resolved Hide resolved
last_body_index = i - 1
names.append(name)
if i > 0 and "inputs" not in node:
if last_body_index is not None:
prev_name = names[last_body_index]
else:
prev_name = names[i - 1]

if not logged_general_warning:
logger.warning(
f"Field `inputs` not specified for node '{name}'. "
"Assuming the model follows a linear multi-head topology "
"(backbone -> (neck?) -> head1, head2, ...). "
"If this is incorrect, please specify the `inputs` field explicitly."
)
logged_general_warning = True

logger.warning(
f"Setting `inputs` of '{name}' to '{prev_name}'. "
)
node["inputs"] = [prev_name]
return nodes

@model_validator(mode="after")
def check_predefined_model(self) -> Self:
from .predefined_models.base_predefined_model import MODELS
Expand Down Expand Up @@ -170,6 +209,30 @@ def check_unique_names(self) -> Self:
names.add(name)
return self

@model_validator(mode="before")
@classmethod
def check_attached_modules(cls, data: Params) -> Params:
if "nodes" not in data:
return data
for section in ["losses", "metrics", "visualizers"]:
if section not in data:
data[section] = []
else:
warnings.warn(
f"Field `model.{section}` is deprecated. "
f"Please specify `{section}`under "
"the node they are attached to."
)
for node in data["nodes"]:
if section in node:
cfg = node.pop(section)
if not isinstance(cfg, list):
cfg = [cfg]
for c in cfg:
c["attached_to"] = node.get("alias", node.get("name"))
data[section] += cfg
return data


class TrackerConfig(BaseModelExtraForbid):
project_name: str | None = None
Expand Down
1 change: 1 addition & 0 deletions requirements-config.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
luxonis-ml[data,utils]@git+https://github.com/luxonis/luxonis-ml.git@dev