From c559448f34c4dcefd3d7fdf399be2785ec2e95c7 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Fri, 12 Jan 2024 16:14:04 +0100 Subject: [PATCH] fixed crashing when no labels are present in the batch --- .../attached_modules/losses/adaptive_detection_loss.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/luxonis_train/attached_modules/losses/adaptive_detection_loss.py b/luxonis_train/attached_modules/losses/adaptive_detection_loss.py index 89c18f67..af1a7e6a 100644 --- a/luxonis_train/attached_modules/losses/adaptive_detection_loss.py +++ b/luxonis_train/attached_modules/losses/adaptive_detection_loss.py @@ -1,4 +1,4 @@ -from typing import Literal +from typing import Literal, cast import torch import torch.nn.functional as F @@ -209,8 +209,11 @@ def forward( def _preprocess_target(self, target: Tensor, batch_size: int, scale_tensor: Tensor): """Preprocess target in shape [batch_size, N, 5] where N is maximum number of instances in one image.""" - sample_ids, counts = torch.unique(target[:, 0].int(), return_counts=True) - out_target = torch.zeros(batch_size, counts.max(), 5, device=target.device) + sample_ids, counts = cast( + tuple[Tensor, Tensor], torch.unique(target[:, 0].int(), return_counts=True) + ) + c_max = int(counts.max()) if counts.numel() > 0 else 0 + out_target = torch.zeros(batch_size, c_max, 5, device=target.device) out_target[:, :, 0] = -1 for id, count in zip(sample_ids, counts): out_target[id, :count] = target[target[:, 0] == id][:, 1:]