diff --git a/luxonis_train/nodes/heads/ddrnet_segmentation_head.py b/luxonis_train/nodes/heads/ddrnet_segmentation_head.py index 730d890c..89886d45 100644 --- a/luxonis_train/nodes/heads/ddrnet_segmentation_head.py +++ b/luxonis_train/nodes/heads/ddrnet_segmentation_head.py @@ -105,6 +105,6 @@ def forward(self, inputs: Tensor) -> Tensor: x = self.conv2(x) x = self.upscale(x) if self.export: - x = x.argmax(dim=1) if self.n_classes > 1 else (x > 0) + x = x.argmax(dim=1) if self.n_classes > 1 else torch.sigmoid(x) > 0.5 return x.to(dtype=torch.int32) return x