Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Segmentation Model Aux Head Export #90

Merged
merged 3 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions luxonis_train/config/predefined_models/segmentation_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,6 @@ def nodes(self) -> list[ModelNodeConfig]:
"""Defines the model nodes, including backbone and head."""
self.head_params.update({"attach_index": -1})
self.aux_head_params.update({"attach_index": -2})
(
self.aux_head_params.update({"remove_on_export": True})
if "remove_on_export" not in self.aux_head_params
else None
)

node_list = [
ModelNodeConfig(
Expand All @@ -106,6 +101,9 @@ def nodes(self) -> list[ModelNodeConfig]:
freezing=self.aux_head_params.pop("freezing", {}),
params=self.aux_head_params,
task=self.task_name,
remove_on_export=self.aux_head_params.pop(
"remove_on_export", True
),
)
)
return node_list
Expand Down
16 changes: 0 additions & 16 deletions luxonis_train/nodes/heads/ddrnet_segmentation_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,19 +91,3 @@ def forward(self, inputs: Tensor) -> Tensor:
if self.export:
return x.argmax(dim=1).to(dtype=torch.int32)
return x

def set_export_mode(self, mode: bool = True) -> None:
"""Sets the module to export mode.
Replaces the forward method with a constant empty tensor.
@warning: The replacement is destructive and cannot be undone.
@type mode: bool
@param mode: Whether to set the export mode to True or False.
Defaults to True.
"""
super().set_export_mode(mode)
if self.export and self.attach_index != -1:
logger.info("Removing the auxiliary head.")

self.forward = lambda inputs: torch.tensor([])
Loading