Skip to content

Commit

Permalink
Fix Segmentation Model Aux Head Export (#90)
Browse files Browse the repository at this point in the history
  • Loading branch information
kozlov721 committed Oct 9, 2024
1 parent d6ddda1 commit 1d5c3d7
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 21 deletions.
8 changes: 3 additions & 5 deletions luxonis_train/config/predefined_models/segmentation_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,6 @@ def nodes(self) -> list[ModelNodeConfig]:
"""Defines the model nodes, including backbone and head."""
self.head_params.update({"attach_index": -1})
self.aux_head_params.update({"attach_index": -2})
(
self.aux_head_params.update({"remove_on_export": True})
if "remove_on_export" not in self.aux_head_params
else None
)

node_list = [
ModelNodeConfig(
Expand All @@ -106,6 +101,9 @@ def nodes(self) -> list[ModelNodeConfig]:
freezing=self.aux_head_params.pop("freezing", {}),
params=self.aux_head_params,
task=self.task_name,
remove_on_export=self.aux_head_params.pop(
"remove_on_export", True
),
)
)
return node_list
Expand Down
16 changes: 0 additions & 16 deletions luxonis_train/nodes/heads/ddrnet_segmentation_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,19 +91,3 @@ def forward(self, inputs: Tensor) -> Tensor:
if self.export:
return x.argmax(dim=1).to(dtype=torch.int32)
return x

def set_export_mode(self, mode: bool = True) -> None:
"""Sets the module to export mode.
Replaces the forward method with a constant empty tensor.
@warning: The replacement is destructive and cannot be undone.
@type mode: bool
@param mode: Whether to set the export mode to True or False.
Defaults to True.
"""
super().set_export_mode(mode)
if self.export and self.attach_index != -1:
logger.info("Removing the auxiliary head.")

self.forward = lambda inputs: torch.tensor([])

0 comments on commit 1d5c3d7

Please sign in to comment.