Skip to content

Commit

Permalink
Merge pull request #1 from luxonis/fix/empty-labels-bbox-loss
Browse files Browse the repository at this point in the history
Crashing `AdaptiveDetectionLoss`
  • Loading branch information
kozlov721 authored Jan 12, 2024
2 parents 270ec4f + c559448 commit 25fdf9a
Showing 1 changed file with 6 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Literal
from typing import Literal, cast

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -209,8 +209,11 @@ def forward(
def _preprocess_target(self, target: Tensor, batch_size: int, scale_tensor: Tensor):
"""Preprocess target in shape [batch_size, N, 5] where N is maximum number of
instances in one image."""
sample_ids, counts = torch.unique(target[:, 0].int(), return_counts=True)
out_target = torch.zeros(batch_size, counts.max(), 5, device=target.device)
sample_ids, counts = cast(
tuple[Tensor, Tensor], torch.unique(target[:, 0].int(), return_counts=True)
)
c_max = int(counts.max()) if counts.numel() > 0 else 0
out_target = torch.zeros(batch_size, c_max, 5, device=target.device)
out_target[:, :, 0] = -1
for id, count in zip(sample_ids, counts):
out_target[id, :count] = target[target[:, 0] == id][:, 1:]
Expand Down

0 comments on commit 25fdf9a

Please sign in to comment.