Skip to content

Commit

Permalink
add ~NMS, smaller loss weight
Browse files Browse the repository at this point in the history
  • Loading branch information
JSabadin committed Jan 16, 2025
1 parent 5d85270 commit bbfd989
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class FOMOLocalizationLoss(BaseLoss[Tensor, Tensor]):
(TaskType.BOUNDINGBOX, TaskType.KEYPOINTS)
]

def __init__(self, object_weight: float = 1000, **kwargs: Any):
def __init__(self, object_weight: float = 500, **kwargs: Any):
"""FOMO Localization Loss for object detection using heatmaps.
@type object_weight: float
Expand Down
35 changes: 26 additions & 9 deletions luxonis_train/nodes/heads/fomo_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, List

import torch
import torch.nn.functional as F
from torch import Tensor, nn

from luxonis_train.enums import TaskType
Expand Down Expand Up @@ -74,8 +75,8 @@ def wrap(self, heatmap: Tensor) -> Packet[Tensor]:
}

def _heatmap_to_kpts(self, heatmap: Tensor) -> List[Tensor]:
"""Convert heatmap to keypoint pairs, ensuring all tensors are
on the same device."""
"""Convert heatmap to keypoint pairs using local-max NMS so that
only the strongest local peak in a neighborhood is retained."""
device = heatmap.device
batch_size, num_classes, height, width = heatmap.shape

Expand All @@ -84,23 +85,39 @@ def _heatmap_to_kpts(self, heatmap: Tensor) -> List[Tensor]:
kpts_per_img = []

for c in range(num_classes):
y_indices, x_indices = torch.where(
torch.sigmoid(heatmap[batch_idx, c, :, :]) > 0.5
)

prob_map = torch.sigmoid(heatmap[batch_idx, c, :, :])

pooled_map = (
F.max_pool2d(
prob_map.unsqueeze(0).unsqueeze(0), # shape [1,1,H,W]
kernel_size=3,
stride=1,
padding=1,
)
.squeeze(0)
.squeeze(0)
) # back to [H,W]

threshold = 0.5
keep = (prob_map == pooled_map) & (prob_map > threshold)

y_indices, x_indices = torch.where(keep)
kpts = []
for y, x in zip(y_indices, x_indices):
kpt_x = x.item() / width * self.original_img_size[1]
kpt_y = y.item() / height * self.original_img_size[0]
kpts.append([kpt_x, kpt_y, 2])

kpts.append([kpt_x, kpt_y, float(prob_map[y, x])])

kpts_per_img.append(kpts)

if all(len(kpt) == 0 for kpt in kpts_per_img):
kpts_per_img = [[[0, 0, 0]]] # One keypoint per object
kpts_per_img = [[[0, 0, 0.0]]]

batch_kpts.append(
torch.tensor(kpts_per_img, device=device).permute(1, 0, 2)
torch.tensor(
kpts_per_img, device=device, dtype=torch.float32
).permute(1, 0, 2)
)

return batch_kpts

0 comments on commit bbfd989

Please sign in to comment.