From 6a6f2dbd072a1dc9ccf60e243f78f12ffa11bb1a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20Kozlovsk=C3=BD?= Date: Thu, 3 Oct 2024 19:17:58 +0200 Subject: [PATCH] Change DDRNet Head Output Dtype (#88) --- luxonis_train/nodes/heads/ddrnet_segmentation_head.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/luxonis_train/nodes/heads/ddrnet_segmentation_head.py b/luxonis_train/nodes/heads/ddrnet_segmentation_head.py index e2ebe2f3..964208b0 100644 --- a/luxonis_train/nodes/heads/ddrnet_segmentation_head.py +++ b/luxonis_train/nodes/heads/ddrnet_segmentation_head.py @@ -83,13 +83,13 @@ def __init__( ) def forward(self, inputs: Tensor) -> Tensor: - x = self.relu(self.bn1(inputs)) + x: Tensor = self.relu(self.bn1(inputs)) x = self.conv1(x) x = self.relu(self.bn2(x)) x = self.conv2(x) x = self.upscale(x) if self.export: - return x.argmax(dim=1) + return x.argmax(dim=1).to(dtype=torch.int32) return x def set_export_mode(self, mode: bool = True) -> None: