Skip to content

Commit

Permalink
remove loss scaling
Browse files Browse the repository at this point in the history
  • Loading branch information
JSabadin committed Dec 10, 2024
1 parent 0f842e1 commit b6be8d4
Showing 1 changed file with 5 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def prepare(
target_kpts = self.get_label(labels, TaskType.KEYPOINTS)
target_bbox = self.get_label(labels, TaskType.BOUNDINGBOX)

self.batch_size = pred_scores.shape[0]
batch_size = pred_scores.shape[0]
n_kpts = (target_kpts.shape[1] - 2) // 3

self._init_parameters(feats)
Expand All @@ -112,16 +112,14 @@ def prepare(
pred_kpts = self.dist2kpts_noscale(
self.anchor_points_strided,
pred_kpts.view(
self.batch_size,
batch_size,
-1,
n_kpts,
3,
),
)

target_bbox = self._preprocess_bbox_target(
target_bbox, self.batch_size
)
target_bbox = self._preprocess_bbox_target(target_bbox, batch_size)

gt_bbox_labels = target_bbox[:, :, :1]
gt_xyxy = target_bbox[:, :, 1:]
Expand All @@ -141,7 +139,7 @@ def prepare(
)

batched_kpts = self._preprocess_kpts_target(
target_kpts, self.batch_size, self.gt_kpts_scale
target_kpts, batch_size, self.gt_kpts_scale
)
assigned_gt_idx_expanded = assigned_gt_idx.unsqueeze(-1).unsqueeze(-1)
selected_keypoints = batched_kpts.gather(
Expand Down Expand Up @@ -234,7 +232,7 @@ def forward(
"visibility": visibility_loss.detach(),
}

return loss * self.batch_size, sub_losses
return loss, sub_losses

def _preprocess_kpts_target(
self, kpts_target: Tensor, batch_size: int, scale_tensor: Tensor
Expand Down

0 comments on commit b6be8d4

Please sign in to comment.