From e6d24419827153d397caa6a1bf773e527d921d62 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Thu, 3 Oct 2024 20:13:51 +0200 Subject: [PATCH 1/3] removed old set_export_mode --- .../nodes/heads/ddrnet_segmentation_head.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/luxonis_train/nodes/heads/ddrnet_segmentation_head.py b/luxonis_train/nodes/heads/ddrnet_segmentation_head.py index 964208b0..39293fed 100644 --- a/luxonis_train/nodes/heads/ddrnet_segmentation_head.py +++ b/luxonis_train/nodes/heads/ddrnet_segmentation_head.py @@ -91,19 +91,3 @@ def forward(self, inputs: Tensor) -> Tensor: if self.export: return x.argmax(dim=1).to(dtype=torch.int32) return x - - def set_export_mode(self, mode: bool = True) -> None: - """Sets the module to export mode. - - Replaces the forward method with a constant empty tensor. - - @warning: The replacement is destructive and cannot be undone. - @type mode: bool - @param mode: Whether to set the export mode to True or False. - Defaults to True. - """ - super().set_export_mode(mode) - if self.export and self.attach_index != -1: - logger.info("Removing the auxiliary head.") - - self.forward = lambda inputs: torch.tensor([]) From 6da3ae4ecfdee2d3baaf2fc8a58e5aa151760566 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Thu, 3 Oct 2024 20:14:01 +0200 Subject: [PATCH 2/3] simplified expression --- .../config/predefined_models/segmentation_model.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/luxonis_train/config/predefined_models/segmentation_model.py b/luxonis_train/config/predefined_models/segmentation_model.py index 8a281131..09b56fcb 100644 --- a/luxonis_train/config/predefined_models/segmentation_model.py +++ b/luxonis_train/config/predefined_models/segmentation_model.py @@ -75,11 +75,8 @@ def nodes(self) -> list[ModelNodeConfig]: """Defines the model nodes, including backbone and head.""" self.head_params.update({"attach_index": -1}) self.aux_head_params.update({"attach_index": -2}) - ( - self.aux_head_params.update({"remove_on_export": True}) - if "remove_on_export" not in self.aux_head_params - else None - ) + if "remove_on_export" not in self.head_params: + self.head_params["remove_on_export"] = True node_list = [ ModelNodeConfig( From 679605ae1dd60edf16d8f9dabaa8278f62896bb0 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Thu, 3 Oct 2024 20:32:45 +0200 Subject: [PATCH 3/3] fixed remove_on_export param --- luxonis_train/config/predefined_models/segmentation_model.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/luxonis_train/config/predefined_models/segmentation_model.py b/luxonis_train/config/predefined_models/segmentation_model.py index 09b56fcb..5d2582d5 100644 --- a/luxonis_train/config/predefined_models/segmentation_model.py +++ b/luxonis_train/config/predefined_models/segmentation_model.py @@ -75,8 +75,6 @@ def nodes(self) -> list[ModelNodeConfig]: """Defines the model nodes, including backbone and head.""" self.head_params.update({"attach_index": -1}) self.aux_head_params.update({"attach_index": -2}) - if "remove_on_export" not in self.head_params: - self.head_params["remove_on_export"] = True node_list = [ ModelNodeConfig( @@ -103,6 +101,9 @@ def nodes(self) -> list[ModelNodeConfig]: freezing=self.aux_head_params.pop("freezing", {}), params=self.aux_head_params, task=self.task_name, + remove_on_export=self.aux_head_params.pop( + "remove_on_export", True + ), ) ) return node_list