From 04a6ab6cce82d6ac841d847f3bc6384c1006de94 Mon Sep 17 00:00:00 2001 From: Jernej Sabadin Date: Thu, 16 Jan 2025 11:31:48 +0100 Subject: [PATCH] fix export bug --- luxonis_train/nodes/heads/fomo_head.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/luxonis_train/nodes/heads/fomo_head.py b/luxonis_train/nodes/heads/fomo_head.py index 5190ca9a..9f557a8a 100644 --- a/luxonis_train/nodes/heads/fomo_head.py +++ b/luxonis_train/nodes/heads/fomo_head.py @@ -89,18 +89,12 @@ def _apply_nms_if_needed(self, heatmap: Tensor) -> Tensor: if not self.use_nms: return heatmap - return ( - F.max_pool2d( - heatmap.unsqueeze(0).unsqueeze( - 0 - ), # Add dummy batch/channel dimensions - kernel_size=3, - stride=1, - padding=1, - ) - .squeeze(0) - .squeeze(0) - ) # Remove dummy dimensions + return F.max_pool2d( + heatmap, + kernel_size=3, + stride=1, + padding=1, + ) def _heatmap_to_kpts(self, heatmap: Tensor) -> List[Tensor]: """Convert heatmap to keypoint pairs using local-max NMS so that