diff --git a/luxonis_train/attached_modules/losses/fomo_localization_loss.py b/luxonis_train/attached_modules/losses/fomo_localization_loss.py index 0ad1ea60..2857710d 100644 --- a/luxonis_train/attached_modules/losses/fomo_localization_loss.py +++ b/luxonis_train/attached_modules/losses/fomo_localization_loss.py @@ -20,7 +20,7 @@ class FOMOLocalizationLoss(BaseLoss[Tensor, Tensor]): (TaskType.BOUNDINGBOX, TaskType.KEYPOINTS) ] - def __init__(self, object_weight: float = 1000, **kwargs: Any): + def __init__(self, object_weight: float = 500, **kwargs: Any): """FOMO Localization Loss for object detection using heatmaps. @type object_weight: float diff --git a/luxonis_train/nodes/heads/fomo_head.py b/luxonis_train/nodes/heads/fomo_head.py index 7a8943e2..cc9ea1a0 100644 --- a/luxonis_train/nodes/heads/fomo_head.py +++ b/luxonis_train/nodes/heads/fomo_head.py @@ -2,6 +2,7 @@ from typing import Any, List import torch +import torch.nn.functional as F from torch import Tensor, nn from luxonis_train.enums import TaskType @@ -74,8 +75,8 @@ def wrap(self, heatmap: Tensor) -> Packet[Tensor]: } def _heatmap_to_kpts(self, heatmap: Tensor) -> List[Tensor]: - """Convert heatmap to keypoint pairs, ensuring all tensors are - on the same device.""" + """Convert heatmap to keypoint pairs using local-max NMS so that + only the strongest local peak in a neighborhood is retained.""" device = heatmap.device batch_size, num_classes, height, width = heatmap.shape @@ -84,23 +85,39 @@ def _heatmap_to_kpts(self, heatmap: Tensor) -> List[Tensor]: kpts_per_img = [] for c in range(num_classes): - y_indices, x_indices = torch.where( - torch.sigmoid(heatmap[batch_idx, c, :, :]) > 0.5 - ) - + prob_map = torch.sigmoid(heatmap[batch_idx, c, :, :]) + + pooled_map = ( + F.max_pool2d( + prob_map.unsqueeze(0).unsqueeze(0), # shape [1,1,H,W] + kernel_size=3, + stride=1, + padding=1, + ) + .squeeze(0) + .squeeze(0) + ) # back to [H,W] + + threshold = 0.5 + keep = (prob_map == pooled_map) & (prob_map > threshold) + + y_indices, x_indices = torch.where(keep) kpts = [] for y, x in zip(y_indices, x_indices): kpt_x = x.item() / width * self.original_img_size[1] kpt_y = y.item() / height * self.original_img_size[0] - kpts.append([kpt_x, kpt_y, 2]) + + kpts.append([kpt_x, kpt_y, float(prob_map[y, x])]) kpts_per_img.append(kpts) if all(len(kpt) == 0 for kpt in kpts_per_img): - kpts_per_img = [[[0, 0, 0]]] # One keypoint per object + kpts_per_img = [[[0, 0, 0.0]]] batch_kpts.append( - torch.tensor(kpts_per_img, device=device).permute(1, 0, 2) + torch.tensor( + kpts_per_img, device=device, dtype=torch.float32 + ).permute(1, 0, 2) ) return batch_kpts