diff --git a/luxonis_train/config/predefined_models/segmentation_model.py b/luxonis_train/config/predefined_models/segmentation_model.py index 8a281131..5d2582d5 100644 --- a/luxonis_train/config/predefined_models/segmentation_model.py +++ b/luxonis_train/config/predefined_models/segmentation_model.py @@ -75,11 +75,6 @@ def nodes(self) -> list[ModelNodeConfig]: """Defines the model nodes, including backbone and head.""" self.head_params.update({"attach_index": -1}) self.aux_head_params.update({"attach_index": -2}) - ( - self.aux_head_params.update({"remove_on_export": True}) - if "remove_on_export" not in self.aux_head_params - else None - ) node_list = [ ModelNodeConfig( @@ -106,6 +101,9 @@ def nodes(self) -> list[ModelNodeConfig]: freezing=self.aux_head_params.pop("freezing", {}), params=self.aux_head_params, task=self.task_name, + remove_on_export=self.aux_head_params.pop( + "remove_on_export", True + ), ) ) return node_list diff --git a/luxonis_train/nodes/heads/ddrnet_segmentation_head.py b/luxonis_train/nodes/heads/ddrnet_segmentation_head.py index 964208b0..39293fed 100644 --- a/luxonis_train/nodes/heads/ddrnet_segmentation_head.py +++ b/luxonis_train/nodes/heads/ddrnet_segmentation_head.py @@ -91,19 +91,3 @@ def forward(self, inputs: Tensor) -> Tensor: if self.export: return x.argmax(dim=1).to(dtype=torch.int32) return x - - def set_export_mode(self, mode: bool = True) -> None: - """Sets the module to export mode. - - Replaces the forward method with a constant empty tensor. - - @warning: The replacement is destructive and cannot be undone. - @type mode: bool - @param mode: Whether to set the export mode to True or False. - Defaults to True. - """ - super().set_export_mode(mode) - if self.export and self.attach_index != -1: - logger.info("Removing the auxiliary head.") - - self.forward = lambda inputs: torch.tensor([])