From fccdc96b5337689519662def2726aeee921d46d4 Mon Sep 17 00:00:00 2001 From: Jernej Sabadin <116955183+JSabadin@users.noreply.github.com> Date: Thu, 11 Jul 2024 18:41:34 +0200 Subject: [PATCH] New Keypoint Heads and Losses (#40) Co-authored-by: klemen1999 Co-authored-by: Martin Kozlovsky Co-authored-by: GitHub Actions --- README.md | 14 + configs/coco_model.yaml | 2 +- .../attached_modules/losses/README.md | 18 +- .../attached_modules/losses/__init__.py | 2 + .../losses/adaptive_detection_loss.py | 5 +- .../losses/efficient_keypoint_bbox_loss.py | 391 ++++++++++++++++++ .../losses/implicit_keypoint_bbox_loss.py | 51 ++- .../attached_modules/losses/keypoint_loss.py | 88 ++-- .../attached_modules/metrics/README.md | 2 + .../metrics/mean_average_precision.py | 4 +- .../mean_average_precision_keypoints.py | 41 +- .../metrics/object_keypoint_similarity.py | 192 ++++++--- .../visualizers/keypoint_visualizer.py | 2 +- luxonis_train/core/archiver.py | 10 +- luxonis_train/nodes/README.md | 14 + luxonis_train/nodes/__init__.py | 2 + luxonis_train/nodes/efficient_bbox_head.py | 6 +- .../nodes/efficient_keypoint_bbox_head.py | 207 ++++++++++ .../nodes/enums/head_categorization.py | 2 + .../utils/assigners/atts_assigner.py | 6 +- luxonis_train/utils/assigners/tal_assigner.py | 13 +- luxonis_train/utils/boxutils.py | 14 +- media/coverage_badge.svg | 4 +- .../test_assigners/test_atts_assigner.py | 3 +- .../test_assigners/test_tal_assigner.py | 6 +- 25 files changed, 963 insertions(+), 136 deletions(-) create mode 100644 luxonis_train/attached_modules/losses/efficient_keypoint_bbox_loss.py create mode 100644 luxonis_train/nodes/efficient_keypoint_bbox_head.py diff --git a/README.md b/README.md index a612b59e..873fe2c9 100644 --- a/README.md +++ b/README.md @@ -50,6 +50,12 @@ For instructions on how to create a dataset in the LDF, follow the [examples](https://github.com/luxonis/luxonis-ml/tree/main/examples) in the [luxonis-ml](https://github.com/luxonis/luxonis-ml) repository. +To inspect dataset images by split (train, val, test), use the command: + +```bash +luxonis_train data inspect --config --view +``` + ## Training Once you've created your `config.yaml` file you can train the model using this command: @@ -66,6 +72,14 @@ luxonis_train train --config config.yaml trainer.batch_size 8 trainer.epochs 10 where key and value are space separated and sub-keys are dot (`.`) separated. If the configuration field is a list, then key/sub-key should be a number (e.g. `trainer.preprocessing.augmentations.0.name RotateCustom`). +## Evaluating + +To evaluate the model on a specific dataset split (train, test, or val), use the following command: + +```bash +luxonis_train eval --config --view +``` + ## Tuning To improve training performance you can use `Tuner` for hyperparameter optimization. diff --git a/configs/coco_model.yaml b/configs/coco_model.yaml index cad138a5..9af25feb 100755 --- a/configs/coco_model.yaml +++ b/configs/coco_model.yaml @@ -46,7 +46,7 @@ model: - name: ImplicitKeypointBBoxLoss attached_to: ImplicitKeypointBBoxHead params: - keypoint_distance_loss_weight: 0.5 + keypoint_regression_loss_weight: 0.5 keypoint_visibility_loss_weight: 0.7 bbox_loss_weight: 0.05 objectness_loss_weight: 0.2 diff --git a/luxonis_train/attached_modules/losses/README.md b/luxonis_train/attached_modules/losses/README.md index aafbc440..c5b1d348 100644 --- a/luxonis_train/attached_modules/losses/README.md +++ b/luxonis_train/attached_modules/losses/README.md @@ -11,6 +11,7 @@ List of all the available loss functions. - [SoftmaxFocalLoss](#softmaxfocalloss) - [AdaptiveDetectionLoss](#adaptivedetectionloss) - [ImplicitKeypointBBoxLoss](#implicitkeypointbboxloss) +- [EfficientKeypointBBoxLoss](#efficientkeypointbboxloss) ## CrossEntropyLoss @@ -97,10 +98,25 @@ Keypoint Similarity Loss](https://arxiv.org/ftp/arxiv/papers/2204/2204.06806.pdf | label_smoothing | float | 0.0 | Smoothing for [SmothBCEWithLogitsLoss](#smoothbcewithlogitsloss) for classification loss. | | min_objectness_iou | float | 0.0 | Minimum objectness IoU. | | bbox_loss_weight | float | 0.05 | Weight for bbox detection sub-loss. | -| keypoint_distance_loss_weight | float | 0.10 | Weight for keypoint distance sub-loss. | +| keypoint_regression_loss_weight | float | 0.5 | Weight for OKS sub-loss. | | keypoint_visibility_loss_weight | float | 0.6 | Weight for keypoint visibility sub-loss. | | class_loss_weight | float | 0.6 | Weight for classification sub-loss. | | objectness_loss_weight | float | 0.7 | Weight for objectness sub-loss. | | anchor_threshold | float | 4.0 | Threshold for matching anchors to targets. | | bias | float | 0.5 | Bias for matchinf anchors to targets. | | balance | list\[float\] | \[4.0, 1.0, 0.4\] | Balance for objectness loss. | + +## EfficientKeypointBBoxLoss + +Adapted from [YOLO-Pose: Enhancing YOLO for Multi Person Pose Estimation Using Object +Keypoint Similarity Loss](https://arxiv.org/ftp/arxiv/papers/2204/2204.06806.pdf). + +| Key | Type | Default value | Description | +| --------------------- | ------------------------------------------------- | ------------- | ----------------------------------------------------------------------------------- | +| viz_pw | float | 1.0 | Power for [BCEWithLogitsLoss](#bcewithlogitsloss) for keypoint visibility. | +| n_warmup_epochs | int | 4 | Number of epochs where ATSS assigner is used, after that we switch to TAL assigner. | +| iou_type | Literal\["none", "giou", "diou", "ciou", "siou"\] | "giou" | IoU type used for bbox regression sub-loss | +| class_loss_weight | float | 1.0 | Weight used for the classification sub-loss. | +| iou_loss_weight | float | 2.5 | Weight used for the IoU sub-loss. | +| regr_kpts_loss_weight | float | 1.5 | Weight used for the OKS sub-loss. | +| vis_kpts_loss_weight | float | 1.0 | Weight used for the keypoint visibility sub-loss. | diff --git a/luxonis_train/attached_modules/losses/__init__.py b/luxonis_train/attached_modules/losses/__init__.py index 737373d2..28585504 100644 --- a/luxonis_train/attached_modules/losses/__init__.py +++ b/luxonis_train/attached_modules/losses/__init__.py @@ -2,6 +2,7 @@ from .base_loss import BaseLoss from .bce_with_logits import BCEWithLogitsLoss from .cross_entropy import CrossEntropyLoss +from .efficient_keypoint_bbox_loss import EfficientKeypointBBoxLoss from .implicit_keypoint_bbox_loss import ImplicitKeypointBBoxLoss from .keypoint_loss import KeypointLoss from .sigmoid_focal_loss import SigmoidFocalLoss @@ -12,6 +13,7 @@ "AdaptiveDetectionLoss", "BCEWithLogitsLoss", "CrossEntropyLoss", + "EfficientKeypointBBoxLoss", "ImplicitKeypointBBoxLoss", "KeypointLoss", "BaseLoss", diff --git a/luxonis_train/attached_modules/losses/adaptive_detection_loss.py b/luxonis_train/attached_modules/losses/adaptive_detection_loss.py index 21291bfa..83660463 100644 --- a/luxonis_train/attached_modules/losses/adaptive_detection_loss.py +++ b/luxonis_train/attached_modules/losses/adaptive_detection_loss.py @@ -100,7 +100,6 @@ def prepare( feats = outputs["features"] pred_scores = outputs["class_scores"][0] pred_distri = outputs["distributions"][0] - batch_size = pred_scores.shape[0] device = pred_scores.device @@ -142,6 +141,7 @@ def prepare( assigned_bboxes, assigned_scores, mask_positive, + _, ) = self.atts_assigner( anchors, n_anchors_list, @@ -157,7 +157,8 @@ def prepare( assigned_bboxes, assigned_scores, mask_positive, - ) = self.tal_assigner.forward( + _, + ) = self.tal_assigner( pred_scores.detach(), pred_bboxes.detach() * stride_tensor, anchor_points, diff --git a/luxonis_train/attached_modules/losses/efficient_keypoint_bbox_loss.py b/luxonis_train/attached_modules/losses/efficient_keypoint_bbox_loss.py new file mode 100644 index 00000000..4fc2a7c0 --- /dev/null +++ b/luxonis_train/attached_modules/losses/efficient_keypoint_bbox_loss.py @@ -0,0 +1,391 @@ +from typing import Literal, cast + +import torch +import torch.nn.functional as F +from pydantic import Field +from torch import Tensor, nn +from torchvision.ops import box_convert +from typing_extensions import Annotated + +from luxonis_train.attached_modules.metrics.object_keypoint_similarity import ( + get_area_factor, + get_sigmas, +) +from luxonis_train.nodes import EfficientKeypointBBoxHead +from luxonis_train.utils.assigners import ATSSAssigner, TaskAlignedAssigner +from luxonis_train.utils.boxutils import ( + IoUType, + anchors_for_fpn_features, + compute_iou_loss, + dist2bbox, +) +from luxonis_train.utils.types import ( + BaseProtocol, + IncompatibleException, + Labels, + LabelType, + Packet, +) + +from .base_loss import BaseLoss +from .bce_with_logits import BCEWithLogitsLoss + + +class Protocol(BaseProtocol): + features: list[Tensor] + class_scores: Annotated[list[Tensor], Field(min_length=1, max_length=1)] + distributions: Annotated[list[Tensor], Field(min_length=1, max_length=1)] + + +class EfficientKeypointBBoxLoss( + BaseLoss[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor] +): + node: EfficientKeypointBBoxHead + + class NodePacket(Packet[Tensor]): + features: list[Tensor] + class_scores: Tensor + distributions: Tensor + + def __init__( + self, + n_warmup_epochs: int = 4, + iou_type: IoUType = "giou", + reduction: Literal["sum", "mean"] = "mean", + class_bbox_loss_weight: float = 1.0, + iou_loss_weight: float = 2.5, + viz_pw: float = 1.0, + regr_kpts_loss_weight: float = 1.5, + vis_kpts_loss_weight: float = 1.0, + sigmas: list[float] | None = None, + area_factor: float | None = None, + **kwargs, + ): + """BBox loss adapted from U{YOLOv6: A Single-Stage Object Detection Framework for Industrial Applications + }. It combines IoU based bbox regression loss and varifocal loss + for classification. + Code is adapted from U{https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/models}. + + @type n_warmup_epochs: int + @param n_warmup_epochs: Number of epochs where ATSS assigner is used, after that we switch to TAL assigner. + @type iou_type: L{IoUType} + @param iou_type: IoU type used for bbox regression loss. + @type reduction: Literal["sum", "mean"] + @param reduction: Reduction type for loss. + @type class_bbox_loss_weight: float + @param class_bbox_loss_weight: Weight of classification loss for bounding boxes. + @type regr_kpts_loss_weight: float + @param regr_kpts_loss_weight: Weight of regression loss for keypoints. + @type vis_kpts_loss_weight: float + @param vis_kpts_loss_weight: Weight of visibility loss for keypoints. + @type iou_loss_weight: float + @param iou_loss_weight: Weight of IoU loss. + @type sigmas: list[float] | None + @param sigmas: Sigmas used in KeypointLoss for OKS metric. If None then use COCO ones if possible or default ones. Defaults to C{None}. + @type area_factor: float | None + @param area_factor: Factor by which we multiply bbox area which is used in KeypointLoss. If None then use default one. Defaults to C{None}. + @type kwargs: dict + @param kwargs: Additional arguments to pass to L{BaseLoss}. + """ + super().__init__( + required_labels=[LabelType.BOUNDINGBOX], protocol=Protocol, **kwargs + ) + + if not isinstance(self.node, EfficientKeypointBBoxHead): + raise IncompatibleException( + f"Loss `{self.__class__.__name__}` is only " + "compatible with nodes of type `EfficientKeypointBBoxHead`." + ) + self.iou_type: IoUType = iou_type + self.reduction = reduction + self.n_classes = self.node.n_classes + self.stride = self.node.stride + self.grid_cell_size = self.node.grid_cell_size + self.grid_cell_offset = self.node.grid_cell_offset + self.original_img_size = self.node.original_in_shape[1:] + self.n_heads = self.node.n_heads + self.n_kps = self.node.n_keypoints + + self.b_cross_entropy = BCEWithLogitsLoss( + pos_weight=torch.tensor([viz_pw]), **kwargs + ) + self.sigmas = get_sigmas( + sigmas=sigmas, n_keypoints=self.n_kps, class_name=self.__class__.__name__ + ) + self.area_factor = get_area_factor( + area_factor, class_name=self.__class__.__name__ + ) + + self.n_warmup_epochs = n_warmup_epochs + self.atts_assigner = ATSSAssigner(topk=9, n_classes=self.n_classes) + self.tal_assigner = TaskAlignedAssigner( + topk=13, n_classes=self.n_classes, alpha=1.0, beta=6.0 + ) + + self.varifocal_loss = VarifocalLoss() + self.class_bbox_loss_weight = class_bbox_loss_weight + self.iou_loss_weight = iou_loss_weight + self.regr_kpts_loss_weight = regr_kpts_loss_weight + self.vis_kpts_loss_weight = vis_kpts_loss_weight + + def prepare( + self, outputs: Packet[Tensor], labels: Labels + ) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + feats = outputs["features"] + pred_scores = outputs["class_scores"][0] + pred_distri = outputs["distributions"][0] + pred_kpts = outputs["keypoints_raw"][0] + + batch_size = pred_scores.shape[0] + device = pred_scores.device + + target_bbox = labels["boundingbox"][0].to(device) + target_kpts = labels["keypoints"][0].to(device) + n_kpts = (target_kpts.shape[1] - 2) // 3 + + gt_bboxes_scale = torch.tensor( + [ + self.original_img_size[1], + self.original_img_size[0], + self.original_img_size[1], + self.original_img_size[0], + ], + device=device, + ) + gt_kpts_scale = torch.tensor( + [ + self.original_img_size[1], + self.original_img_size[0], + ], + device=device, + ) + ( + anchors, + anchor_points, + n_anchors_list, + stride_tensor, + ) = anchors_for_fpn_features( + feats, + self.stride, + self.grid_cell_size, + self.grid_cell_offset, + multiply_with_stride=True, + ) + + anchor_points_strided = anchor_points / stride_tensor + pred_bboxes = dist2bbox(pred_distri, anchor_points_strided) + pred_kpts = self.dist2kpts_noscale( + anchor_points_strided, pred_kpts.view(batch_size, -1, n_kpts, 3) + ) + + target_bbox = self._preprocess_bbox_target( + target_bbox, batch_size, gt_bboxes_scale + ) + + gt_bbox_labels = target_bbox[:, :, :1] + gt_xyxy = target_bbox[:, :, 1:] + mask_gt = (gt_xyxy.sum(-1, keepdim=True) > 0).float() + + if self._epoch < self.n_warmup_epochs: + ( + assigned_labels, + assigned_bboxes, + assigned_scores, + mask_positive, + assigned_gt_idx, + ) = self.atts_assigner( + anchors, + n_anchors_list, + gt_bbox_labels, + gt_xyxy, + mask_gt, + pred_bboxes.detach() * stride_tensor, + ) + else: + ( + assigned_labels, + assigned_bboxes, + assigned_scores, + mask_positive, + assigned_gt_idx, + ) = self.tal_assigner( + pred_scores.detach(), + pred_bboxes.detach() * stride_tensor, + anchor_points, + gt_bbox_labels, + gt_xyxy, + mask_gt, + ) + + batched_kpts = self._preprocess_kpts_target( + target_kpts, batch_size, gt_kpts_scale + ) + assigned_gt_idx_expanded = assigned_gt_idx.unsqueeze(-1).unsqueeze(-1) + selected_keypoints = batched_kpts.gather( + 1, assigned_gt_idx_expanded.expand(-1, -1, self.n_kps, 3) + ) + xy_components = selected_keypoints[:, :, :, :2] + normalized_xy = xy_components / stride_tensor.view(1, -1, 1, 1) + selected_keypoints = torch.cat( + (normalized_xy, selected_keypoints[:, :, :, 2:]), dim=-1 + ) + gt_kpt = selected_keypoints[mask_positive] + pred_kpts = pred_kpts[mask_positive] + assigned_bboxes = assigned_bboxes / stride_tensor + + area = ( + assigned_bboxes[mask_positive][:, 0] - assigned_bboxes[mask_positive][:, 2] + ) * ( + assigned_bboxes[mask_positive][:, 1] - assigned_bboxes[mask_positive][:, 3] + ) + + return ( + pred_bboxes, + pred_scores, + assigned_bboxes, + assigned_labels, + assigned_scores, + mask_positive, + gt_kpt, + pred_kpts, + area * self.area_factor, + ) + + def forward( + self, + pred_bboxes: Tensor, + pred_scores: Tensor, + assigned_bboxes: Tensor, + assigned_labels: Tensor, + assigned_scores: Tensor, + mask_positive: Tensor, + gt_kpts: Tensor, + pred_kpts: Tensor, + area: Tensor, + ): + device = pred_bboxes.device + sigmas = self.sigmas.to(device) + d = (gt_kpts[..., 0] - pred_kpts[..., 0]).pow(2) + ( + gt_kpts[..., 1] - pred_kpts[..., 1] + ).pow(2) + e = d / ((2 * sigmas).pow(2) * ((area.view(-1, 1) + 1e-9) * 2)) + mask = (gt_kpts[..., 2] > 0).float() + regression_loss = ( + ((1 - torch.exp(-e)) * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-9) + ).mean() + visibility_loss = self.b_cross_entropy.forward(pred_kpts[..., 2], mask) + + one_hot_label = F.one_hot(assigned_labels.long(), self.n_classes + 1)[..., :-1] + loss_cls = self.varifocal_loss(pred_scores, assigned_scores, one_hot_label) + + if assigned_scores.sum() > 1: + loss_cls /= assigned_scores.sum() + + loss_iou = compute_iou_loss( + pred_bboxes, + assigned_bboxes, + assigned_scores, + mask_positive, + reduction="sum", + iou_type=self.iou_type, + bbox_format="xyxy", + )[0] + + loss = ( + self.class_bbox_loss_weight * loss_cls + + self.iou_loss_weight * loss_iou + + regression_loss * self.regr_kpts_loss_weight + + visibility_loss * self.vis_kpts_loss_weight + ) + + sub_losses = { + "class": loss_cls.detach(), + "iou": loss_iou.detach(), + "regression": regression_loss.detach(), + "visibility": visibility_loss.detach(), + } + + return loss, sub_losses + + def _preprocess_bbox_target( + self, bbox_target: Tensor, batch_size: int, scale_tensor: Tensor + ) -> Tensor: + """Preprocess target bboxes in shape [batch_size, N, 5] where N is maximum + number of instances in one image.""" + sample_ids, counts = cast( + tuple[Tensor, Tensor], + torch.unique(bbox_target[:, 0].int(), return_counts=True), + ) + c_max = int(counts.max()) if counts.numel() > 0 else 0 + out_target = torch.zeros(batch_size, c_max, 5, device=bbox_target.device) + out_target[:, :, 0] = -1 + for id, count in zip(sample_ids, counts): + out_target[id, :count] = bbox_target[bbox_target[:, 0] == id][:, 1:] + + scaled_target = out_target[:, :, 1:5] * scale_tensor + out_target[..., 1:] = box_convert(scaled_target, "xywh", "xyxy") + return out_target + + def _preprocess_kpts_target( + self, kpts_target: Tensor, batch_size: int, scale_tensor: Tensor + ) -> Tensor: + """Preprocesses the target keypoints in shape [batch_size, N, n_keypoints, 3] + where N is the maximum number of keypoints in one image.""" + + _, counts = torch.unique(kpts_target[:, 0].int(), return_counts=True) + max_kpts = int(counts.max()) if counts.numel() > 0 else 0 + batched_keypoints = torch.zeros( + (batch_size, max_kpts, self.n_kps, 3), device=kpts_target.device + ) + for i in range(batch_size): + keypoints_i = kpts_target[kpts_target[:, 0] == i] + scaled_keypoints_i = keypoints_i[:, 2:].clone() + batched_keypoints[i, : keypoints_i.shape[0]] = scaled_keypoints_i.view( + -1, self.n_kps, 3 + ) + batched_keypoints[i, :, :, :2] *= scale_tensor[:2] + + return batched_keypoints + + def dist2kpts_noscale(self, anchor_points: Tensor, kpts: Tensor) -> Tensor: + """Adjusts and scales predicted keypoints relative to anchor points without + considering image stride.""" + adj_kpts = kpts.clone() + scale = 2.0 + x_adj = anchor_points[:, [0]] - 0.5 + y_adj = anchor_points[:, [1]] - 0.5 + + adj_kpts[..., :2] *= scale + adj_kpts[..., 0] += x_adj + adj_kpts[..., 1] += y_adj + return adj_kpts + + +class VarifocalLoss(nn.Module): + def __init__(self, alpha: float = 0.75, gamma: float = 2.0): + """Varifocal Loss is a loss function for training a dense object detector to predict + the IoU-aware classification score, inspired by focal loss. + Code is adapted from: U{https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/models/losses.py} + + @type alpha: float + @param alpha: alpha parameter in focal loss, default is 0.75. + @type gamma: float + @param gamma: gamma parameter in focal loss, default is 2.0. + """ + + super().__init__() + + self.alpha = alpha + self.gamma = gamma + + def forward( + self, pred_score: Tensor, target_score: Tensor, label: Tensor + ) -> Tensor: + weight = ( + self.alpha * pred_score.pow(self.gamma) * (1 - label) + target_score * label + ) + ce_loss = F.binary_cross_entropy( + pred_score.float(), target_score.float(), reduction="none" + ) + loss = (ce_loss * weight).sum() + return loss diff --git a/luxonis_train/attached_modules/losses/implicit_keypoint_bbox_loss.py b/luxonis_train/attached_modules/losses/implicit_keypoint_bbox_loss.py index 555d0d30..ff530b2a 100644 --- a/luxonis_train/attached_modules/losses/implicit_keypoint_bbox_loss.py +++ b/luxonis_train/attached_modules/losses/implicit_keypoint_bbox_loss.py @@ -45,8 +45,10 @@ def __init__( label_smoothing: float = 0.0, min_objectness_iou: float = 0.0, bbox_loss_weight: float = 0.05, - keypoint_distance_loss_weight: float = 0.10, keypoint_visibility_loss_weight: float = 0.6, + keypoint_regression_loss_weight: float = 0.5, + sigmas: list[float] | None = None, + area_factor: float | None = None, class_loss_weight: float = 0.6, objectness_loss_weight: float = 0.7, anchor_threshold: float = 4.0, @@ -72,10 +74,14 @@ def __init__( @param min_objectness_iou: Minimum objectness iou. Defaults to C{0.0}. @type bbox_loss_weight: float @param bbox_loss_weight: Weight for the bounding box loss. - @type keypoint_distance_loss_weight: float - @param keypoint_distance_loss_weight: Weight for the keypoint distance loss. Defaults to C{0.10}. @type keypoint_visibility_loss_weight: float @param keypoint_visibility_loss_weight: Weight for the keypoint visibility loss. Defaults to C{0.6}. + @type keypoint_regression_loss_weight: float + @param keypoint_regression_loss_weight: Weight for the keypoint regression loss. Defaults to C{0.5}. + @type sigmas: list[float] | None + @param sigmas: Sigmas used in KeypointLoss for OKS metric. If None then use COCO ones if possible or default ones. Defaults to C{None}. + @type area_factor: float | None + @param area_factor: Factor by which we multiply bbox area which is used in KeypointLoss. If None then use default one. Defaults to C{None}. @type class_loss_weight: float @param class_loss_weight: Weight for the class loss. Defaults to C{0.6}. @type objectness_loss_weight: float @@ -117,10 +123,10 @@ class Protocol(BaseProtocol): self.min_objectness_iou = min_objectness_iou self.bbox_weight = bbox_loss_weight - self.kpt_distance_weight = keypoint_distance_loss_weight self.class_weight = class_loss_weight self.objectness_weight = objectness_loss_weight self.kpt_visibility_weight = keypoint_visibility_loss_weight + self.keypoint_regression_loss_weight = keypoint_regression_loss_weight self.anchor_threshold = anchor_threshold self.bias = bias @@ -134,9 +140,10 @@ class Protocol(BaseProtocol): **kwargs, ) self.keypoint_loss = KeypointLoss( + n_keypoints=self.n_keypoints, bce_power=viz_pw, - distance_weight=keypoint_distance_loss_weight, - visibility_weight=keypoint_visibility_loss_weight, + sigmas=sigmas, + area_factor=area_factor, **kwargs, ) @@ -169,13 +176,15 @@ def prepare( boxes = labels["boundingbox"][0] nkpts = (kpts.shape[1] - 2) // 3 - targets = torch.zeros((len(boxes), nkpts * 2 + self.box_offset + 1)) + targets = torch.zeros((len(boxes), nkpts * 3 + self.box_offset + 1)) targets[:, :2] = boxes[:, :2] targets[:, 2 : self.box_offset + 1] = box_convert( boxes[:, 2:], "xywh", "cxcywh" ) - targets[:, self.box_offset + 1 :: 2] = kpts[:, 2::3] # insert kp x coordinates - targets[:, self.box_offset + 2 :: 2] = kpts[:, 3::3] # insert kp y coordinates + + targets[:, self.box_offset + 1 :: 3] = kpts[:, 2::3] # insert kp x coordinates + targets[:, self.box_offset + 2 :: 3] = kpts[:, 3::3] # insert kp y coordinates + targets[:, self.box_offset + 3 :: 3] = kpts[:, 4::3] # insert kp visibility n_targets = len(targets) @@ -203,7 +212,6 @@ def prepare( for i in range(self.num_heads): anchor = self.anchors[i] feature_height, feature_width = predictions[i].shape[2:4] - scaled_targets, xy_shifts = match_to_anchor( targets, anchor, @@ -259,7 +267,7 @@ def forward( "objectness": torch.tensor(0.0, device=device), "class": torch.tensor(0.0, device=device), "kpt_visibility": torch.tensor(0.0, device=device), - "kpt_distance": torch.tensor(0.0, device=device), + "kpt_regression": torch.tensor(0.0, device=device), } for pred, class_target, box_target, kpt_target, index, anchor, balance in zip( @@ -284,13 +292,16 @@ def forward( sub_losses["bboxes"] += bbox_loss * self.bbox_weight + area = box_target[:, 2] * box_target[:, 3] + _, kpt_sublosses = self.keypoint_loss.forward( pred_subset[:, self.box_offset + self.n_classes :], kpt_target.to(device), + area.to(device), ) - sub_losses["kpt_distance"] += ( - kpt_sublosses["distance"] * self.kpt_distance_weight + sub_losses["kpt_regression"] += ( + kpt_sublosses["regression"] * self.keypoint_regression_loss_weight ) sub_losses["kpt_visibility"] += ( kpt_sublosses["visibility"] * self.kpt_visibility_weight @@ -326,8 +337,14 @@ def forward( def _create_keypoint_target(self, scaled_targets: Tensor, box_xy_deltas: Tensor): keypoint_target = scaled_targets[:, self.box_offset + 1 : -1] for j in range(self.n_keypoints): - low = 2 * j - high = 2 * (j + 1) - keypoint_mask = keypoint_target[:, low:high] != 0 - keypoint_target[:, low:high][keypoint_mask] -= box_xy_deltas[keypoint_mask] + idx = 3 * j + keypoint_coords = keypoint_target[:, idx : idx + 2] + visibility = keypoint_target[:, idx + 2] + + keypoint_mask = visibility != 0 + keypoint_coords[keypoint_mask] -= box_xy_deltas[keypoint_mask] + + keypoint_target[:, idx : idx + 2] = keypoint_coords + keypoint_target[:, idx + 2] = visibility + return keypoint_target diff --git a/luxonis_train/attached_modules/losses/keypoint_loss.py b/luxonis_train/attached_modules/losses/keypoint_loss.py index b1ddd8ba..8a5640cb 100644 --- a/luxonis_train/attached_modules/losses/keypoint_loss.py +++ b/luxonis_train/attached_modules/losses/keypoint_loss.py @@ -4,6 +4,10 @@ from pydantic import Field from torch import Tensor +from luxonis_train.attached_modules.metrics.object_keypoint_similarity import ( + get_area_factor, + get_sigmas, +) from luxonis_train.utils.boxutils import process_keypoints_predictions from luxonis_train.utils.types import ( BaseProtocol, @@ -23,25 +27,44 @@ class Protocol(BaseProtocol): class KeypointLoss(BaseLoss[Tensor, Tensor]): def __init__( self, + n_keypoints: int, bce_power: float = 1.0, - distance_weight: float = 0.1, - visibility_weight: float = 0.6, + sigmas: list[float] | None = None, + area_factor: float | None = None, **kwargs, ): + """Keypoint based loss that is computed from OKS-based regression and visibility + loss. + + @type n_keypoints: int + @param n_keypoints: Number of keypoints. + @type bce_power: float + @param bce_power: Power used for BCE visibility loss. Defaults to C{1.0}. + @param sigmas: Sigmas used for OKS. If None then use COCO ones if possible or + default ones. Defaults to C{None}. + @type area_factor: float | None + @param area_factor: Factor by which we multiply bbox area. If None then use + default one. Defaults to C{None}. + """ + super().__init__( protocol=Protocol, required_labels=[LabelType.KEYPOINTS], **kwargs ) self.b_cross_entropy = BCEWithLogitsLoss( pos_weight=torch.tensor([bce_power]), **kwargs ) - self.distance_weight = distance_weight - self.visibility_weight = visibility_weight + self.sigmas = get_sigmas( + sigmas=sigmas, n_keypoints=n_keypoints, class_name=self.__class__.__name__ + ) + self.area_factor = get_area_factor( + area_factor, class_name=self.__class__.__name__ + ) def prepare(self, inputs: Packet[Tensor], labels: Labels) -> tuple[Tensor, Tensor]: - return torch.cat(inputs["keypoints"], dim=0), labels[LabelType.KEYPOINTS] + return torch.cat(inputs["keypoints"], dim=0), self.get_label(labels)[0] def forward( - self, prediction: Tensor, target: Tensor + self, prediction: Tensor, target: Tensor, area: Tensor ) -> tuple[Tensor, dict[str, Tensor]]: """Computes the keypoint loss and visibility loss for a given prediction and target. @@ -49,29 +72,36 @@ def forward( @type prediction: Tensor @param prediction: Predicted tensor of shape C{[n_detections, n_keypoints * 3]}. @type target: Tensor - @param target: Target tensor of shape C{[n_detections, n_keypoints * 2]}. - @rtype: tuple[Tensor, Tensor] - @return: A tuple containing the keypoint loss tensor of shape C{[1,]} and the - visibility loss tensor of shape C{[1,]}. + @param target: Target tensor of shape C{[n_detections, n_keypoints * 3]}. + @type area: Tensor + @param area: Area tensor of shape C{[n_detections]}. + @rtype: tuple[Tensor, dict[str, Tensor]] + @return: A tuple containing the total loss tensor of shape C{[1,]} and a + dictionary with the regression loss and visibility loss tensors. """ - x, y, visibility_score = process_keypoints_predictions(prediction) - gt_x = target[:, 0::2] - gt_y = target[:, 1::2] - - mask = target[:, 0::2] != 0 - visibility_loss = ( - self.b_cross_entropy.forward(visibility_score, mask.float()) - * self.visibility_weight - ) - distance = (x - gt_x) ** 2 + (y - gt_y) ** 2 + device = prediction.device + sigmas = self.sigmas.to(device) - loss_factor = (torch.sum(mask != 0) + torch.sum(mask == 0)) / ( - torch.sum(mask != 0) + 1e-9 - ) - distance_loss = ( - loss_factor - * (torch.log(distance + 1 + 1e-9) * mask).mean() - * self.distance_weight + pred_x, pred_y, pred_v = process_keypoints_predictions(prediction) + gt_x = target[:, 0::3] + gt_y = target[:, 1::3] + gt_v = (target[:, 2::3] > 0).float() + + visibility_loss = self.b_cross_entropy.forward(pred_v, gt_v) + scales = area * self.area_factor + + d = (gt_x - pred_x) ** 2 + (gt_y - pred_y) ** 2 + e = d / (2 * sigmas**2) / (scales.view(-1, 1) + 1e-9) / 2 + + regression_loss_unreduced = 1 - torch.exp(-e) + regression_loss_reduced = (regression_loss_unreduced * gt_v).sum(dim=1) / ( + gt_v.sum(dim=1) + 1e-9 ) - loss = distance_loss + visibility_loss - return loss, {"distance": distance_loss, "visibility": visibility_loss} + regression_loss = regression_loss_reduced.mean() + + total_loss = regression_loss + visibility_loss + + return total_loss, { + "regression": regression_loss, + "visibility": visibility_loss, + } diff --git a/luxonis_train/attached_modules/metrics/README.md b/luxonis_train/attached_modules/metrics/README.md index 4e452158..17735540 100644 --- a/luxonis_train/attached_modules/metrics/README.md +++ b/luxonis_train/attached_modules/metrics/README.md @@ -42,3 +42,5 @@ boxes. ## MeanAveragePrecisionKeypoints Similar to [MeanAveragePrecision](#meanaverageprecision), but uses [OKS](#objectkeypointsimilarity) as `IoU` measure. +For a deeper understanding of how OKS works, please refer to the detailed explanation provided [here](https://learnopencv.com/object-keypoint-similarity/). +Evaluation leverages COCO evaluation framework (COCOeval) to assess mAP performance. diff --git a/luxonis_train/attached_modules/metrics/mean_average_precision.py b/luxonis_train/attached_modules/metrics/mean_average_precision.py index 67c010ec..c3eaad7e 100644 --- a/luxonis_train/attached_modules/metrics/mean_average_precision.py +++ b/luxonis_train/attached_modules/metrics/mean_average_precision.py @@ -38,7 +38,9 @@ def update( def prepare( self, outputs: Packet[Tensor], labels: Labels ) -> tuple[list[dict[str, Tensor]], list[dict[str, Tensor]]]: - label = labels[self.task][0] + label = labels["boundingbox"][ + 0 + ] # TODO: Think of a better way to deal with multi-task heads output_nms = self.get_input_tensors(outputs) image_size = self.node.original_in_shape[1:] diff --git a/luxonis_train/attached_modules/metrics/mean_average_precision_keypoints.py b/luxonis_train/attached_modules/metrics/mean_average_precision_keypoints.py index 31bc7557..27df8102 100644 --- a/luxonis_train/attached_modules/metrics/mean_average_precision_keypoints.py +++ b/luxonis_train/attached_modules/metrics/mean_average_precision_keypoints.py @@ -8,6 +8,10 @@ from torch import Tensor from torchvision.ops import box_convert +from luxonis_train.attached_modules.metrics.object_keypoint_similarity import ( + get_area_factor, + get_sigmas, +) from luxonis_train.utils.types import ( BBoxProtocol, KeypointProtocol, @@ -46,7 +50,9 @@ class MeanAveragePrecisionKeypoints(BaseMetric): def __init__( self, - kpt_sigmas: Tensor | None = None, + sigmas: list[float] | None = None, + area_factor: float | None = None, + max_dets: int = 20, box_format: Literal["xyxy", "xywh", "cxcywh"] = "xyxy", **kwargs, ): @@ -59,8 +65,13 @@ def __init__( @type num_keypoints: int @param num_keypoints: Number of keypoints. - @type kpt_sigmas: Tensor or None - @param kpt_sigmas: Sigma for each keypoint to weigh its importance, if None use same weights for all. + @type sigmas: list[float] | None + @param sigmas: Sigma for each keypoint to weigh its importance, if C{None}, then + use COCO if possible otherwise defaults. Defaults to C{None}. + @type area_factor: float | None + @param area_factor: Factor by which we multiply bbox area. If None then use default one. Defaults to C{None}. + @type max_dets: int, + @param max_dets: Maximum number of detections to be considered per image. Defaults to C{20}. @type box_format: Literal["xyxy", "xywh", "cxcywh"] @param box_format: Input bbox format. @type kwargs: Any @@ -74,9 +85,9 @@ def __init__( self.n_keypoints = self.node.n_keypoints - if kpt_sigmas is not None and len(kpt_sigmas) != self.n_keypoints: - raise ValueError("Expected kpt_sigmas to be of shape (num_keypoints).") - self.kpt_sigmas = kpt_sigmas or torch.ones(self.n_keypoints) + self.sigmas = get_sigmas(sigmas, self.n_keypoints, self.__class__.__name__) + self.area_factor = get_area_factor(area_factor, self.__class__.__name__) + self.max_dets = max_dets allowed_box_formats = ("xyxy", "xywh", "cxcywh") if box_format not in allowed_box_formats: @@ -214,7 +225,7 @@ def compute(self) -> tuple[Tensor, dict[str, Tensor]]: coco_preds.dataset = self._get_coco_format( self.pred_boxes, self.pred_keypoints, - self.groundtruth_labels, + self.pred_labels, scores=self.pred_scores, ) # type: ignore @@ -223,7 +234,8 @@ def compute(self) -> tuple[Tensor, dict[str, Tensor]]: coco_preds.createIndex() self.coco_eval = COCOeval(coco_target, coco_preds, iouType="keypoints") - self.coco_eval.params.kpt_oks_sigmas = self.kpt_sigmas.cpu().numpy() + self.coco_eval.params.kpt_oks_sigmas = self.sigmas.cpu().numpy() + self.coco_eval.params.maxDets = [self.max_dets] self.coco_eval.evaluate() self.coco_eval.accumulate() @@ -293,19 +305,22 @@ def _get_coco_format( if area is not None and area[image_id][k].cpu().item() > 0: area_stat = area[image_id][k].cpu().tolist() else: - area_stat = image_box[2] * image_box[3] + area_stat = image_box[2] * image_box[3] * self.area_factor + num_keypoints = len( + [i for i in range(2, len(image_kpt), 3) if image_kpt[i] != 0] + ) # number of annotated keypoints annotation = { "id": annotation_id, "image_id": image_id, "bbox": image_box, "area": area_stat, "category_id": image_label, - "iscrowd": crowds[image_id][k].cpu().tolist() - if crowds is not None - else 0, + "iscrowd": ( + crowds[image_id][k].cpu().tolist() if crowds is not None else 0 + ), "keypoints": image_kpt, - "num_keypoints": self.n_keypoints, + "num_keypoints": num_keypoints, } if scores is not None: diff --git a/luxonis_train/attached_modules/metrics/object_keypoint_similarity.py b/luxonis_train/attached_modules/metrics/object_keypoint_similarity.py index c1768012..cfbae11f 100644 --- a/luxonis_train/attached_modules/metrics/object_keypoint_similarity.py +++ b/luxonis_train/attached_modules/metrics/object_keypoint_similarity.py @@ -1,3 +1,5 @@ +import logging + import torch from scipy.optimize import linear_sum_assignment from torch import Tensor @@ -12,22 +14,12 @@ from .base_metric import BaseMetric +logger = logging.getLogger(__name__) + class ObjectKeypointSimilarity( BaseMetric[list[dict[str, Tensor]], list[dict[str, Tensor]]] ): - """Object Keypoint Similarity metric for evaluating keypoint predictions. - - @type n_keypoints: int - @param n_keypoints: Number of keypoints. - @type kpt_sigmas: Tensor - @param kpt_sigmas: Sigma for each keypoint to weigh its importance, if C{None}, then - use same weights for all. - @type use_cocoeval_oks: bool - @param use_cocoeval_oks: Whether to use same OKS formula as in COCOeval or use the - one from definition. - """ - is_differentiable: bool = False higher_is_better: bool = True full_state_update: bool = True @@ -41,10 +33,25 @@ class ObjectKeypointSimilarity( def __init__( self, n_keypoints: int | None = None, - kpt_sigmas: Tensor | None = None, - use_cocoeval_oks: bool = False, + sigmas: list[float] | None = None, + area_factor: float | None = None, + use_cocoeval_oks: bool = True, **kwargs, ) -> None: + """Object Keypoint Similarity metric for evaluating keypoint predictions. + + @type n_keypoints: int + @param n_keypoints: Number of keypoints. + @type sigmas: list[float] | None + @param sigmas: Sigma for each keypoint to weigh its importance, if C{None}, then + use COCO if possible otherwise defaults. Defaults to C{None}. + @type area_factor: float | None + @param area_factor: Factor by which we multiply bbox area. If None then use + default one. Defaults to C{None}. + @type use_cocoeval_oks: bool + @param use_cocoeval_oks: Whether to use same OKS formula as in COCOeval or use + the one from definition. Defaults to C{True}. + """ super().__init__( required_labels=[LabelType.KEYPOINTS], protocol=KeypointProtocol, **kwargs ) @@ -55,9 +62,9 @@ def __init__( f"to {self.__class__.__name__}." ) self.n_keypoints = n_keypoints or self.node.n_keypoints - if kpt_sigmas is not None and len(kpt_sigmas) != self.n_keypoints: - raise ValueError("Expected kpt_sigmas to be of shape (num_keypoints).") - self.kpt_sigmas = kpt_sigmas or torch.ones(self.n_keypoints) / self.n_keypoints + + self.sigmas = get_sigmas(sigmas, self.n_keypoints, self.__class__.__name__) + self.area_factor = get_area_factor(area_factor, self.__class__.__name__) self.use_cocoeval_oks = use_cocoeval_oks self.add_state("pred_keypoints", default=[], dist_reduce_fx=None) @@ -93,7 +100,7 @@ def prepare( curr_kpts[:, 1::3] *= image_size[0] curr_bboxs_widths = curr_bboxs[:, 2] - curr_bboxs[:, 0] curr_bboxs_heights = curr_bboxs[:, 3] - curr_bboxs[:, 1] - curr_scales = torch.sqrt(curr_bboxs_widths * curr_bboxs_heights) + curr_scales = curr_bboxs_widths * curr_bboxs_heights * self.area_factor label_list_oks.append({"keypoints": curr_kpts, "scales": curr_scales}) return output_list_oks, label_list_oks @@ -136,7 +143,7 @@ def update( def compute(self) -> Tensor: """Computes the OKS metric based on the inner state.""" - self.kpt_sigmas = self.kpt_sigmas.to(self.device) + self.sigmas = self.sigmas.to(self.device) image_mean_oks = torch.zeros(len(self.groundtruth_keypoints)) for i, (pred_kpts, gt_kpts, gt_scales) in enumerate( zip( @@ -145,7 +152,13 @@ def compute(self) -> Tensor: ): gt_kpts = torch.reshape(gt_kpts, (-1, self.n_keypoints, 3)) # [N, K, 3] - image_ious = self._compute_oks(pred_kpts, gt_kpts, gt_scales) # [M, N] + image_ious = compute_oks( + pred_kpts, + gt_kpts, + gt_scales, + self.sigmas, + self.use_cocoeval_oks, + ) # [M, N] gt_indices, pred_indices = linear_sum_assignment( image_ious.cpu().numpy(), maximize=True ) @@ -156,48 +169,115 @@ def compute(self) -> Tensor: return final_oks - def _compute_oks(self, pred: Tensor, gt: Tensor, scales: Tensor) -> Tensor: - """Compute Object Keypoint Similarity between every GT and prediction. - - @type pred: Tensor[N, K, 3] - @param pred: Predicted keypoints. - @type gt: Tensor[M, K, 3] - @param gt: Groundtruth keypoints. - @type scales: Tensor[M] - @param scales: Scales of the bounding boxes. - @rtype: Tensor - @return: Object Keypoint Similarity every pred and gt [M, N] - """ - eps = 1e-7 - distances = (gt[:, None, :, 0] - pred[..., 0]) ** 2 + ( - gt[:, None, :, 1] - pred[..., 1] - ) ** 2 - kpt_mask = gt[..., 2] != 0 # only compute on visible keypoints - if self.use_cocoeval_oks: - # use same formula as in COCOEval script here: - # https://github.com/cocodataset/cocoapi/blob/8c9bcc3cf640524c4c20a9c40e89cb6a2f2fa0e9/PythonAPI/pycocotools/cocoeval.py#L229 - oks = ( - distances - / (2 * self.kpt_sigmas) ** 2 - / (scales[:, None, None] + eps) - / 2 - ) - else: - # use same formula as defined here: https://cocodataset.org/#keypoints-eval - oks = ( - distances - / ((scales[:, None, None] + eps) * self.kpt_sigmas.to(scales.device)) - ** 2 - / 2 - ) - return (torch.exp(-oks) * kpt_mask[:, None]).sum(-1) / ( - kpt_mask.sum(-1)[:, None] + eps +def compute_oks( + pred: Tensor, + gt: Tensor, + scales: Tensor, + sigmas: Tensor, + use_cocoeval_oks: bool, +) -> Tensor: + """Compute Object Keypoint Similarity between every GT and prediction. + + @type pred: Tensor[N, K, 3] + @param pred: Predicted keypoints. + @type gt: Tensor[M, K, 3] + @param gt: Groundtruth keypoints. + @type scales: Tensor[M] + @param scales: Scales of the bounding boxes. + @type sigmas: Tensor + @param sigmas: Sigma for each keypoint to weigh its importance, if C{None}, then use + same weights for all. + @type use_cocoeval_oks: bool + @param use_cocoeval_oks: Whether to use same OKS formula as in COCOeval or use the + one from definition. + @rtype: Tensor + @return: Object Keypoint Similarity every pred and gt [M, N] + """ + eps = 1e-7 + distances = (gt[:, None, :, 0] - pred[..., 0]) ** 2 + ( + gt[:, None, :, 1] - pred[..., 1] + ) ** 2 + kpt_mask = gt[..., 2] != 0 # only compute on visible keypoints + if use_cocoeval_oks: + # use same formula as in COCOEval script here: + # https://github.com/cocodataset/cocoapi/blob/8c9bcc3cf640524c4c20a9c40e89cb6a2f2fa0e9/PythonAPI/pycocotools/cocoeval.py#L229 + oks = distances / (2 * sigmas) ** 2 / (scales[:, None, None] + eps) / 2 + else: + # use same formula as defined here: https://cocodataset.org/#keypoints-eval + oks = ( + distances + / ((scales[:, None, None] + eps) * sigmas.to(scales.device)) ** 2 + / 2 ) + return (torch.exp(-oks) * kpt_mask[:, None]).sum(-1) / ( + kpt_mask.sum(-1)[:, None] + eps + ) + def fix_empty_tensors(input_tensor: Tensor) -> Tensor: """Empty tensors can cause problems in DDP mode, this methods corrects them.""" if input_tensor.numel() == 0 and input_tensor.ndim == 1: return input_tensor.unsqueeze(0) return input_tensor + + +def get_sigmas( + sigmas: list[float] | None, n_keypoints: int, class_name: str | None +) -> Tensor: + """Validate and set the sigma values.""" + if sigmas is not None: + if len(sigmas) == n_keypoints: + return torch.tensor(sigmas, dtype=torch.float32) + else: + error_msg = "The length of the sigmas list must be the same as the number of keypoints." + if class_name: + error_msg = f"[{class_name}] {error_msg}" + raise ValueError(error_msg) + else: + if n_keypoints == 17: + warn_msg = "Default COCO sigmas are being used." + if class_name: + warn_msg = f"[{class_name}] {warn_msg}" + logger.warning(warn_msg) + return torch.tensor( + [ + 0.026, + 0.025, + 0.025, + 0.035, + 0.035, + 0.079, + 0.079, + 0.072, + 0.072, + 0.062, + 0.062, + 0.107, + 0.107, + 0.087, + 0.087, + 0.089, + 0.089, + ], + dtype=torch.float32, + ) + else: + warn_msg = "Default sigma of 0.04 is being used for each keypoint." + if class_name: + warn_msg = f"[{class_name}] {warn_msg}" + logger.warning(warn_msg) + return torch.tensor([0.04] * n_keypoints, dtype=torch.float32) + + +def get_area_factor(area_factor: float | None, class_name: str | None) -> float: + """Set the default area factor if not defined.""" + if area_factor is None: + warn_msg = "Default area_factor of 0.53 is being used bbox area scaling." + if class_name: + warn_msg = f"[{class_name}] {warn_msg}" + logger.warning(warn_msg) + return 0.53 + else: + return area_factor diff --git a/luxonis_train/attached_modules/visualizers/keypoint_visualizer.py b/luxonis_train/attached_modules/visualizers/keypoint_visualizer.py index 6594912f..18d45ece 100644 --- a/luxonis_train/attached_modules/visualizers/keypoint_visualizer.py +++ b/luxonis_train/attached_modules/visualizers/keypoint_visualizer.py @@ -56,7 +56,7 @@ def draw_predictions( ) -> Tensor: viz = torch.zeros_like(canvas) for i in range(len(canvas)): - prediction = predictions[i][:, 1:] + prediction = predictions[i] mask = prediction[..., 2] < visibility_threshold visible_kpts = prediction[..., :2] * (~mask).unsqueeze(-1).float() viz[i] = draw_keypoints( diff --git a/luxonis_train/core/archiver.py b/luxonis_train/core/archiver.py index a42d2ec7..9e2b7c5a 100644 --- a/luxonis_train/core/archiver.py +++ b/luxonis_train/core/archiver.py @@ -289,7 +289,13 @@ def _get_head_specific_parameters( parameters["max_det"] = head_node.max_det parameters["n_keypoints"] = head_node.n_keypoints parameters["anchors"] = head_node.anchors.tolist() - + elif head_name == "EfficientKeypointBBoxHead": + # or appropriate subtype + head_node = self.lightning_module._modules["nodes"][head_alias] + parameters["iou_threshold"] = head_node.iou_thres + parameters["conf_threshold"] = head_node.conf_thres + parameters["max_det"] = head_node.max_det + parameters["n_keypoints"] = head_node.n_keypoints else: raise ValueError("Unknown head name") return parameters @@ -310,6 +316,8 @@ def _get_head_outputs(self, head_name) -> dict: head_outputs["predictions"] = self.outputs[0]["name"] elif head_name == "ImplicitKeypointBBoxHead": head_outputs["predictions"] = self.outputs[0]["name"] + elif head_name == "EfficientKeypointBBoxHead": + head_outputs["predictions"] = self.outputs[0]["name"] else: raise ValueError("Unknown head name") return head_outputs diff --git a/luxonis_train/nodes/README.md b/luxonis_train/nodes/README.md index 637c5026..6a29d237 100644 --- a/luxonis_train/nodes/README.md +++ b/luxonis_train/nodes/README.md @@ -20,6 +20,7 @@ arbitrarily as long as the two nodes are compatible with each other. - [BiSeNetHead](#bisenethead) - [EfficientBBoxHead](#efficientbboxhead) - [ImplicitKeypointBBoxHead](#implicitkeypointbboxhead) +- [EfficientKeypointBBoxHead](#efficientkeypointbboxhead) Every node takes these parameters: @@ -193,3 +194,16 @@ Adapted from [here](https://arxiv.org/pdf/2207.02696.pdf). | init_coco_biases | bool | True | Whether to use COCO bias and weight initialization | | conf_thres | float | 0.25 | confidence threshold for nms (used for evaluation) | | iou_thres | float | 0.45 | iou threshold for nms (used for evaluation) | + +## EfficientKeypointBBoxHead + +Adapted from [here](https://arxiv.org/pdf/2207.02696.pdf). + +**Params** + +| Key | Type | Default value | Description | +| ----------- | ----------- | ------------- | -------------------------------------------------- | +| n_keypoints | int \| None | None | Number of keypoints. | +| n_heads | int | 3 | Number of output heads | +| conf_thres | float | 0.25 | confidence threshold for nms (used for evaluation) | +| iou_thres | float | 0.45 | iou threshold for nms (used for evaluation) | diff --git a/luxonis_train/nodes/__init__.py b/luxonis_train/nodes/__init__.py index 9a506c1f..4c90abaa 100644 --- a/luxonis_train/nodes/__init__.py +++ b/luxonis_train/nodes/__init__.py @@ -3,6 +3,7 @@ from .classification_head import ClassificationHead from .contextspatial import ContextSpatial from .efficient_bbox_head import EfficientBBoxHead +from .efficient_keypoint_bbox_head import EfficientKeypointBBoxHead from .efficientnet import EfficientNet from .efficientrep import EfficientRep from .implicit_keypoint_bbox_head import ImplicitKeypointBBoxHead @@ -22,6 +23,7 @@ "EfficientBBoxHead", "EfficientNet", "EfficientRep", + "EfficientKeypointBBoxHead", "ImplicitKeypointBBoxHead", "BaseNode", "MicroNet", diff --git a/luxonis_train/nodes/efficient_bbox_head.py b/luxonis_train/nodes/efficient_bbox_head.py index e7b23288..23728af1 100644 --- a/luxonis_train/nodes/efficient_bbox_head.py +++ b/luxonis_train/nodes/efficient_bbox_head.py @@ -50,7 +50,9 @@ def __init__( @type max_det: int @param max_det: Maximum number of detections retained after NMS. Defaults to C{300}. """ - super().__init__(_task_type=LabelType.BOUNDINGBOX, **kwargs) + super().__init__( + _task_type=kwargs.pop("_task_type", LabelType.BOUNDINGBOX), **kwargs + ) self.n_heads = n_heads @@ -126,7 +128,7 @@ def _fit_stride_to_num_heads(self): """Returns correct stride for number of heads and attach index.""" stride = torch.tensor( [ - self.original_in_shape[1] / x[1] # type: ignore + self.original_in_shape[1] / x[2] # type: ignore for x in self.in_sizes[: self.n_heads] ], dtype=torch.int, diff --git a/luxonis_train/nodes/efficient_keypoint_bbox_head.py b/luxonis_train/nodes/efficient_keypoint_bbox_head.py new file mode 100644 index 00000000..dabb62c5 --- /dev/null +++ b/luxonis_train/nodes/efficient_keypoint_bbox_head.py @@ -0,0 +1,207 @@ +from typing import Literal + +import torch +from torch import Tensor, nn + +from luxonis_train.nodes.blocks import ConvModule +from luxonis_train.utils.boxutils import ( + anchors_for_fpn_features, + dist2bbox, + non_max_suppression, +) +from luxonis_train.utils.types import LabelType, Packet + +from .efficient_bbox_head import EfficientBBoxHead + + +class EfficientKeypointBBoxHead(EfficientBBoxHead): + def __init__( + self, + n_keypoints: int | None = None, + n_heads: Literal[2, 3, 4] = 3, + conf_thres: float = 0.25, + iou_thres: float = 0.45, + max_det: int = 300, + **kwargs, + ): + """Head for object and keypoint detection. + + Adapted from U{YOLOv6: A Single-Stage Object Detection Framework for Industrial + Applications}. + + @param n_keypoints: Number of keypoints. If not defined, inferred + from the dataset metadata (if provided). Defaults to C{None}. + @type n_keypoints: int | None + + @param n_heads: Number of output heads. Defaults to C{3}. + B{Note:} Should be same also on neck in most cases. + @type n_heads: int + + @param conf_thres: Threshold for confidence. Defaults to C{0.25}. + @type conf_thres: float + + @param iou_thres: Threshold for IoU. Defaults to C{0.45}. + @type iou_thres: float + + @param max_det: Maximum number of detections retained after NMS. Defaults to C{300}. + @type max_det: int + """ + super().__init__( + n_heads=n_heads, + conf_thres=conf_thres, + iou_thres=iou_thres, + max_det=max_det, + _task_type=LabelType.KEYPOINTS, + **kwargs, + ) + + n_keypoints = n_keypoints or self.dataset_metadata._n_keypoints + + if n_keypoints is None: + raise ValueError( + "Number of keypoints must be specified either in the constructor or " + "in the dataset metadata." + ) + + self.n_keypoints = n_keypoints + self.nk = n_keypoints * 3 + + mid_ch = max(self.in_channels[0] // 4, self.nk) + self.kpt_layers = nn.ModuleList( + nn.Sequential( + ConvModule(x, mid_ch, 3, 1, 1, activation=nn.SiLU()), + ConvModule(mid_ch, mid_ch, 3, 1, 1, activation=nn.SiLU()), + nn.Conv2d(mid_ch, self.nk, 1, 1), + ) + for x in self.in_channels + ) + + def forward( + self, inputs: list[Tensor] + ) -> tuple[list[Tensor], list[Tensor], list[Tensor], list[Tensor]]: + features, cls_score_list, reg_distri_list = super().forward(inputs) + + _, self.anchor_points, _, self.stride_tensor = anchors_for_fpn_features( + features, + self.stride, + self.grid_cell_size, + self.grid_cell_offset, + multiply_with_stride=False, + ) + + kpt_list: list[Tensor] = [] + for i in range(self.n_heads): + kpt_pred = self.kpt_layers[i](inputs[i]) + kpt_list.append(kpt_pred) + + return features, cls_score_list, reg_distri_list, kpt_list + + def wrap( + self, output: tuple[list[Tensor], list[Tensor], list[Tensor], list[Tensor]] + ) -> Packet[Tensor]: + features, cls_score_list, reg_distri_list, kpt_list = output + bs = features[0].shape[0] + if self.export: + outputs = [] + for out_cls, out_reg, out_kpts in zip( + cls_score_list, reg_distri_list, kpt_list, strict=True + ): + chunks = out_kpts.split(3, dim=1) + modified_chunks = [] + for chunk in chunks: + x = chunk[:, 0:1, :, :] + y = chunk[:, 1:2, :, :] + v = torch.sigmoid(chunk[:, 2:3, :, :]) + modified_chunk = torch.cat([x, y, v], dim=1) + modified_chunks.append(modified_chunk) + out_kpts_modified = torch.cat(modified_chunks, dim=1) + out = torch.cat([out_reg, out_cls, out_kpts_modified], dim=1) + outputs.append(out) + return {"outputs": outputs} + cls_tensor = torch.cat( + [cls_score_list[i].flatten(2) for i in range(len(cls_score_list))], dim=2 + ).permute(0, 2, 1) + reg_tensor = torch.cat( + [reg_distri_list[i].flatten(2) for i in range(len(reg_distri_list))], dim=2 + ).permute(0, 2, 1) + kpt_tensor = torch.cat( + [ + kpt_list[i].view(bs, self.nk, -1).flatten(2) + for i in range(len(kpt_list)) + ], + dim=2, + ).permute(0, 2, 1) + + if self.training: + return { + "features": features, + "class_scores": [cls_tensor], + "distributions": [reg_tensor], + "keypoints_raw": [kpt_tensor], + } + + pred_kpt = self._dist2kpts(kpt_tensor) + detections = self._process_to_bbox_and_kps( + (features, cls_tensor, reg_tensor, pred_kpt) + ) + return { + "boundingbox": [detection[:, :6] for detection in detections], + "features": features, + "class_scores": [cls_tensor], + "distributions": [reg_tensor], + "keypoints": [ + detection[:, 6:].reshape(-1, self.n_keypoints, 3) + for detection in detections + ], + "keypoints_raw": [kpt_tensor], + } + + def _dist2kpts(self, kpts): + """Decodes keypoints.""" + y = kpts.clone() + + anchor_points_transposed = self.anchor_points.transpose(0, 1) + stride_tensor = self.stride_tensor.squeeze(-1) + + stride_tensor = stride_tensor.view(1, -1, 1) + anchor_points_x = anchor_points_transposed[0].view(1, -1, 1) + anchor_points_y = anchor_points_transposed[1].view(1, -1, 1) + + y[:, :, 0::3] = (y[:, :, 0::3] * 2.0 + (anchor_points_x - 0.5)) * stride_tensor + y[:, :, 1::3] = (y[:, :, 1::3] * 2.0 + (anchor_points_y - 0.5)) * stride_tensor + y[:, :, 2::3] = y[:, :, 2::3].sigmoid() + + return y + + def _process_to_bbox_and_kps( + self, output: tuple[list[Tensor], Tensor, Tensor, Tensor] + ) -> list[Tensor]: + """Performs post-processing of the output and returns bboxs after NMS.""" + features, cls_score_list, reg_dist_list, keypoints = output + + pred_bboxes = dist2bbox(reg_dist_list, self.anchor_points, out_format="xyxy") + + pred_bboxes *= self.stride_tensor + output_merged = torch.cat( + [ + pred_bboxes, + torch.ones( + (features[-1].shape[0], pred_bboxes.shape[1], 1), + dtype=pred_bboxes.dtype, + device=pred_bboxes.device, + ), + cls_score_list, + keypoints, + ], + dim=-1, + ) + + return non_max_suppression( + output_merged, + n_classes=self.n_classes, + conf_thres=self.conf_thres, + iou_thres=self.iou_thres, + bbox_format="xyxy", + max_det=self.max_det, + predicts_objectness=False, + ) diff --git a/luxonis_train/nodes/enums/head_categorization.py b/luxonis_train/nodes/enums/head_categorization.py index 56f98ff3..a2854b3a 100644 --- a/luxonis_train/nodes/enums/head_categorization.py +++ b/luxonis_train/nodes/enums/head_categorization.py @@ -7,6 +7,7 @@ class ImplementedHeads(Enum): ClassificationHead = "Classification" EfficientBBoxHead = "ObjectDetectionYOLO" ImplicitKeypointBBoxHead = "KeypointDetectionYOLO" + EfficientKeypointBBoxHead = "Keypoint" SegmentationHead = "Segmentation" BiSeNetHead = "Segmentation" @@ -17,5 +18,6 @@ class ImplementedHeadsIsSoxtmaxed(Enum): ClassificationHead = False EfficientBBoxHead = None ImplicitKeypointBBoxHead = None + EfficientKeypointBBoxHead = None SegmentationHead = False BiSeNetHead = False diff --git a/luxonis_train/utils/assigners/atts_assigner.py b/luxonis_train/utils/assigners/atts_assigner.py index 26b4dc23..f4989b54 100644 --- a/luxonis_train/utils/assigners/atts_assigner.py +++ b/luxonis_train/utils/assigners/atts_assigner.py @@ -38,7 +38,7 @@ def forward( gt_bboxes: Tensor, mask_gt: Tensor, pred_bboxes: Tensor, - ) -> tuple[Tensor, Tensor, Tensor, Tensor]: + ) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: """Assigner's forward method which generates final assignments. @type anchor_bboxes: Tensor @@ -53,7 +53,7 @@ def forward( @param mask_gt: Mask for valid GTs [bs, n_max_boxes, 1] @type pred_bboxes: Tensor @param pred_bboxes: Predicted bboxes of shape [bs, n_anchors, 4] - @rtype: tuple[Tensor, Tensor, Tensor, Tensor] + @rtype: tuple[Tensor, Tensor, Tensor, Tensor, Tensor] @return: Assigned labels of shape [bs, n_anchors], assigned bboxes of shape [bs, n_anchors, 4], assigned scores of shape [bs, n_anchors, n_classes] and output positive mask of shape [bs, n_anchors]. @@ -70,6 +70,7 @@ def forward( torch.zeros([self.bs, self.n_anchors, 4]).to(device), torch.zeros([self.bs, self.n_anchors, self.n_classes]).to(device), torch.zeros([self.bs, self.n_anchors]).to(device), + torch.zeros([self.bs, self.n_anchors]).to(device), ) gt_bboxes_flat = gt_bboxes.reshape([-1, 4]) @@ -124,6 +125,7 @@ def forward( assigned_bboxes, assigned_scores, out_mask_positive, + assigned_gt_idx, ) def _get_bbox_center(self, bbox: Tensor) -> Tensor: diff --git a/luxonis_train/utils/assigners/tal_assigner.py b/luxonis_train/utils/assigners/tal_assigner.py index 0765ad6a..08b5b461 100644 --- a/luxonis_train/utils/assigners/tal_assigner.py +++ b/luxonis_train/utils/assigners/tal_assigner.py @@ -50,7 +50,7 @@ def forward( gt_labels: Tensor, gt_bboxes: Tensor, mask_gt: Tensor, - ) -> tuple[Tensor, Tensor, Tensor, Tensor]: + ) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: """Assigner's forward method which generates final assignments. @type pred_scores: Tensor @@ -65,7 +65,7 @@ def forward( @param gt_bboxes: Initial GT bboxes [bs, n_max_boxes, 4] @type mask_gt: Tensor @param mask_gt: Mask for valid GTs [bs, n_max_boxes, 1] - @rtype: tuple[Tensor, Tensor, Tensor, Tensor] + @rtype: tuple[Tensor, Tensor, Tensor, Tensor, Tensor] @return: Assigned labels of shape [bs, n_anchors], assigned bboxes of shape [bs, n_anchors, 4], assigned scores of shape [bs, n_anchors, n_classes] and output mask of shape [bs, n_anchors] @@ -80,6 +80,7 @@ def forward( torch.zeros_like(pred_bboxes).to(device), torch.zeros_like(pred_scores).to(device), torch.zeros_like(pred_scores[..., 0]).to(device), + torch.zeros_like(pred_scores[..., 0]).to(device), ) # Compute alignment metric between all bboxes (bboxes of all pyramid levels) and GT @@ -121,7 +122,13 @@ def forward( out_mask_positive = mask_pos_sum.bool() - return assigned_labels, assigned_bboxes, assigned_scores, out_mask_positive + return ( + assigned_labels, + assigned_bboxes, + assigned_scores, + out_mask_positive, + assigned_gt_idx, + ) def _get_alignment_metric( self, diff --git a/luxonis_train/utils/boxutils.py b/luxonis_train/utils/boxutils.py index 3a26cc4f..64a8b8dd 100644 --- a/luxonis_train/utils/boxutils.py +++ b/luxonis_train/utils/boxutils.py @@ -77,12 +77,20 @@ def match_to_anchor( # The boxes and keypoints need to be scaled to the size of the features # First two indices are batch index and class label, # last index is anchor index. Those are not scaled. - scale_length = 2 * n_keypoints + box_offset + 2 + scale_length = 3 * n_keypoints + box_offset + 2 scales = torch.ones(scale_length, device=targets.device) - scales[2 : scale_length - 1] = torch.tensor( - [scale_width, scale_height] * (n_keypoints + 2) + + # Scale box and keypoint coordinates, but not visibility + for i in range(n_keypoints): + scales[box_offset + 1 + 3 * i] = scale_width + scales[box_offset + 2 + 3 * i] = scale_height + + scales[2 : box_offset + 1] = torch.tensor( + [scale_width, scale_height, scale_width, scale_height] ) + scaled_targets = targets * scales + if targets.size(1) == 0: return targets[0], torch.zeros(1, device=targets.device) diff --git a/media/coverage_badge.svg b/media/coverage_badge.svg index b750dd9c..6c15cace 100644 --- a/media/coverage_badge.svg +++ b/media/coverage_badge.svg @@ -15,7 +15,7 @@ coverage coverage - 77% - 77% + 75% + 75% diff --git a/tests/unittests/test_utils/test_assigners/test_atts_assigner.py b/tests/unittests/test_utils/test_assigners/test_atts_assigner.py index 4512d9e5..a3801ebb 100644 --- a/tests/unittests/test_utils/test_assigners/test_atts_assigner.py +++ b/tests/unittests/test_utils/test_assigners/test_atts_assigner.py @@ -24,7 +24,7 @@ def test_forward(): mask_gt = torch.rand(bs, n_max_boxes, 1) pred_bboxes = torch.rand(bs, n_anchors, 4) - labels, bboxes, scores, mask = assigner.forward( + labels, bboxes, scores, mask, assigned_gt_idx = assigner.forward( anchor_bboxes, n_level_bboxes, gt_labels, gt_bboxes, mask_gt, pred_bboxes ) @@ -32,6 +32,7 @@ def test_forward(): assert bboxes.shape == (bs, n_anchors, 4) assert scores.shape == (bs, n_anchors, n_classes) assert mask.shape == (bs, n_anchors) + assert assigned_gt_idx.shape == (bs, n_anchors) def test_get_bbox_center(): diff --git a/tests/unittests/test_utils/test_assigners/test_tal_assigner.py b/tests/unittests/test_utils/test_assigners/test_tal_assigner.py index bb2dd912..8f291615 100644 --- a/tests/unittests/test_utils/test_assigners/test_tal_assigner.py +++ b/tests/unittests/test_utils/test_assigners/test_tal_assigner.py @@ -31,7 +31,7 @@ def test_forward(): mask_gt = torch.rand(batch_size, num_max_boxes, 1) # Call the forward method - labels, bboxes, scores, mask = assigner.forward( + labels, bboxes, scores, mask, assigned_gt_idx = assigner.forward( pred_scores, pred_bboxes, anchor_points, gt_labels, gt_bboxes, mask_gt ) @@ -60,6 +60,10 @@ def test_forward(): assert torch.equal( mask, torch.zeros_like(mask) ) # All mask values should be zero as there are no GT boxes + assert assigned_gt_idx.shape == (batch_size, num_anchors) + assert torch.equal( + assigned_gt_idx, torch.zeros_like(assigned_gt_idx) + ) # All assigned_gt_idx values should be zero as there are no GT boxes def test_get_alignment_metric():