Skip to content

Commit

Permalink
correct output names for dai-nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
klemen1999 committed Jan 23, 2025
1 parent 85e5c73 commit 5017243
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 24 deletions.
10 changes: 9 additions & 1 deletion luxonis_train/models/luxonis_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,9 +594,17 @@ def export_onnx(self, save_path: str, **kwargs) -> list[str]:
idx += 1
else:
output_names = []
running_i = {} # for case where export_output_names should be used but output node's output is split into multiple subnodes
for node_name, output_name, i in output_order:
if node_name in export_output_names_dict:
output_names.append(export_output_names_dict[node_name][i])
running_i[node_name] = (
running_i.get(node_name, -1) + 1
) # if not present default to 0 otherwise add 1
output_names.append(
export_output_names_dict[node_name][
running_i[node_name]
]
)
else:
output_names.append(f"{node_name}/{output_name}/{i}")

Expand Down
40 changes: 23 additions & 17 deletions luxonis_train/nodes/heads/efficient_bbox_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,27 +82,13 @@ def __init__(
in_channels=self.in_channels[i],
)
self.heads.append(curr_head)
if (
self.export_output_names is None
or len(self.export_output_names) != self.n_heads
):
if (
self.export_output_names is not None
and len(self.export_output_names) != self.n_heads
):
logger.warning(
f"Number of provided output names ({len(self.export_output_names)}) "
f"does not match number of heads ({self.n_heads}). "
f"Using default names."
)
self._export_output_names = [
f"output{i+1}_yolov6r2" for i in range(self.n_heads)
]

if initialize_weights:
self.initialize_weights()

if download_weights:
if (
download_weights and self.name == "EfficientBBoxHead"
): # skip download on classes that inherit this one
weights_path = self.get_variant_weights(initialize_weights)
if weights_path:
self.load_checkpoint(path=weights_path, strict=False)
Expand All @@ -111,6 +97,8 @@ def __init__(
f"No checkpoint available for {self.name}, skipping."
)

self.check_export_output_names()

def initialize_weights(self) -> None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
Expand Down Expand Up @@ -142,6 +130,24 @@ def get_variant_weights(self, initialize_weights: bool) -> str | None:
else:
return None

def check_export_output_names(self):
if (
self.export_output_names is None
or len(self.export_output_names) != self.n_heads
):
if (
self.export_output_names is not None
and len(self.export_output_names) != self.n_heads
):
logger.warning(
f"Number of provided output names ({len(self.export_output_names)}) "
f"does not match number of heads ({self.n_heads}). "
f"Using default names."
)
self._export_output_names = [
f"output{i + 1}_yolov6r2" for i in range(self.n_heads)
]

def forward(
self, inputs: list[Tensor]
) -> tuple[list[Tensor], list[Tensor], list[Tensor]]:
Expand Down
5 changes: 4 additions & 1 deletion luxonis_train/nodes/heads/precision_bbox_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ def __init__(
self.bias_init()
self.initialize_weights()

self.check_export_output_names()

def check_export_output_names(self):
if (
self.export_output_names is None
or len(self.export_output_names) != self.n_heads
Expand All @@ -141,7 +144,7 @@ def __init__(
f"Using default names."
)
self._export_output_names = [
f"output{i+1}_yolov8r2" for i in range(self.n_heads)
f"output{i + 1}_yolov8" for i in range(self.n_heads)
]

def forward(self, x: list[Tensor]) -> tuple[list[Tensor], list[Tensor]]:
Expand Down
32 changes: 27 additions & 5 deletions luxonis_train/nodes/heads/precision_seg_bbox_head.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from typing import Any, Literal

import torch
Expand All @@ -14,6 +15,8 @@

from .precision_bbox_head import PrecisionBBoxHead

logger = logging.getLogger(__name__)


class PrecisionSegmentBBoxHead(PrecisionBBoxHead):
tasks: list[TaskType] = [
Expand Down Expand Up @@ -61,10 +64,6 @@ def __init__(
)

self.n_masks = n_masks
self.n_proto = n_proto

self.proto = SegProto(self.in_channels[0], self.n_proto, self.n_masks)

mid_ch = max(self.in_channels[0] // 4, self.n_masks)
self.mask_layers = nn.ModuleList(
nn.Sequential(
Expand All @@ -75,7 +74,30 @@ def __init__(
for x in self.in_channels
)

self._export_output_names = None
self.n_proto = n_proto
self.proto = SegProto(self.in_channels[0], self.n_proto, self.n_masks)

self.check_export_output_names()

def check_export_output_names(self):
if (
self.export_output_names is None
or len(self.export_output_names) != self.n_heads
):
if (
self.export_output_names is not None
and len(self.export_output_names) != self.n_heads
):
logger.warning(
f"Number of provided output names ({len(self.export_output_names)}) "
f"does not match number of heads ({self.n_heads}). "
f"Using default names."
)
self._export_output_names = (
[f"output{i + 1}_yolov8" for i in range(self.n_heads)]
+ [f"output{i + 1}_masks" for i in range(self.n_heads)]
+ ["protos_output"]
) # export names are applied on sorter output names

def forward(
self, inputs: list[Tensor]
Expand Down

0 comments on commit 5017243

Please sign in to comment.