From fb1cbef2fd5268c8eea659cf32b0e0a994633036 Mon Sep 17 00:00:00 2001 From: Jernej Sabadin Date: Wed, 4 Dec 2024 13:43:37 +0100 Subject: [PATCH 01/31] feat: new det and seg heads --- luxonis_train/assigners/tal_assigner.py | 6 - .../attached_modules/losses/__init__.py | 4 + .../losses/adaptive_detection_loss.py | 9 +- .../losses/efficient_keypoint_bbox_loss.py | 18 +- .../losses/precision_dfl_detection_loss.py | 293 +++++++++++++++++ .../losses/precision_dlf_segmentation_loss.py | 306 ++++++++++++++++++ luxonis_train/nodes/blocks/__init__.py | 6 + luxonis_train/nodes/blocks/blocks.py | 129 ++++++++ luxonis_train/nodes/heads/__init__.py | 4 + .../nodes/heads/precision_bbox_head.py | 234 ++++++++++++++ .../nodes/heads/precision_seg_bbox_head.py | 188 +++++++++++ luxonis_train/utils/__init__.py | 2 + luxonis_train/utils/boundingbox.py | 40 ++- tests/integration/test_detection.py | 26 ++ 14 files changed, 1247 insertions(+), 18 deletions(-) create mode 100644 luxonis_train/attached_modules/losses/precision_dfl_detection_loss.py create mode 100644 luxonis_train/attached_modules/losses/precision_dlf_segmentation_loss.py create mode 100644 luxonis_train/nodes/heads/precision_bbox_head.py create mode 100644 luxonis_train/nodes/heads/precision_seg_bbox_head.py diff --git a/luxonis_train/assigners/tal_assigner.py b/luxonis_train/assigners/tal_assigner.py index c9435afa..51566a05 100644 --- a/luxonis_train/assigners/tal_assigner.py +++ b/luxonis_train/assigners/tal_assigner.py @@ -250,10 +250,4 @@ def _get_final_assignments( torch.full_like(assigned_scores, 0), ) - assigned_labels = torch.where( - mask_pos_sum.bool(), - assigned_labels, - torch.full_like(assigned_labels, self.n_classes), - ) - return assigned_labels, assigned_bboxes, assigned_scores diff --git a/luxonis_train/attached_modules/losses/__init__.py b/luxonis_train/attached_modules/losses/__init__.py index ff0bafc8..32b33174 100644 --- a/luxonis_train/attached_modules/losses/__init__.py +++ b/luxonis_train/attached_modules/losses/__init__.py @@ -7,6 +7,8 @@ from .ohem_bce_with_logits import OHEMBCEWithLogitsLoss from .ohem_cross_entropy import OHEMCrossEntropyLoss from .ohem_loss import OHEMLoss +from .precision_dfl_detection_loss import PrecisionDFLDetectionLoss +from .precision_dlf_segmentation_loss import PrecisionDFLSegmentationLoss from .reconstruction_segmentation_loss import ReconstructionSegmentationLoss from .sigmoid_focal_loss import SigmoidFocalLoss from .smooth_bce_with_logits import SmoothBCEWithLogitsLoss @@ -26,4 +28,6 @@ "OHEMCrossEntropyLoss", "OHEMBCEWithLogitsLoss", "FOMOLocalizationLoss", + "PrecisionDFLDetectionLoss", + "PrecisionDFLSegmentationLoss", ] diff --git a/luxonis_train/attached_modules/losses/adaptive_detection_loss.py b/luxonis_train/attached_modules/losses/adaptive_detection_loss.py index a81d5a45..521a26f1 100644 --- a/luxonis_train/attached_modules/losses/adaptive_detection_loss.py +++ b/luxonis_train/attached_modules/losses/adaptive_detection_loss.py @@ -56,9 +56,9 @@ def __init__( @type reduction: Literal["sum", "mean"] @param reduction: Reduction type for loss. @type class_loss_weight: float - @param class_loss_weight: Weight of classification loss. + @param class_loss_weight: Weight of classification loss. Defaults to 1.0. For optimal results, multiply with accumulate_grad_batches. @type iou_loss_weight: float - @param iou_loss_weight: Weight of IoU loss. + @param iou_loss_weight: Weight of IoU loss. Defaults to 2.5. For optimal results, multiply with accumulate_grad_batches. """ super().__init__(**kwargs) @@ -133,6 +133,11 @@ def forward( assigned_scores: Tensor, mask_positive: Tensor, ): + assigned_labels = torch.where( + mask_positive > 0, + assigned_labels, + torch.full_like(assigned_labels, self.n_classes), + ) one_hot_label = F.one_hot(assigned_labels.long(), self.n_classes + 1)[ ..., :-1 ] diff --git a/luxonis_train/attached_modules/losses/efficient_keypoint_bbox_loss.py b/luxonis_train/attached_modules/losses/efficient_keypoint_bbox_loss.py index 701a3c72..5dc3e564 100644 --- a/luxonis_train/attached_modules/losses/efficient_keypoint_bbox_loss.py +++ b/luxonis_train/attached_modules/losses/efficient_keypoint_bbox_loss.py @@ -56,11 +56,11 @@ def __init__( @type class_loss_weight: float @param class_loss_weight: Weight of classification loss for bounding boxes. @type regr_kpts_loss_weight: float - @param regr_kpts_loss_weight: Weight of regression loss for keypoints. + @param regr_kpts_loss_weight: Weight of regression loss for keypoints. Defaults to 12.0. For optimal results, multiply with accumulate_grad_batches. @type vis_kpts_loss_weight: float - @param vis_kpts_loss_weight: Weight of visibility loss for keypoints. + @param vis_kpts_loss_weight: Weight of visibility loss for keypoints. Defaults to 1.0. For optimal results, multiply with accumulate_grad_batches. @type iou_loss_weight: float - @param iou_loss_weight: Weight of IoU loss. + @param iou_loss_weight: Weight of IoU loss. Defaults to 2.5. For optimal results, multiply with accumulate_grad_batches. @type sigmas: list[float] | None @param sigmas: Sigmas used in keypoint loss for OKS metric. If None then use COCO ones if possible or default ones. Defaults to C{None}. @type area_factor: float | None @@ -103,7 +103,7 @@ def prepare( target_kpts = self.get_label(labels, TaskType.KEYPOINTS) target_bbox = self.get_label(labels, TaskType.BOUNDINGBOX) - batch_size = pred_scores.shape[0] + self.batch_size = pred_scores.shape[0] n_kpts = (target_kpts.shape[1] - 2) // 3 self._init_parameters(feats) @@ -112,14 +112,16 @@ def prepare( pred_kpts = self.dist2kpts_noscale( self.anchor_points_strided, pred_kpts.view( - batch_size, + self.batch_size, -1, n_kpts, 3, ), ) - target_bbox = self._preprocess_bbox_target(target_bbox, batch_size) + target_bbox = self._preprocess_bbox_target( + target_bbox, self.batch_size + ) gt_bbox_labels = target_bbox[:, :, :1] gt_xyxy = target_bbox[:, :, 1:] @@ -139,7 +141,7 @@ def prepare( ) batched_kpts = self._preprocess_kpts_target( - target_kpts, batch_size, self.gt_kpts_scale + target_kpts, self.batch_size, self.gt_kpts_scale ) assigned_gt_idx_expanded = assigned_gt_idx.unsqueeze(-1).unsqueeze(-1) selected_keypoints = batched_kpts.gather( @@ -232,7 +234,7 @@ def forward( "visibility": visibility_loss.detach(), } - return loss, sub_losses + return loss * self.batch_size, sub_losses def _preprocess_kpts_target( self, kpts_target: Tensor, batch_size: int, scale_tensor: Tensor diff --git a/luxonis_train/attached_modules/losses/precision_dfl_detection_loss.py b/luxonis_train/attached_modules/losses/precision_dfl_detection_loss.py new file mode 100644 index 00000000..d682aeea --- /dev/null +++ b/luxonis_train/attached_modules/losses/precision_dfl_detection_loss.py @@ -0,0 +1,293 @@ +import logging +from typing import Any, cast + +import torch +import torch.nn.functional as F +from torch import Tensor, nn +from torchvision.ops import box_convert + +from luxonis_train.assigners import TaskAlignedAssigner +from luxonis_train.enums import TaskType +from luxonis_train.nodes import PrecisionBBoxHead +from luxonis_train.utils import ( + Labels, + Packet, + anchors_for_fpn_features, + bbox2dist, + bbox_iou, + dist2bbox, +) + +from .base_loss import BaseLoss + +logger = logging.getLogger(__name__) + + +class PrecisionDFLDetectionLoss( + BaseLoss[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor] +): + node: PrecisionBBoxHead + supported_tasks: list[TaskType] = [TaskType.BOUNDINGBOX] + + def __init__( + self, + reg_max: int = 16, + tal_topk: int = 10, + class_loss_weight: float = 0.5, + bbox_loss_weight: float = 7.5, + dfl_loss_weight: float = 1.5, + **kwargs: Any, + ): + """BBox loss adapted from U{Real-Time Flying Object Detection with YOLOv8 + } + + @type reg_max: int + @param reg_max: Maximum number of regression channels. Defaults to 16. + @type tal_topk: int + @param tal_topk: Number of anchors considered in selection. Defaults to 10. + @type class_loss_weight: float + @param class_loss_weight: Weight for classification loss. Defaults to 0.5. For optimal results, multiply with accumulate_grad_batches. + @type bbox_loss_weight: float + @param bbox_loss_weight: Weight for bbox loss. Defaults to 7.5. For optimal results, multiply with accumulate_grad_batches. + @type dfl_loss_weight: float + @param dfl_loss_weight: Weight for DFL loss. Defaults to 1.5. For optimal results, multiply with accumulate_grad_batches. + """ + super().__init__(**kwargs) + self.stride = self.node.stride + self.grid_cell_size = self.node.grid_cell_size + self.grid_cell_offset = self.node.grid_cell_offset + self.original_img_size = self.original_in_shape[1:] + + self.class_loss_weight = class_loss_weight + self.bbox_loss_weight = bbox_loss_weight + self.dfl_loss_weight = dfl_loss_weight + + self.assigner = TaskAlignedAssigner( + n_classes=self.n_classes, topk=tal_topk, alpha=0.5, beta=6.0 + ) + self.bbox_loss = CustomBboxLoss(reg_max) + self.proj = torch.arange(reg_max, dtype=torch.float) + self.bce = nn.BCEWithLogitsLoss(reduction="none") + + def prepare( + self, inputs: Packet[Tensor], labels: Labels + ) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + feats = self.get_input_tensors(inputs, "features") + self._init_parameters(feats) + self.batch_size = feats[0].shape[0] + pred_distri, pred_scores = torch.cat( + [xi.view(self.batch_size, self.node.no, -1) for xi in feats], 2 + ).split((self.node.reg_max * 4, self.n_classes), 1) + target = self.get_label(labels) + pred_distri = pred_distri.permute(0, 2, 1).contiguous() + pred_scores = pred_scores.permute(0, 2, 1).contiguous() + + target = self._preprocess_bbox_target(target, self.batch_size) + + pred_bboxes = self.decode_bbox(self.anchor_points_strided, pred_distri) + + gt_labels = target[:, :, :1] + gt_xyxy = target[:, :, 1:] + mask_gt = (gt_xyxy.sum(-1, keepdim=True) > 0).float() + + _, assigned_bboxes, assigned_scores, mask_positive, _ = self.assigner( + pred_scores.detach().sigmoid(), + (pred_bboxes.detach() * self.stride_tensor).type(gt_xyxy.dtype), + self.anchor_points, + gt_labels, + gt_xyxy, + mask_gt, + ) + + return ( + pred_distri, + pred_bboxes, + pred_scores, + assigned_bboxes / self.stride_tensor, + assigned_scores, + mask_positive, + ) + + def forward( + self, + pred_distri: Tensor, + pred_bboxes: Tensor, + pred_scores: Tensor, + assigned_bboxes: Tensor, + assigned_scores: Tensor, + mask_positive: Tensor, + ): + max_assigned_scores_sum = max(assigned_scores.sum(), 1) + loss_cls = ( + self.bce(pred_scores, assigned_scores) + ).sum() / max_assigned_scores_sum + if mask_positive.sum(): + loss_iou, loss_dfl = self.bbox_loss( + pred_distri, + pred_bboxes, + self.anchor_points_strided, + assigned_bboxes, + assigned_scores, + max_assigned_scores_sum, + mask_positive, + ) + else: + loss_iou = torch.tensor(0.0).to(pred_distri.device) + loss_dfl = torch.tensor(0.0).to(pred_distri.device) + + loss = ( + self.class_loss_weight * loss_cls + + self.bbox_loss_weight * loss_iou + + self.dfl_loss_weight * loss_dfl + ) + sub_losses = { + "class": loss_cls.detach(), + "iou": loss_iou.detach(), + "dfl": loss_dfl.detach(), + } + + return loss * self.batch_size, sub_losses + + def _preprocess_bbox_target( + self, target: Tensor, batch_size: int + ) -> Tensor: + sample_ids, counts = cast( + tuple[Tensor, Tensor], + torch.unique(target[:, 0].int(), return_counts=True), + ) + c_max = int(counts.max()) if counts.numel() > 0 else 0 + out_target = torch.zeros(batch_size, c_max, 5, device=target.device) + out_target[:, :, 0] = -1 + for id, count in zip(sample_ids, counts): + out_target[id, :count] = target[target[:, 0] == id][:, 1:] + + scaled_target = out_target[:, :, 1:5] * self.gt_bboxes_scale + out_target[..., 1:] = box_convert(scaled_target, "xywh", "xyxy") + + return out_target + + def decode_bbox(self, anchor_points: Tensor, pred_dist: Tensor) -> Tensor: + """Decode predicted object bounding box coordinates from anchor + points and distribution. + + @type anchor_points: Tensor + @param anchor_points: Anchor points tensor of shape [N, 4] where + N is the number of anchors. + @type pred_dist: Tensor + @param pred_dist: Predicted distribution tensor of shape + [batch_size, N, 4 * reg_max] where N is the number of + anchors. + @rtype: Tensor + """ + if self.node.dfl: + batch_size, num_anchors, num_channels = pred_dist.shape + dist_probs = pred_dist.view( + batch_size, num_anchors, 4, num_channels // 4 + ).softmax(dim=3) + dist_transformed = dist_probs.matmul( + self.proj.to(anchor_points.device).type(pred_dist.dtype) + ) + return dist2bbox(dist_transformed, anchor_points, out_format="xyxy") + + def _init_parameters(self, features: list[Tensor]): + if not hasattr(self, "gt_bboxes_scale"): + _, self.anchor_points, _, self.stride_tensor = ( + anchors_for_fpn_features( + features, + self.stride, + self.grid_cell_size, + self.grid_cell_offset, + multiply_with_stride=True, + ) + ) + self.gt_bboxes_scale = torch.tensor( + [ + self.original_img_size[1], + self.original_img_size[0], + self.original_img_size[1], + self.original_img_size[0], + ], + device=features[0].device, + ) + self.anchor_points_strided = ( + self.anchor_points / self.stride_tensor + ) + + +class CustomBboxLoss(nn.Module): + def __init__(self, reg_max: int = 16): + """BBox loss that combines IoU and DFL losses. + + @type reg_max: int + @param reg_max: Maximum number of regression channels. Defaults + to 16. + """ + super().__init__() + self.dist_loss = CustomDFLoss(reg_max) if reg_max > 1 else None + + def forward( + self, + pred_dist: Tensor, + pred_bboxes: Tensor, + anchors: Tensor, + targets: Tensor, + scores: Tensor, + total_score: Tensor, + fg_mask: Tensor, + ) -> tuple[Tensor, Tensor]: + score_weights = scores.sum(dim=-1)[fg_mask].unsqueeze(dim=-1) + + iou_vals = bbox_iou( + pred_bboxes[fg_mask], + targets[fg_mask], + iou_type="ciou", + element_wise=True, + ).unsqueeze(dim=-1) + iou_loss_val = ((1.0 - iou_vals) * score_weights).sum() / total_score + + if self.dist_loss is not None: + offset_targets = bbox2dist( + targets, anchors, self.dist_loss.reg_max - 1 + ) + dfl_loss_val = ( + self.dist_loss( + pred_dist[fg_mask].view(-1, self.dist_loss.reg_max), + offset_targets[fg_mask], + ) + * score_weights + ) + dfl_loss_val = dfl_loss_val.sum() / total_score + else: + dfl_loss_val = torch.zeros(1, device=pred_dist.device) + + return iou_loss_val, dfl_loss_val + + +class CustomDFLoss(nn.Module): + def __init__(self, reg_max: int = 16): + """DFL loss that combines classification and regression losses. + + @type reg_max: int + @param reg_max: Maximum number of regression channels. Defaults + to 16. + """ + super().__init__() + self.reg_max = reg_max + + def __call__(self, pred_dist: Tensor, targets: Tensor) -> Tensor: + targets = targets.clamp(0, self.reg_max - 1 - 0.01) + left_target = targets.floor().long() + right_target = left_target + 1 + weight_left = right_target - targets + weight_right = 1.0 - weight_left + + left_val = F.cross_entropy( + pred_dist, left_target.view(-1), reduction="none" + ).view(left_target.shape) + right_val = F.cross_entropy( + pred_dist, right_target.view(-1), reduction="none" + ).view(left_target.shape) + + return (left_val * weight_left + right_val * weight_right).mean( + dim=-1, keepdim=True + ) diff --git a/luxonis_train/attached_modules/losses/precision_dlf_segmentation_loss.py b/luxonis_train/attached_modules/losses/precision_dlf_segmentation_loss.py new file mode 100644 index 00000000..8777cd24 --- /dev/null +++ b/luxonis_train/attached_modules/losses/precision_dlf_segmentation_loss.py @@ -0,0 +1,306 @@ +import logging +from typing import Any + +import torch +import torch.nn.functional as F +from torch import Tensor +from torchvision.ops import box_convert + +from luxonis_train.attached_modules.losses.precision_dfl_detection_loss import ( + PrecisionDFLDetectionLoss, +) +from luxonis_train.enums import TaskType +from luxonis_train.nodes import PrecisionSegmentBBoxHead +from luxonis_train.utils import ( + Labels, + Packet, + apply_bounding_box_to_masks, +) + +logger = logging.getLogger(__name__) + + +class PrecisionDFLSegmentationLoss(PrecisionDFLDetectionLoss): + node: PrecisionSegmentBBoxHead + supported_tasks: list[TaskType] = [ + TaskType.BOUNDINGBOX, + TaskType.SEGMENTATION, + ] + + def __init__( + self, + reg_max: int = 16, + tal_topk: int = 10, + class_loss_weight: float = 0.5, + bbox_loss_weight: float = 7.5, + dfl_loss_weight: float = 1.5, + overlap_mask: bool = True, + **kwargs: Any, + ): + """Instance Segmentation and BBox loss adapted from U{Real-Time Flying Object Detection with YOLOv8 + } + + @type reg_max: int + @param reg_max: Maximum number of regression channels. Defaults to 16. + @type tal_topk: int + @param tal_topk: Number of anchors considered in selection. Defaults to 10. + @type class_loss_weight: float + @param class_loss_weight: Weight for classification loss. Defaults to 0.5. For optimal results, multiply with accumulate_grad_batches. + @type bbox_loss_weight: float + @param bbox_loss_weight: Weight for bbox loss. Defaults to 7.5. For optimal results, multiply with accumulate_grad_batches. + @type dfl_loss_weight: float + @param dfl_loss_weight: Weight for DFL loss. Defaults to 1.5. For optimal results, multiply with accumulate_grad_batches. + """ + super().__init__( + reg_max=reg_max, + tal_topk=tal_topk, + class_loss_weight=class_loss_weight, + bbox_loss_weight=bbox_loss_weight, + dfl_loss_weight=dfl_loss_weight, + **kwargs, + ) + self.overlap = overlap_mask + + def prepare( + self, inputs: Packet[Tensor], labels: Labels + ) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + det_feats = self.get_input_tensors(inputs, "features") + proto = self.get_input_tensors(inputs, "prototypes") + pred_mask = self.get_input_tensors(inputs, "mask_coeficients") + self._init_parameters(det_feats) + self.batch_size, _, mask_h, mask_w = proto.shape + pred_distri, pred_scores = torch.cat( + [xi.view(self.batch_size, self.node.no, -1) for xi in det_feats], 2 + ).split((self.node.reg_max * 4, self.n_classes), 1) + target_bbox = self.get_label(labels, TaskType.BOUNDINGBOX) + target_masks = self.get_label( + labels, TaskType.SEGMENTATION + ) # TODO: THIS SHOULD BE REFINED AFTER ANNOTATION REFACTOR IN LUXONIS_ML + if tuple(target_masks.shape[-2:]) != (mask_h, mask_w): + target_masks = F.interpolate( + target_masks, (mask_h, mask_w), mode="nearest" + )[ + 0 + ] # TODO: target_mask should be [1, N_masks, H, W] -> [N_masks, H, W]. Masks are ordered the same way as in target_bbox + + pred_distri = pred_distri.permute(0, 2, 1).contiguous() + pred_scores = pred_scores.permute(0, 2, 1).contiguous() + pred_mask = pred_mask.permute(0, 2, 1).contiguous() + + target_bbox = self._preprocess_bbox_target( + target_bbox, self.batch_size + ) + + pred_bboxes = self.decode_bbox(self.anchor_points_strided, pred_distri) + + gt_labels = target_bbox[:, :, :1] + gt_xyxy = target_bbox[:, :, 1:] + mask_gt = (gt_xyxy.sum(-1, keepdim=True) > 0).float() + + _, assigned_bboxes, assigned_scores, mask_positive, assigned_gt_idx = ( + self.assigner( + pred_scores.detach().sigmoid(), + (pred_bboxes.detach() * self.stride_tensor).type( + gt_xyxy.dtype + ), + self.anchor_points, + gt_labels, + gt_xyxy, + mask_gt, + ) + ) + + return ( + pred_distri, + pred_bboxes, + pred_scores, + assigned_bboxes, + assigned_scores, + mask_positive, + assigned_gt_idx, + pred_mask, + proto, + target_masks, + ) + + def forward( + self, + pred_distri: Tensor, + pred_bboxes: Tensor, + pred_scores: Tensor, + assigned_bboxes: Tensor, + assigned_scores: Tensor, + mask_positive: Tensor, + assigned_gt_idx: Tensor, + pred_masks: Tensor, + proto: Tensor, + target_masks: Tensor, + ): + max_assigned_scores_sum = max(assigned_scores.sum(), 1) + loss_cls = ( + self.bce(pred_scores, assigned_scores) + ).sum() / max_assigned_scores_sum + if mask_positive.sum(): + loss_iou, loss_dfl = self.bbox_loss( + pred_distri, + pred_bboxes, + self.anchor_points_strided, + assigned_bboxes / self.stride_tensor, + assigned_scores, + max_assigned_scores_sum, + mask_positive, + ) + else: + loss_iou = torch.tensor(0.0).to(pred_distri.device) + loss_dfl = torch.tensor(0.0).to(pred_distri.device) + + # TODO: after annotation refactor in luxonis-ml, this dummy batch_idx should be updated + batch_idx = torch.tensor([0], device=proto.device).unsqueeze( + -1 + ) # THAT IS WHAT YOLO uses + + loss_seg = self.calculate_segmentation_loss( + mask_positive, + target_masks, + assigned_gt_idx, + assigned_bboxes, + batch_idx, + proto, + pred_masks, + self.overlap, + ) + + loss = ( + self.class_loss_weight * loss_cls + + self.bbox_loss_weight * loss_iou + + self.dfl_loss_weight * loss_dfl + + loss_seg * self.bbox_loss_weight + ) + sub_losses = { + "class": loss_cls.detach(), + "iou": loss_iou.detach(), + "dfl": loss_dfl.detach(), + "seg": loss_seg.detach(), + } + + return loss * self.batch_size, sub_losses + + # TODO: Modify after adding corect annotation loading + def calculate_segmentation_loss( + self, + fg_mask: torch.Tensor, + masks: torch.Tensor, + target_gt_idx: torch.Tensor, + target_bboxes: torch.Tensor, + batch_idx: torch.Tensor, + proto: torch.Tensor, + pred_masks: torch.Tensor, + overlap: bool, + ) -> torch.Tensor: + """Calculate the loss for instance segmentation. + + Args: + fg_mask (torch.Tensor): A binary tensor of shape (BS, N_anchors) indicating which anchors are positive. + masks (torch.Tensor): Ground truth masks of shape (BS, H, W) if `overlap` is False, otherwise (BS, ?, H, W). + target_gt_idx (torch.Tensor): Indexes of ground truth objects for each anchor of shape (BS, N_anchors). + target_bboxes (torch.Tensor): Ground truth bounding boxes for each anchor of shape (BS, N_anchors, 4). + batch_idx (torch.Tensor): Batch indices of shape (N_labels_in_batch, 1). + proto (torch.Tensor): Prototype masks of shape (BS, 32, H, W). + pred_masks (torch.Tensor): Predicted masks for each anchor of shape (BS, N_anchors, 32). + imgsz (torch.Tensor): Size of the input image as a tensor of shape (2), i.e., (H, W). + overlap (bool): Whether the masks in `masks` tensor overlap. + + Returns: + (torch.Tensor): The calculated loss for instance segmentation. + + Notes: + The batch loss can be computed for improved speed at higher memory usage. + For example, pred_mask can be computed as follows: + pred_mask = torch.einsum('in,nhw->ihw', pred, proto) # (i, 32) @ (32, 160, 160) -> (i, 160, 160) + """ + _, _, mask_h, mask_w = proto.shape + loss = 0 + + # Normalize to 0-1 + target_bboxes_normalized = target_bboxes / self.gt_bboxes_scale + + # Areas of target bboxes + marea = box_convert( + target_bboxes_normalized, in_fmt="xyxy", out_fmt="xywh" + )[..., 2:].prod(2) + + # Normalize to mask size + mxyxy = target_bboxes_normalized * torch.tensor( + [mask_w, mask_h, mask_w, mask_h], device=proto.device + ) + + for i, single_i in enumerate( + zip(fg_mask, target_gt_idx, pred_masks, proto, mxyxy, marea, masks) + ): + ( + fg_mask_i, + target_gt_idx_i, + pred_masks_i, + proto_i, + mxyxy_i, + marea_i, + masks_i, + ) = single_i + if fg_mask_i.any(): + mask_idx = target_gt_idx_i[fg_mask_i] + if overlap: + gt_mask = masks_i == (mask_idx + 1).view(-1, 1, 1) + gt_mask = gt_mask.float() + else: + gt_mask = masks[batch_idx.view(-1) == i][mask_idx] + + loss += self.single_mask_loss( + gt_mask, + pred_masks_i[fg_mask_i], + proto_i, + mxyxy_i[fg_mask_i], + marea_i[fg_mask_i], + ) + + # WARNING: lines below prevents Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove + else: + loss += (proto * 0).sum() + ( + pred_masks * 0 + ).sum() # inf sums may lead to nan loss + + return loss / fg_mask.sum() + + # TODO: Modify after adding corect annotation loading + @staticmethod + def single_mask_loss( + gt_mask: torch.Tensor, + pred: torch.Tensor, + proto: torch.Tensor, + xyxy: torch.Tensor, + area: torch.Tensor, + ) -> torch.Tensor: + """Compute the instance segmentation loss for a single image. + + Args: + gt_mask (torch.Tensor): Ground truth mask of shape (n, H, W), where n is the number of objects. + pred (torch.Tensor): Predicted mask coefficients of shape (n, 32). + proto (torch.Tensor): Prototype masks of shape (32, H, W). + xyxy (torch.Tensor): Ground truth bounding boxes in xyxy format, normalized to [0, 1], of shape (n, 4). + area (torch.Tensor): Area of each ground truth bounding box of shape (n,). + + Returns: + (torch.Tensor): The calculated mask loss for a single image. + + Notes: + The function uses the equation pred_mask = torch.einsum('in,nhw->ihw', pred, proto) to produce the + predicted masks from the prototype masks and predicted mask coefficients. + """ + pred_mask = torch.einsum( + "in,nhw->ihw", pred, proto + ) # (n, 32) @ (32, 80, 80) -> (n, 80, 80) + loss = F.binary_cross_entropy_with_logits( + pred_mask, gt_mask, reduction="none" + ) + return ( + apply_bounding_box_to_masks(loss, xyxy).mean(dim=(1, 2)) / area + ).sum() diff --git a/luxonis_train/nodes/blocks/__init__.py b/luxonis_train/nodes/blocks/__init__.py index ce0181c9..71228fbd 100644 --- a/luxonis_train/nodes/blocks/__init__.py +++ b/luxonis_train/nodes/blocks/__init__.py @@ -1,4 +1,5 @@ from .blocks import ( + DFL, AttentionRefinmentBlock, BasicResNetBlock, BlockRepeater, @@ -6,9 +7,11 @@ ConvModule, CSPStackRepBlock, DropPath, + DWConvModule, EfficientDecoupledBlock, FeatureFusionBlock, RepVGGBlock, + SegProto, SpatialPyramidPoolingBlock, SqueezeExciteBlock, UpBlock, @@ -32,4 +35,7 @@ "Bottleneck", "UpscaleOnline", "DropPath", + "SegProto", + "DWConvModule", + "DFL", ] diff --git a/luxonis_train/nodes/blocks/blocks.py b/luxonis_train/nodes/blocks/blocks.py index 25bea7c5..29a2fa9b 100644 --- a/luxonis_train/nodes/blocks/blocks.py +++ b/luxonis_train/nodes/blocks/blocks.py @@ -81,6 +81,90 @@ def _initialize_weights_and_biases(self, prior_prob: float) -> None: module.weight = nn.Parameter(w, requires_grad=True) +class SegProto(nn.Module): + def __init__(self, in_ch, mid_ch=256, out_ch=32): + """Initializes the segmentation prototype generator. + + @type in_ch: int + @param in_ch: Number of input channels. + @type mid_ch: int + @param mid_ch: Number of intermediate channels. Defaults to 256. + @type out_ch: int + @param out_ch: Number of output channels. Defaults to 32. + """ + super().__init__() + self.conv1 = ConvModule( + in_channels=in_ch, + out_channels=mid_ch, + kernel_size=3, + stride=1, + padding=1, + activation=nn.SiLU(), + ) + self.upsample = nn.ConvTranspose2d( + in_channels=mid_ch, + out_channels=mid_ch, + kernel_size=2, + stride=2, + bias=True, + ) + self.conv2 = ConvModule( + in_channels=mid_ch, + out_channels=mid_ch, + kernel_size=3, + stride=1, + padding=1, + activation=nn.SiLU(), + ) + self.conv3 = ConvModule( + in_channels=mid_ch, + out_channels=out_ch, + kernel_size=1, + stride=1, + padding=0, + activation=nn.SiLU(), + ) + + def forward(self, x): + """Defines the forward pass of the segmentation prototype + generator. + + @type x: torch.Tensor + @param x: Input tensor. + @rtype: torch.Tensor + @return: Processed tensor. + """ + return self.conv3(self.conv2(self.upsample(self.conv1(x)))) + + +class DFL(nn.Module): + def __init__(self, channels: int = 16): + """ + Constructs the module with a convolutional layer using the specified input channels. + Proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391 + + @type channels: int + @param channels: Number of input channels. Defaults to 16. + + """ + super().__init__() + self.transform = nn.Conv2d( + channels, 1, kernel_size=1, bias=False + ).requires_grad_(False) + weights = torch.arange(channels, dtype=torch.float32) + self.transform.weight.data.copy_(weights.view(1, channels, 1, 1)) + self.num_channels = channels + + def forward(self, input: Tensor): + """Transforms the input tensor and returns the processed + output.""" + batch_size, _, anchors = input.size() + reshaped = input.view(batch_size, 4, self.num_channels, anchors) + softmaxed = reshaped.transpose(2, 1).softmax(dim=1) + processed = self.transform(softmaxed) + return processed.view(batch_size, 4, anchors) + + class ConvModule(nn.Sequential): def __init__( self, @@ -131,6 +215,51 @@ def __init__( ) +class DWConvModule(ConvModule): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + bias: bool = False, + activation: nn.Module | None = None, + ): + """Depth-wise Conv2d + BN + Activation. + + @type in_channels: int + @param in_channels: Number of input channels. + @type out_channels: int + @param out_channels: Number of output channels. + @type kernel_size: int + @param kernel_size: Kernel size. + @type stride: int + @param stride: Stride. Defaults to 1. + @type padding: int + @param padding: Padding. Defaults to 0. + @type dilation: int + @param dilation: Dilation. Defaults to 1. + @type bias: bool + @param bias: Whether to use bias. Defaults to False. + @type activation: L{nn.Module} | None + @param activation: Activation function. If None then nn.Relu. + """ + + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=in_channels, # Depth-wise convolution + bias=bias, + activation=activation, + ) + + class UpBlock(nn.Sequential): def __init__( self, diff --git a/luxonis_train/nodes/heads/__init__.py b/luxonis_train/nodes/heads/__init__.py index fa4d9b9f..48bdf1e3 100644 --- a/luxonis_train/nodes/heads/__init__.py +++ b/luxonis_train/nodes/heads/__init__.py @@ -5,6 +5,8 @@ from .efficient_bbox_head import EfficientBBoxHead from .efficient_keypoint_bbox_head import EfficientKeypointBBoxHead from .fomo_head import FOMOHead +from .precision_bbox_head import PrecisionBBoxHead +from .precision_seg_bbox_head import PrecisionSegmentBBoxHead from .segmentation_head import SegmentationHead __all__ = [ @@ -16,4 +18,6 @@ "DDRNetSegmentationHead", "DiscSubNetHead", "FOMOHead", + "PrecisionBBoxHead", + "PrecisionSegmentBBoxHead", ] diff --git a/luxonis_train/nodes/heads/precision_bbox_head.py b/luxonis_train/nodes/heads/precision_bbox_head.py new file mode 100644 index 00000000..bfc5f72d --- /dev/null +++ b/luxonis_train/nodes/heads/precision_bbox_head.py @@ -0,0 +1,234 @@ +import logging +import math +from typing import Any, Literal + +import torch +from torch import Tensor, nn + +from luxonis_train.enums import TaskType +from luxonis_train.nodes import BaseNode +from luxonis_train.nodes.blocks import DFL, ConvModule, DWConvModule +from luxonis_train.utils import ( + Packet, + anchors_for_fpn_features, + dist2bbox, + non_max_suppression, +) + +logger = logging.getLogger(__name__) + + +class PrecisionBBoxHead(BaseNode[list[Tensor], list[Tensor]]): + in_channels: list[int] + tasks: list[TaskType] = [TaskType.BOUNDINGBOX] + + def __init__( + self, + reg_max: int = 16, + n_heads: Literal[2, 3, 4] = 3, + conf_thres: float = 0.25, + iou_thres: float = 0.45, + max_det: int = 300, + **kwargs: Any, + ): + """ + Adapted from U{Real-Time Flying Object Detection with YOLOv8 + } + + @type ch: tuple[int] + @param ch: Channels for each detection layer. + @type reg_max: int + @param reg_max: Maximum number of regression channels. + @type n_heads: Literal[2, 3, 4] + @param n_heads: Number of output heads. + @type conf_thres: float + @param conf_thres: Confidence threshold for NMS. + @type iou_thres: float + @param iou_thres: IoU threshold for NMS. + """ + super().__init__(**kwargs) + self.reg_max = reg_max + self.no = self.n_classes + reg_max * 4 + self.n_heads = n_heads + self.conf_thres = conf_thres + self.iou_thres = iou_thres + self.grid_cell_offset = 0.5 + self.grid_cell_size = 5.0 + self.max_det = max_det + + reg_channels = max((16, self.in_channels[0] // 4, reg_max * 4)) + cls_channels = max(self.in_channels[0], min(self.n_classes, 100)) + + self.detection_heads = nn.ModuleList( + nn.Sequential( + # Regression branch + nn.Sequential( + ConvModule( + x, + reg_channels, + kernel_size=3, + padding=1, + activation=nn.SiLU(), + ), + ConvModule( + reg_channels, + reg_channels, + kernel_size=3, + padding=1, + activation=nn.SiLU(), + ), + nn.Conv2d(reg_channels, 4 * self.reg_max, kernel_size=1), + ), + # Classification branch + nn.Sequential( + nn.Sequential( + DWConvModule( + x, + x, + kernel_size=3, + padding=1, + activation=nn.SiLU(), + ), + ConvModule( + x, + cls_channels, + kernel_size=1, + activation=nn.SiLU(), + ), + ), + nn.Sequential( + DWConvModule( + cls_channels, + cls_channels, + kernel_size=3, + padding=1, + activation=nn.SiLU(), + ), + ConvModule( + cls_channels, + cls_channels, + kernel_size=1, + activation=nn.SiLU(), + ), + ), + nn.Conv2d(cls_channels, self.n_classes, kernel_size=1), + ), + ) + for x in self.in_channels + ) + + self.stride = self._fit_stride_to_n_heads() + self.dfl = DFL(reg_max) if reg_max > 1 else nn.Identity() + self.bias_init() + self.initialize_weights() + + def forward(self, x: list[Tensor]) -> list[Tensor]: + for i in range(self.n_heads): + reg_output = self.detection_heads[i][0](x[i]) + cls_output = self.detection_heads[i][1](x[i]) + x[i] = torch.cat((reg_output, cls_output), 1) + return x + + def wrap(self, output: list[Tensor]) -> Packet[Tensor]: + if self.training: + return { + "features": output, + } + y = self._inference(output) + if self.export: + return {self.task: y} + boxes = non_max_suppression( + y, + n_classes=self.n_classes, + conf_thres=self.conf_thres, + iou_thres=self.iou_thres, + bbox_format="xyxy", + max_det=self.max_det, + predicts_objectness=False, + ) + + return { + "features": output, + "boundingbox": boxes, + } + + def _fit_stride_to_n_heads(self): + """Returns correct stride for number of heads and attach + index.""" + stride = torch.tensor( + [ + self.original_in_shape[1] / x[2] # type: ignore + for x in self.in_sizes[: self.n_heads] + ], + dtype=torch.int, + ) + return stride + + def _inference(self, x: list[Tensor], masks: Tensor | None = None): + """Decode predicted bounding boxes and class probabilities based + on multiple-level feature maps.""" + shape = x[0].shape + x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2) + _, self.anchor_points, _, self.strides = anchors_for_fpn_features( + x, self.stride, 0.5 + ) + box, cls = x_cat.split((self.reg_max * 4, self.n_classes), 1) + pred_bboxes = self.decode_bboxes( + self.dfl(box), self.anchor_points.transpose(0, 1) + ) * self.strides.transpose(0, 1) + + if self.export: + return torch.cat( + (pred_bboxes.permute(0, 2, 1), cls.sigmoid().permute(0, 2, 1)), + 1, + ) + + base_output = [ + pred_bboxes.permute(0, 2, 1), + torch.ones( + (shape[0], pred_bboxes.shape[2], 1), + dtype=pred_bboxes.dtype, + device=pred_bboxes.device, + ), + cls.permute(0, 2, 1), + ] + + if masks is not None: + base_output.append(masks.permute(0, 2, 1)) + + output_merged = torch.cat(base_output, dim=-1) + return output_merged + + def decode_bboxes(self, bboxes: Tensor, anchors: Tensor) -> Tensor: + """Decode bounding boxes.""" + return dist2bbox(bboxes, anchors, out_format="xyxy", dim=1) + + def bias_init(self): + """Initialize biases for the detection heads. + + Assumes detection_heads structure with separate regression and + classification branches. + """ + for head, stride in zip(self.detection_heads, self.stride): + reg_branch = head[0] + cls_branch = head[1] + + reg_conv = reg_branch[-1] + reg_conv.bias.data[:] = 1.0 + + cls_conv = cls_branch[-1] + cls_conv.bias.data[: self.n_classes] = math.log( + 5 / self.n_classes / (self.original_in_shape[1] / stride) ** 2 + ) + + def initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + pass + elif isinstance(m, nn.BatchNorm2d): + m.eps = 0.001 + m.momentum = 0.03 + elif isinstance( + m, (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU) + ): + m.inplace = True diff --git a/luxonis_train/nodes/heads/precision_seg_bbox_head.py b/luxonis_train/nodes/heads/precision_seg_bbox_head.py new file mode 100644 index 00000000..5cfc3e60 --- /dev/null +++ b/luxonis_train/nodes/heads/precision_seg_bbox_head.py @@ -0,0 +1,188 @@ +from typing import Any, Literal + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +from luxonis_train.enums import TaskType +from luxonis_train.nodes.blocks import ConvModule, SegProto +from luxonis_train.utils import ( + Packet, + apply_bounding_box_to_masks, + non_max_suppression, +) + +from .precision_bbox_head import PrecisionBBoxHead + + +class PrecisionSegmentBBoxHead(PrecisionBBoxHead): + tasks: list[TaskType] = [TaskType.SEGMENTATION, TaskType.BOUNDINGBOX] + + def __init__( + self, + n_heads: Literal[2, 3, 4] = 3, + n_masks: int = 32, + n_proto: int = 256, + conf_thres: float = 0.25, + iou_thres: float = 0.45, + max_det: int = 300, + **kwargs: Any, + ): + """ + Head for instance segmentation and object detection. + Adapted from U{Real-Time Flying Object Detection with YOLOv8 + } + + @type n_heads: Literal[2, 3, 4] + @param n_heads: Number of output heads. Defaults to 3. + @type n_masks: int + @param n_masks: Number of masks. + @type n_proto: int + @param n_proto: Number of prototypes for segmentation. + @type conf_thres: flaot + @param conf_thres: Confidence threshold for NMS. + @type iou_thres: float + @param iou_thres: IoU threshold for NMS. + @type max_det: int + @param max_det: Maximum number of detections retained after NMS. + """ + super().__init__( + n_heads=n_heads, + conf_thres=conf_thres, + iou_thres=iou_thres, + max_det=max_det, + **kwargs, + ) + + self.n_masks = n_masks + self.n_proto = n_proto + + self.proto = SegProto(self.in_channels[0], self.n_proto, self.n_masks) + + mid_ch = max(self.in_channels[0] // 4, self.n_masks) + self.mask_layers = nn.ModuleList( + nn.Sequential( + ConvModule(x, mid_ch, 3, 1, 1, activation=nn.SiLU()), + ConvModule(mid_ch, mid_ch, 3, 1, 1, activation=nn.SiLU()), + nn.Conv2d(mid_ch, self.n_masks, 1, 1), + ) + for x in self.in_channels + ) + + self._export_output_names = None + + def forward( + self, inputs: list[Tensor] + ) -> tuple[list[Tensor], list[Tensor], list[Tensor], list[Tensor]]: + prototypes = self.proto(inputs[0]) + bs = prototypes.shape[0] + mask_coefficients = torch.cat( + [ + self.mask_layers[i](inputs[i]).view(bs, self.n_masks, -1) + for i in range(self.n_heads) + ], + dim=2, + ) + det_outs = super().forward(inputs) + + return det_outs, prototypes, mask_coefficients + + def wrap( + self, output: tuple[list[Tensor], Tensor, Tensor] + ) -> Packet[Tensor]: + det_feats, prototypes, mask_coefficients = output + if self.training: + return { + "features": det_feats, + "prototypes": prototypes, + "mask_coeficients": mask_coefficients, + } + if self.export: + { + self.task: ( + torch.cat([det_feats, mask_coefficients], 1), + prototypes, + ) + } + pred_bboxes = self._inference(det_feats, mask_coefficients) + preds = non_max_suppression( + pred_bboxes, + n_classes=self.n_classes, + conf_thres=self.conf_thres, + iou_thres=self.iou_thres, + bbox_format="xyxy", + max_det=self.max_det, + predicts_objectness=False, + ) + + results = { + "features": det_feats, + "prototypes": prototypes, + "mask_coeficients": mask_coefficients, + "boundingbox": [], + "segmentation": [], # TODO: Sync on how we want to visualize this + } + + for i, pred in enumerate(preds): + results["segmentation"].append( + refine_and_apply_masks( + prototypes[i], + pred[:, 6:], + pred[:, :4], + self.original_in_shape[-2:], + upsample=True, + ) + ) + results["boundingbox"].append(pred[:, :6]) + + return results + + +def refine_and_apply_masks( + mask_prototypes, + predicted_masks, + bounding_boxes, + target_shape, + upsample=False, +): + """Refine and apply masks to bounding boxes based on the mask head + outputs. + + @type mask_prototypes: torch.Tensor + @param mask_prototypes: Tensor of shape [mask_dim, mask_height, + mask_width]. + @type predicted_masks: torch.Tensor + @param predicted_masks: Tensor of shape [num_masks, mask_dim], where + num_masks is the number of detected masks. + @type bounding_boxes: torch.Tensor + @param bounding_boxes: Tensor of shape [num_masks, 4], containing + bounding box coordinates. + @type target_shape: tuple + @param target_shape: Tuple (height, width) representing the + dimensions of the original image. + @type upsample: bool + @param upsample: If True, upsample the masks to the target image + dimensions. Default is False. + @rtype: torch.Tensor + @return: A binary mask tensor of shape [num_masks, height, width], + where the masks are cropped according to their respective + bounding boxes. + """ + channels, proto_h, proto_w = mask_prototypes.shape + img_h, img_w = target_shape + masks_combined = ( + predicted_masks @ mask_prototypes.float().view(channels, -1) + ).view(-1, proto_h, proto_w) + w_scale, h_scale = proto_w / img_w, proto_h / img_h + scaled_boxes = bounding_boxes.clone() + scaled_boxes[:, [0, 2]] *= w_scale + scaled_boxes[:, [1, 3]] *= h_scale + cropped_masks = apply_bounding_box_to_masks(masks_combined, scaled_boxes) + if upsample: + cropped_masks = F.interpolate( + cropped_masks.unsqueeze(0), + size=target_shape, + mode="bilinear", + align_corners=False, + ).squeeze(0) + return (cropped_masks > 0).to(cropped_masks.dtype) diff --git a/luxonis_train/utils/__init__.py b/luxonis_train/utils/__init__.py index 2944dfde..132da4dc 100644 --- a/luxonis_train/utils/__init__.py +++ b/luxonis_train/utils/__init__.py @@ -1,5 +1,6 @@ from .boundingbox import ( anchors_for_fpn_features, + apply_bounding_box_to_masks, bbox2dist, bbox_iou, compute_iou_loss, @@ -41,4 +42,5 @@ "compute_iou_loss", "get_sigmas", "traverse_graph", + "apply_bounding_box_to_masks", ] diff --git a/luxonis_train/utils/boundingbox.py b/luxonis_train/utils/boundingbox.py index e72360c3..ff2af2cf 100644 --- a/luxonis_train/utils/boundingbox.py +++ b/luxonis_train/utils/boundingbox.py @@ -19,6 +19,7 @@ def dist2bbox( distance: Tensor, anchor_points: Tensor, out_format: BBoxFormatType = "xyxy", + dim: int = -1, ) -> Tensor: """Transform distance (ltrb) to box ("xyxy", "xywh" or "cxcywh"). @@ -29,12 +30,14 @@ def dist2bbox( @type out_format: BBoxFormatType @param out_format: BBox output format. Defaults to "xyxy". @rtype: Tensor + @param dim: Dimension to split distance tensor. Defaults to -1. + @rtype: Tensor @return: BBoxes in correct format """ - lt, rb = torch.split(distance, 2, -1) + lt, rb = torch.split(distance, 2, dim=dim) x1y1 = anchor_points - lt x2y2 = anchor_points + rb - bbox = torch.cat([x1y1, x2y2], -1) + bbox = torch.cat([x1y1, x2y2], dim=dim) if out_format in ["xyxy", "xywh", "cxcywh"]: bbox = box_convert(bbox, in_fmt="xyxy", out_fmt=out_format) else: @@ -401,6 +404,39 @@ def anchors_for_fpn_features( ) +def apply_bounding_box_to_masks( + masks: Tensor, bounding_boxes: Tensor +) -> Tensor: + """Crops the given masks to the regions specified by the + corresponding bounding boxes. + + @type masks: Tensor + @param masks: Masks tensor of shape [n, h, w]. + @type bounding_boxes: Tensor + @param bounding_boxes: Bounding boxes tensor of shape [n, 4]. + @rtype: Tensor + @return: Cropped masks tensor of shape [n, h, w]. + """ + _, mask_height, mask_width = masks.shape + left, top, right, bottom = torch.split( + bounding_boxes[:, :, None], 1, dim=1 + ) + width_indices = torch.arange( + mask_width, device=masks.device, dtype=left.dtype + )[None, None, :] + height_indices = torch.arange( + mask_height, device=masks.device, dtype=left.dtype + )[None, :, None] + + cropped_masks = masks * ( + (width_indices >= left) + & (width_indices < right) + & (height_indices >= top) + & (height_indices < bottom) + ) + return cropped_masks + + def compute_iou_loss( pred_bboxes: Tensor, target_bboxes: Tensor, diff --git a/tests/integration/test_detection.py b/tests/integration/test_detection.py index 45e83f0a..8b527ead 100644 --- a/tests/integration/test_detection.py +++ b/tests/integration/test_detection.py @@ -26,6 +26,10 @@ def get_opts_backbone(backbone: str) -> dict[str, Any]: }, "inputs": [backbone], }, + { + "name": "PrecisionBBoxHead", + "inputs": [backbone], + }, ], "losses": [ { @@ -37,6 +41,10 @@ def get_opts_backbone(backbone: str) -> dict[str, Any]: "attached_to": "EfficientKeypointBBoxHead", "params": {"area_factor": 0.5}, }, + { + "name": "PrecisionDFLDetectionLoss", + "attached_to": "PrecisionBBoxHead", + }, ], "metrics": [ { @@ -48,6 +56,10 @@ def get_opts_backbone(backbone: str) -> dict[str, Any]: "alias": "EfficientKeypointBBoxHead-MaP", "attached_to": "EfficientKeypointBBoxHead", }, + { + "name": "MeanAveragePrecision", + "attached_to": "PrecisionBBoxHead", + }, ], } } @@ -72,18 +84,30 @@ def get_opts_variant(variant: str) -> dict[str, Any]: "name": "EfficientBBoxHead", "inputs": ["neck"], }, + { + "name": "PrecisionBBoxHead", + "inputs": ["neck"], + }, ], "losses": [ { "name": "AdaptiveDetectionLoss", "attached_to": "EfficientBBoxHead", }, + { + "name": "PrecisionDFLDetectionLoss", + "attached_to": "PrecisionBBoxHead", + }, ], "metrics": [ { "name": "MeanAveragePrecision", "attached_to": "EfficientBBoxHead", }, + { + "name": "MeanAveragePrecision", + "attached_to": "PrecisionBBoxHead", + }, ], } } @@ -111,6 +135,7 @@ def test_backbones( ): opts = get_opts_backbone(backbone) opts["loader.params.dataset_name"] = parking_lot_dataset.identifier + opts["trainer.epochs"] = 1 train_and_test(config, opts) @@ -122,4 +147,5 @@ def test_variants( ): opts = get_opts_variant(variant) opts["loader.params.dataset_name"] = parking_lot_dataset.identifier + opts["trainer.epochs"] = 1 train_and_test(config, opts) From 6b7710ebebc93bf799f25f9f70d8f714efffc37c Mon Sep 17 00:00:00 2001 From: Jernej Sabadin Date: Tue, 10 Dec 2024 06:40:16 +0100 Subject: [PATCH 02/31] feat: new loader, new visualizer --- .../losses/precision_dlf_segmentation_loss.py | 26 +- .../attached_modules/visualizers/__init__.py | 2 + .../instance_segmentation_visualizer.py | 241 ++++++++++++++++++ .../attached_modules/visualizers/utils.py | 1 + luxonis_train/core/core.py | 42 ++- luxonis_train/enums.py | 1 + luxonis_train/loaders/luxonis_loader_torch.py | 76 ++++-- luxonis_train/loaders/utils.py | 4 + .../nodes/heads/precision_seg_bbox_head.py | 17 +- luxonis_train/utils/dataset_metadata.py | 17 +- 10 files changed, 351 insertions(+), 76 deletions(-) create mode 100644 luxonis_train/attached_modules/visualizers/instance_segmentation_visualizer.py diff --git a/luxonis_train/attached_modules/losses/precision_dlf_segmentation_loss.py b/luxonis_train/attached_modules/losses/precision_dlf_segmentation_loss.py index 8777cd24..7303d4ca 100644 --- a/luxonis_train/attached_modules/losses/precision_dlf_segmentation_loss.py +++ b/luxonis_train/attached_modules/losses/precision_dlf_segmentation_loss.py @@ -24,7 +24,7 @@ class PrecisionDFLSegmentationLoss(PrecisionDFLDetectionLoss): node: PrecisionSegmentBBoxHead supported_tasks: list[TaskType] = [ TaskType.BOUNDINGBOX, - TaskType.SEGMENTATION, + TaskType.INSTANCE_SEGMENTATION, ] def __init__( @@ -73,15 +73,12 @@ def prepare( [xi.view(self.batch_size, self.node.no, -1) for xi in det_feats], 2 ).split((self.node.reg_max * 4, self.n_classes), 1) target_bbox = self.get_label(labels, TaskType.BOUNDINGBOX) - target_masks = self.get_label( - labels, TaskType.SEGMENTATION - ) # TODO: THIS SHOULD BE REFINED AFTER ANNOTATION REFACTOR IN LUXONIS_ML + img_idx = target_bbox[:, 0] + target_masks = self.get_label(labels, TaskType.INSTANCE_SEGMENTATION) if tuple(target_masks.shape[-2:]) != (mask_h, mask_w): target_masks = F.interpolate( - target_masks, (mask_h, mask_w), mode="nearest" - )[ - 0 - ] # TODO: target_mask should be [1, N_masks, H, W] -> [N_masks, H, W]. Masks are ordered the same way as in target_bbox + target_masks.unsqueeze(0), (mask_h, mask_w), mode="nearest" + ).squeeze(0) pred_distri = pred_distri.permute(0, 2, 1).contiguous() pred_scores = pred_scores.permute(0, 2, 1).contiguous() @@ -121,6 +118,7 @@ def prepare( pred_mask, proto, target_masks, + img_idx, ) def forward( @@ -135,6 +133,7 @@ def forward( pred_masks: Tensor, proto: Tensor, target_masks: Tensor, + img_idx: Tensor, ): max_assigned_scores_sum = max(assigned_scores.sum(), 1) loss_cls = ( @@ -154,17 +153,12 @@ def forward( loss_iou = torch.tensor(0.0).to(pred_distri.device) loss_dfl = torch.tensor(0.0).to(pred_distri.device) - # TODO: after annotation refactor in luxonis-ml, this dummy batch_idx should be updated - batch_idx = torch.tensor([0], device=proto.device).unsqueeze( - -1 - ) # THAT IS WHAT YOLO uses - loss_seg = self.calculate_segmentation_loss( mask_positive, target_masks, assigned_gt_idx, assigned_bboxes, - batch_idx, + img_idx, proto, pred_masks, self.overlap, @@ -174,7 +168,7 @@ def forward( self.class_loss_weight * loss_cls + self.bbox_loss_weight * loss_iou + self.dfl_loss_weight * loss_dfl - + loss_seg * self.bbox_loss_weight + + self.bbox_loss_weight * loss_seg ) sub_losses = { "class": loss_cls.detach(), @@ -183,7 +177,7 @@ def forward( "seg": loss_seg.detach(), } - return loss * self.batch_size, sub_losses + return loss, sub_losses # TODO: Modify after adding corect annotation loading def calculate_segmentation_loss( diff --git a/luxonis_train/attached_modules/visualizers/__init__.py b/luxonis_train/attached_modules/visualizers/__init__.py index 50b90471..1bd65f50 100644 --- a/luxonis_train/attached_modules/visualizers/__init__.py +++ b/luxonis_train/attached_modules/visualizers/__init__.py @@ -1,6 +1,7 @@ from .base_visualizer import BaseVisualizer from .bbox_visualizer import BBoxVisualizer from .classification_visualizer import ClassificationVisualizer +from .instance_segmentation_visualizer import InstanceSegmentationVisualizer from .keypoint_visualizer import KeypointVisualizer from .multi_visualizer import MultiVisualizer from .segmentation_visualizer import SegmentationVisualizer @@ -23,6 +24,7 @@ "KeypointVisualizer", "MultiVisualizer", "SegmentationVisualizer", + "InstanceSegmentationVisualizer", "combine_visualizations", "draw_bounding_box_labels", "draw_keypoint_labels", diff --git a/luxonis_train/attached_modules/visualizers/instance_segmentation_visualizer.py b/luxonis_train/attached_modules/visualizers/instance_segmentation_visualizer.py new file mode 100644 index 00000000..63f8aa37 --- /dev/null +++ b/luxonis_train/attached_modules/visualizers/instance_segmentation_visualizer.py @@ -0,0 +1,241 @@ +import logging + +import torch +from torch import Tensor + +from luxonis_train.enums import TaskType +from luxonis_train.utils import Labels, Packet + +from .base_visualizer import BaseVisualizer +from .utils import ( + Color, + draw_bounding_box_labels, + draw_bounding_boxes, + draw_segmentation_labels, + get_color, +) + +logger = logging.getLogger(__name__) + + +class InstanceSegmentationVisualizer(BaseVisualizer[Tensor, Tensor]): + """Visualizer for instance segmentation tasks, supporting the + visualization of predicted and ground truth bounding boxes and + instance masks.""" + + supported_tasks: list[TaskType] = [ + TaskType.INSTANCE_SEGMENTATION, + TaskType.BOUNDINGBOX, + ] + + def __init__( + self, + labels: dict[int, str] | list[str] | None = None, + draw_labels: bool = True, + colors: dict[str, Color] | list[Color] | None = None, + fill: bool = False, + width: int | None = None, + font: str | None = None, + font_size: int | None = None, + alpha: float = 0.6, + **kwargs, + ): + """Initialize the visualizer with customization options for + appearance. + + Parameters: + - labels: A dictionary or list mapping class indices to labels. Defaults to None. + - draw_labels: Whether to draw labels on bounding boxes. Defaults to True. + - colors: Colors for each class. Can be a dictionary or list. Defaults to None. + - fill: Whether to fill bounding boxes. Defaults to False. + - width: Line width for bounding boxes. Defaults to None (adaptive). + - font: Font to use for labels. Defaults to None. + - font_size: Font size for labels. Defaults to None. + - alpha: Transparency for instance masks. Defaults to 0.6. + """ + super().__init__(**kwargs) + + if isinstance(labels, list): + labels = {i: label for i, label in enumerate(labels)} + + self.bbox_labels = labels or { + i: label for i, label in enumerate(self.class_names) + } + + if colors is None: + colors = { + label: get_color(i) for i, label in self.bbox_labels.items() + } + if isinstance(colors, list): + colors = { + self.bbox_labels[i]: color for i, color in enumerate(colors) + } + + self.colors = colors + self.fill = fill + self.width = width + self.font = font + self.font_size = font_size + self.draw_labels = draw_labels + self.alpha = alpha + + def prepare( + self, inputs: Packet[Tensor], labels: Labels | None + ) -> tuple[Tensor, Tensor, list[Tensor], Tensor | None, Tensor | None]: + """ + TODO: Docstring + """ + target_bboxes = labels["boundingbox"][0] + target_masks = labels["instance_segmentation"][0] + predicted_bboxes = inputs["boundingbox"] + predicted_masks = inputs["instance_segmentation"] + + return target_bboxes, target_masks, predicted_bboxes, predicted_masks + + def draw_predictions( + self, + canvas: Tensor, + pred_bboxes: list[Tensor], + pred_masks: list[Tensor], + width: int | None, + label_dict: dict[int, str], + color_dict: dict[str, Color], + draw_labels: bool, + alpha: float, + ) -> Tensor: + """Draw predicted bounding boxes and masks on the canvas.""" + viz = torch.zeros_like(canvas) + + for i in range(len(canvas)): + viz[i] = canvas[i].clone() + prediction = pred_bboxes[i] + masks = pred_masks[i] + prediction_classes = prediction[..., 5].int() + + cls_labels = ( + [label_dict[int(c)] for c in prediction_classes] + if draw_labels and label_dict is not None + else None + ) + cls_colors = ( + [color_dict[label_dict[int(c)]] for c in prediction_classes] + if color_dict is not None and label_dict is not None + else None + ) + + *_, H, W = canvas.shape + width = width or max(1, int(min(H, W) / 100)) + + try: + for j, mask in enumerate(masks): + print(f"mask.sum(): {mask.sum()}") + viz[i] = draw_segmentation_labels( + viz[i], + mask.unsqueeze(0), + colors=[cls_colors[j]], + alpha=alpha, + ).to(canvas.device) + + viz[i] = draw_bounding_boxes( + viz[i], + prediction[:, :4], + width=width, + labels=cls_labels, + colors=cls_colors, + ).to(canvas.device) + except ValueError as e: + logger.warning( + f"Failed to draw bounding boxes or masks: {e}. Skipping visualization." + ) + viz[i] = canvas[i] + + return viz + + @staticmethod + def draw_targets( + canvas: Tensor, + target_bboxes: Tensor, + target_masks: Tensor, + width: int | None, + label_dict: dict[int, str], + color_dict: dict[str, Color], + draw_labels: bool, + alpha: float, + ) -> Tensor: + """Draw ground truth bounding boxes and masks on the canvas.""" + viz = torch.zeros_like(canvas) + + for i in range(len(canvas)): + viz[i] = canvas[i].clone() + image_targets = target_bboxes[target_bboxes[:, 0] == i] + image_masks = target_masks[target_bboxes[:, 0] == i] + target_classes = image_targets[:, 1].int() + + cls_labels = ( + [label_dict[int(c)] for c in target_classes] + if draw_labels and label_dict is not None + else None + ) + cls_colors = ( + [color_dict[label_dict[int(c)]] for c in target_classes] + if color_dict is not None and label_dict is not None + else None + ) + + *_, H, W = canvas.shape + width = width or max(1, int(min(H, W) / 100)) + + for j, (bbox, mask) in enumerate( + zip(image_targets[:, 2:], image_masks) + ): + print(f"sum(mask): {mask.sum()}") + viz[i] = draw_segmentation_labels( + viz[i], + mask.unsqueeze(0), + alpha=alpha, + colors=[cls_colors[j]], + ).to(canvas.device) + viz[i] = draw_bounding_box_labels( + viz[i], + bbox.unsqueeze(0), + width=width, + labels=[cls_labels[j]] if cls_labels else None, + colors=[cls_colors[j]], + ).to(canvas.device) + + return viz + + def forward( + self, + label_canvas: Tensor, + prediction_canvas: Tensor, + target_bboxes: Tensor | None, + target_masks: Tensor | None, + predicted_bboxes: Tensor, + predicted_masks: Tensor, + ) -> tuple[Tensor, Tensor] | Tensor: + """Visualize predictions and ground truth.""" + predictions_viz = self.draw_predictions( + prediction_canvas, + predicted_bboxes, + predicted_masks, + self.width, + self.bbox_labels, + self.colors, + self.draw_labels, + self.alpha, + ) + if target_bboxes is None or target_masks is None: + return predictions_viz + + targets_viz = self.draw_targets( + label_canvas, + target_bboxes, + target_masks, + self.width, + self.bbox_labels, + self.colors, + self.draw_labels, + self.alpha, + ) + return targets_viz, predictions_viz diff --git a/luxonis_train/attached_modules/visualizers/utils.py b/luxonis_train/attached_modules/visualizers/utils.py index 45ec454b..d6d710c6 100644 --- a/luxonis_train/attached_modules/visualizers/utils.py +++ b/luxonis_train/attached_modules/visualizers/utils.py @@ -118,6 +118,7 @@ def draw_segmentation_labels( @rtype: Tensor @return: Image with segmentation labels drawn on. """ + print(f"sum(label): {label.sum()}") masks = label.bool() masks = masks.cpu() img = img.cpu() diff --git a/luxonis_train/core/core.py b/luxonis_train/core/core.py index 2a3f3678..86ee4590 100644 --- a/luxonis_train/core/core.py +++ b/luxonis_train/core/core.py @@ -13,7 +13,6 @@ import torch.utils.data as torch_data import yaml from lightning.pytorch.utilities import rank_zero_only -from luxonis_ml.data import Augmentations from luxonis_ml.nn_archive import ArchiveGenerator from luxonis_ml.nn_archive.config import CONFIG_VERSION from luxonis_ml.utils import LuxonisFileSystem, reset_logging, setup_logging @@ -113,25 +112,11 @@ def __init__( precision=self.cfg.trainer.precision, ) - self.train_augmentations = Augmentations( - image_size=self.cfg.trainer.preprocessing.train_image_size, - augmentations=[ - i.model_dump() - for i in self.cfg.trainer.preprocessing.get_active_augmentations() - ], - train_rgb=self.cfg.trainer.preprocessing.train_rgb, - keep_aspect_ratio=self.cfg.trainer.preprocessing.keep_aspect_ratio, - ) - self.val_augmentations = Augmentations( - image_size=self.cfg.trainer.preprocessing.train_image_size, - augmentations=[ - i.model_dump() - for i in self.cfg.trainer.preprocessing.get_active_augmentations() - ], - train_rgb=self.cfg.trainer.preprocessing.train_rgb, - keep_aspect_ratio=self.cfg.trainer.preprocessing.keep_aspect_ratio, - only_normalize=True, - ) + self.train_augmentations = [ + i.model_dump() + for i in self.cfg.trainer.preprocessing.get_active_augmentations() + ] + self.val_augmentations = self.train_augmentations self.loaders: dict[str, BaseLoaderTorch] = {} for view in ["train", "val", "test"]: @@ -141,16 +126,23 @@ def __init__( self.cfg.loader.params["delete_existing"] = False self.loaders[view] = Loader( - augmentations=( - self.train_augmentations - if view == "train" - else self.val_augmentations - ), view={ "train": self.cfg.loader.train_view, "val": self.cfg.loader.val_view, "test": self.cfg.loader.test_view, }[view], + augmentation_engine="albumentations", + augmentation_config=( + self.train_augmentations + if view == "train" + else self.val_augmentations + ), + height=self.cfg.trainer.preprocessing.train_image_size[0], + width=self.cfg.trainer.preprocessing.train_image_size[1], + keep_aspect_ratio=self.cfg.trainer.preprocessing.keep_aspect_ratio, + out_image_format="RGB" + if self.cfg.trainer.preprocessing.train_rgb + else "BGR", image_source=self.cfg.loader.image_source, **self.cfg.loader.params, ) diff --git a/luxonis_train/enums.py b/luxonis_train/enums.py index b024d6a9..88ee5c9d 100644 --- a/luxonis_train/enums.py +++ b/luxonis_train/enums.py @@ -10,3 +10,4 @@ class TaskType(str, Enum): KEYPOINTS = "keypoints" LABEL = "label" ARRAY = "array" + INSTANCE_SEGMENTATION = "instance_segmentation" diff --git a/luxonis_train/loaders/luxonis_loader_torch.py b/luxonis_train/loaders/luxonis_loader_torch.py index 230128b5..9cc910e4 100644 --- a/luxonis_train/loaders/luxonis_loader_torch.py +++ b/luxonis_train/loaders/luxonis_loader_torch.py @@ -1,9 +1,8 @@ import logging -from typing import Literal +from typing import List, Literal, Optional, Union import numpy as np from luxonis_ml.data import ( - Augmentations, BucketStorage, BucketType, LuxonisDataset, @@ -25,16 +24,22 @@ class LuxonisLoaderTorch(BaseLoaderTorch): @typechecked def __init__( self, - dataset_name: str | None = None, - dataset_dir: str | None = None, - dataset_type: DatasetType | None = None, - team_id: str | None = None, + dataset_name: Optional[str] = None, + dataset_dir: Optional[str] = None, + dataset_type: Optional[DatasetType] = None, + team_id: Optional[str] = None, bucket_type: Literal["internal", "external"] = "internal", bucket_storage: Literal["local", "s3", "gcs", "azure"] = "local", stream: bool = False, delete_existing: bool = True, - view: str | list[str] = "train", - augmentations: Augmentations | None = None, + view: Union[str, List[str]] = "train", + augmentation_engine: Literal["albumentations"] = "albumentations", + augmentation_config: Optional[Union[List, str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + keep_aspect_ratio: bool = False, + out_image_format: Literal["RGB", "BGR"] = "RGB", + force_resync: bool = False, **kwargs, ): """Torch-compatible loader for Luxonis datasets. @@ -69,15 +74,30 @@ def __init__( because the underlying data might have changed. If C{delete_existing} is set to C{False} and a dataset of the same name already exists, the existing dataset will be used instead of re-parsing the data. - @type view: str | list[str] + @type view: Union[str, List[str]] @param view: A single split or a list of splits that will be used to create a view of the dataset. Each split is a string that represents a subset of the dataset. The available splits depend on the dataset, but usually include 'train', 'val', and 'test'. Defaults to 'train'. - @type augmentations: Augmentations | None - @param augmentations: Augmentations to apply to the data. Defaults to C{None}. + @type augmentation_engine: Literal["albumentations"] + @param augmentation_engine: Engine to use for applying augmentations. + Defaults to 'albumentations'. + @type augmentation_config: List | str | None + @param augmentation_config: Augmentation configuration as a list or path to a + configuration file. Defaults to C{None}. + @type height: int | None + @param height: Optional height to resize the images. + @type width: int | None + @param width: Optional width to resize the images. + @type keep_aspect_ratio: bool + @param keep_aspect_ratio: Flag to maintain aspect ratio during resizing. + @type out_image_format: Literal["RGB", "BGR"] + @param out_image_format: Format of the output images. Defaults to 'RGB'. + @type force_resync: bool + @param force_resync: Force a resynchronization of the dataset. Defaults to False. """ - super().__init__(view=view, augmentations=augmentations, **kwargs) + super().__init__(view=view, **kwargs) + if dataset_dir is not None: self.dataset = self._parse_dataset( dataset_dir, dataset_name, dataset_type, delete_existing @@ -93,11 +113,17 @@ def __init__( bucket_type=BucketType(bucket_type), bucket_storage=BucketStorage(bucket_storage), ) + self.base_loader = LuxonisLoader( dataset=self.dataset, view=self.view, - stream=stream, - augmentations=self.augmentations, + augmentation_engine=augmentation_engine, + augmentation_config=augmentation_config, + height=height, + width=width, + keep_aspect_ratio=keep_aspect_ratio, + out_image_format=out_image_format, + force_resync=force_resync, ) def __len__(self) -> int: @@ -114,9 +140,12 @@ def __getitem__(self, idx: int) -> LuxonisLoaderTorchOutput: img = np.transpose(img, (2, 0, 1)) # HWC to CHW tensor_img = Tensor(img) tensor_labels: dict[str, tuple[Tensor, TaskType]] = {} - for task, (array, label_type) in labels.items(): - tensor_labels[task] = (Tensor(array), TaskType(label_type.value)) - + for task_with_type, array in labels.items(): + task_parts = task_with_type.split("/") + if len(task_parts) != 2: + raise ValueError(f"Invalid task format: {task_with_type}") + _, task_type = task_parts + tensor_labels[task_type] = (Tensor(array), TaskType(task_type)) return {self.image_source: tensor_img}, tensor_labels def get_classes(self) -> dict[str, list[str]]: @@ -130,8 +159,8 @@ def get_n_keypoints(self) -> dict[str, int]: def _parse_dataset( self, dataset_dir: str, - dataset_name: str | None, - dataset_type: DatasetType | None, + dataset_name: Optional[str], + dataset_type: Optional[DatasetType], delete_existing: bool, ) -> LuxonisDataset: if dataset_name is None: @@ -144,21 +173,18 @@ def _parse_dataset( logger.warning( f"Dataset {dataset_name} already exists. " "The dataset will be generated again to ensure the latest data are used. " - "If you don't want to regenerate the dataset every time, set `delete_existing=False`'" + "If you don't want to regenerate the dataset every time, set `delete_existing=False`." ) if dataset_type is None: logger.warning( - "Dataset type is not set. " - "Attempting to infer it from the directory structure. " - "If this fails, please set the dataset type manually. " - f"Supported types are: {', '.join(DatasetType.__members__)}." + "Dataset type is not set. Attempting to infer it from the directory structure. " + "If this fails, please set the dataset type manually." ) logger.info( f"Parsing dataset from {dataset_dir} with name '{dataset_name}'" ) - return LuxonisParser( dataset_dir, dataset_name=dataset_name, diff --git a/luxonis_train/loaders/utils.py b/luxonis_train/loaders/utils.py index b030e218..2e3c4b82 100644 --- a/luxonis_train/loaders/utils.py +++ b/luxonis_train/loaders/utils.py @@ -50,4 +50,8 @@ def collate_fn( label_box.append(l_box) out_labels[task] = torch.cat(label_box, 0), task_type + elif task_type == TaskType.INSTANCE_SEGMENTATION: + masks = [label[task][0] for label in labels] + out_labels[task] = torch.cat(masks, 0), task_type + return out_inputs, out_labels diff --git a/luxonis_train/nodes/heads/precision_seg_bbox_head.py b/luxonis_train/nodes/heads/precision_seg_bbox_head.py index 5cfc3e60..3c869e6d 100644 --- a/luxonis_train/nodes/heads/precision_seg_bbox_head.py +++ b/luxonis_train/nodes/heads/precision_seg_bbox_head.py @@ -16,7 +16,10 @@ class PrecisionSegmentBBoxHead(PrecisionBBoxHead): - tasks: list[TaskType] = [TaskType.SEGMENTATION, TaskType.BOUNDINGBOX] + tasks: list[TaskType] = [ + TaskType.INSTANCE_SEGMENTATION, + TaskType.BOUNDINGBOX, + ] def __init__( self, @@ -120,11 +123,13 @@ def wrap( "prototypes": prototypes, "mask_coeficients": mask_coefficients, "boundingbox": [], - "segmentation": [], # TODO: Sync on how we want to visualize this + "instance_segmentation": [], } - for i, pred in enumerate(preds): - results["segmentation"].append( + for i, pred in enumerate( + preds + ): # TODO: Investigate low seg loss but wrong masks + results["instance_segmentation"].append( refine_and_apply_masks( prototypes[i], pred[:, 6:], @@ -168,6 +173,10 @@ def refine_and_apply_masks( where the masks are cropped according to their respective bounding boxes. """ + if predicted_masks.size(0) == 0 or bounding_boxes.size(0) == 0: + img_h, img_w = target_shape + return torch.zeros(0, img_h, img_w, dtype=torch.uint8) + channels, proto_h, proto_w = mask_prototypes.shape img_h, img_w = target_shape masks_combined = ( diff --git a/luxonis_train/utils/dataset_metadata.py b/luxonis_train/utils/dataset_metadata.py index 3a9cecdf..f79e0232 100644 --- a/luxonis_train/utils/dataset_metadata.py +++ b/luxonis_train/utils/dataset_metadata.py @@ -1,3 +1,5 @@ +import warnings + from luxonis_train.loaders import BaseLoaderTorch @@ -43,10 +45,11 @@ def n_classes(self, task: str | None = None) -> int: """ if task is not None: if task not in self._classes: - raise ValueError( - f"Task '{task}' is not present in the dataset." + # TODO: rework this + warnings.warn( + f"Task '{task}' is not present in the dataset. Ignoring the task argument.", + UserWarning, ) - return len(self._classes[task]) n_classes = len(list(self._classes.values())[0]) for classes in self._classes.values(): if len(classes) != n_classes: @@ -99,10 +102,12 @@ def classes(self, task: str | None = None) -> list[str]: """ if task is not None: if task not in self._classes: - raise ValueError( - f"Task type {task} is not present in the dataset." + # TODO: rework this + warnings.warn( + f"Task '{task}' is not present in the dataset. Ignoring the task argument.", + UserWarning, ) - return self._classes[task] + task = None class_names = list(self._classes.values())[0] for classes in self._classes.values(): if classes != class_names: From 0f842e17d1c17ac4aad58a7e71a7cdb1c9989a88 Mon Sep 17 00:00:00 2001 From: Jernej Sabadin Date: Tue, 10 Dec 2024 14:25:29 +0100 Subject: [PATCH 03/31] fix: seg loss, batch vis, and mAP for seg --- .../losses/precision_dfl_detection_loss.py | 4 +- .../losses/precision_dlf_segmentation_loss.py | 168 ++++++------------ .../metrics/mean_average_precision.py | 132 ++++++++++---- .../instance_segmentation_visualizer.py | 119 +++++++------ .../attached_modules/visualizers/utils.py | 1 - .../nodes/heads/precision_bbox_head.py | 4 +- .../nodes/heads/precision_seg_bbox_head.py | 4 +- 7 files changed, 231 insertions(+), 201 deletions(-) diff --git a/luxonis_train/attached_modules/losses/precision_dfl_detection_loss.py b/luxonis_train/attached_modules/losses/precision_dfl_detection_loss.py index d682aeea..5351ec60 100644 --- a/luxonis_train/attached_modules/losses/precision_dfl_detection_loss.py +++ b/luxonis_train/attached_modules/losses/precision_dfl_detection_loss.py @@ -39,7 +39,9 @@ def __init__( **kwargs: Any, ): """BBox loss adapted from U{Real-Time Flying Object Detection with YOLOv8 - } + } and from U{YOLOv6: A Single-Stage Object Detection Framework for Industrial Applications + }. + Code is adapted from U{https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/models}. @type reg_max: int @param reg_max: Maximum number of regression channels. Defaults to 16. diff --git a/luxonis_train/attached_modules/losses/precision_dlf_segmentation_loss.py b/luxonis_train/attached_modules/losses/precision_dlf_segmentation_loss.py index 7303d4ca..8808dc2c 100644 --- a/luxonis_train/attached_modules/losses/precision_dlf_segmentation_loss.py +++ b/luxonis_train/attached_modules/losses/precision_dlf_segmentation_loss.py @@ -34,11 +34,12 @@ def __init__( class_loss_weight: float = 0.5, bbox_loss_weight: float = 7.5, dfl_loss_weight: float = 1.5, - overlap_mask: bool = True, **kwargs: Any, ): """Instance Segmentation and BBox loss adapted from U{Real-Time Flying Object Detection with YOLOv8 - } + } and from U{YOLOv6: A Single-Stage Object Detection Framework for Industrial Applications + }. + Code is adapted from U{https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/models}. @type reg_max: int @param reg_max: Maximum number of regression channels. Defaults to 16. @@ -59,7 +60,6 @@ def __init__( dfl_loss_weight=dfl_loss_weight, **kwargs, ) - self.overlap = overlap_mask def prepare( self, inputs: Packet[Tensor], labels: Labels @@ -73,7 +73,7 @@ def prepare( [xi.view(self.batch_size, self.node.no, -1) for xi in det_feats], 2 ).split((self.node.reg_max * 4, self.n_classes), 1) target_bbox = self.get_label(labels, TaskType.BOUNDINGBOX) - img_idx = target_bbox[:, 0] + img_idx = target_bbox[:, 0].unsqueeze(-1) target_masks = self.get_label(labels, TaskType.INSTANCE_SEGMENTATION) if tuple(target_masks.shape[-2:]) != (mask_h, mask_w): target_masks = F.interpolate( @@ -153,7 +153,7 @@ def forward( loss_iou = torch.tensor(0.0).to(pred_distri.device) loss_dfl = torch.tensor(0.0).to(pred_distri.device) - loss_seg = self.calculate_segmentation_loss( + loss_seg = self.compute_segmentation_loss( mask_positive, target_masks, assigned_gt_idx, @@ -161,7 +161,6 @@ def forward( img_idx, proto, pred_masks, - self.overlap, ) loss = ( @@ -179,122 +178,65 @@ def forward( return loss, sub_losses - # TODO: Modify after adding corect annotation loading - def calculate_segmentation_loss( + def compute_segmentation_loss( self, fg_mask: torch.Tensor, - masks: torch.Tensor, - target_gt_idx: torch.Tensor, - target_bboxes: torch.Tensor, - batch_idx: torch.Tensor, + gt_masks: torch.Tensor, + gt_idx: torch.Tensor, + bboxes: torch.Tensor, + batch_ids: torch.Tensor, proto: torch.Tensor, pred_masks: torch.Tensor, - overlap: bool, ) -> torch.Tensor: - """Calculate the loss for instance segmentation. - - Args: - fg_mask (torch.Tensor): A binary tensor of shape (BS, N_anchors) indicating which anchors are positive. - masks (torch.Tensor): Ground truth masks of shape (BS, H, W) if `overlap` is False, otherwise (BS, ?, H, W). - target_gt_idx (torch.Tensor): Indexes of ground truth objects for each anchor of shape (BS, N_anchors). - target_bboxes (torch.Tensor): Ground truth bounding boxes for each anchor of shape (BS, N_anchors, 4). - batch_idx (torch.Tensor): Batch indices of shape (N_labels_in_batch, 1). - proto (torch.Tensor): Prototype masks of shape (BS, 32, H, W). - pred_masks (torch.Tensor): Predicted masks for each anchor of shape (BS, N_anchors, 32). - imgsz (torch.Tensor): Size of the input image as a tensor of shape (2), i.e., (H, W). - overlap (bool): Whether the masks in `masks` tensor overlap. - - Returns: - (torch.Tensor): The calculated loss for instance segmentation. - - Notes: - The batch loss can be computed for improved speed at higher memory usage. - For example, pred_mask can be computed as follows: - pred_mask = torch.einsum('in,nhw->ihw', pred, proto) # (i, 32) @ (32, 160, 160) -> (i, 160, 160) + """Compute the segmentation loss for the entire batch. + + @type fg_mask: torch.Tensor + @param fg_mask: Foreground mask. Shape: (B, N_anchor). + @type gt_masks: torch.Tensor + @param gt_masks: Ground truth masks. Shape: (n, H, W). + @type gt_idx: torch.Tensor + @param gt_idx: Ground truth mask indices. Shape: (B, N_anchor). + @type bboxes: torch.Tensor + @param bboxes: Ground truth bounding boxes in xyxy format. + Shape: (B, N_anchor, 4). + @type batch_ids: torch.Tensor + @param batch_ids: Batch indices. Shape: (n, 1). + @type proto: torch.Tensor + @param proto: Prototype masks. Shape: (B, 32, H, W). + @type pred_masks: torch.Tensor + @param pred_masks: Predicted mask coefficients. Shape: (B, + N_anchor, 32). """ - _, _, mask_h, mask_w = proto.shape - loss = 0 - - # Normalize to 0-1 - target_bboxes_normalized = target_bboxes / self.gt_bboxes_scale - - # Areas of target bboxes - marea = box_convert( - target_bboxes_normalized, in_fmt="xyxy", out_fmt="xywh" - )[..., 2:].prod(2) - - # Normalize to mask size - mxyxy = target_bboxes_normalized * torch.tensor( - [mask_w, mask_h, mask_w, mask_h], device=proto.device + _, _, h, w = proto.shape + total_loss = 0 + bboxes_norm = bboxes / self.gt_bboxes_scale + bbox_area = box_convert(bboxes_norm, in_fmt="xyxy", out_fmt="xywh")[ + ..., 2: + ].prod(2) + bboxes_scaled = bboxes_norm * torch.tensor( + [w, h, w, h], device=proto.device ) - for i, single_i in enumerate( - zip(fg_mask, target_gt_idx, pred_masks, proto, mxyxy, marea, masks) + for img_idx, data in enumerate( + zip(fg_mask, gt_idx, pred_masks, proto, bboxes_scaled, bbox_area) ): - ( - fg_mask_i, - target_gt_idx_i, - pred_masks_i, - proto_i, - mxyxy_i, - marea_i, - masks_i, - ) = single_i - if fg_mask_i.any(): - mask_idx = target_gt_idx_i[fg_mask_i] - if overlap: - gt_mask = masks_i == (mask_idx + 1).view(-1, 1, 1) - gt_mask = gt_mask.float() - else: - gt_mask = masks[batch_idx.view(-1) == i][mask_idx] - - loss += self.single_mask_loss( - gt_mask, - pred_masks_i[fg_mask_i], - proto_i, - mxyxy_i[fg_mask_i], - marea_i[fg_mask_i], + fg, gt, pred, pr, bbox, area = data + if fg.any(): + mask_ids = gt[fg] + gt_mask = gt_masks[batch_ids.view(-1) == img_idx][mask_ids] + + # Compute individual image mask loss + pred_mask = torch.einsum("in,nhw->ihw", pred[fg], pr) + loss = F.binary_cross_entropy_with_logits( + pred_mask, gt_mask, reduction="none" ) - - # WARNING: lines below prevents Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove + total_loss += ( + apply_bounding_box_to_masks(loss, bbox[fg]).mean( + dim=(1, 2) + ) + / area[fg] + ).sum() else: - loss += (proto * 0).sum() + ( - pred_masks * 0 - ).sum() # inf sums may lead to nan loss - - return loss / fg_mask.sum() + total_loss += (proto * 0).sum() + (pred_masks * 0).sum() - # TODO: Modify after adding corect annotation loading - @staticmethod - def single_mask_loss( - gt_mask: torch.Tensor, - pred: torch.Tensor, - proto: torch.Tensor, - xyxy: torch.Tensor, - area: torch.Tensor, - ) -> torch.Tensor: - """Compute the instance segmentation loss for a single image. - - Args: - gt_mask (torch.Tensor): Ground truth mask of shape (n, H, W), where n is the number of objects. - pred (torch.Tensor): Predicted mask coefficients of shape (n, 32). - proto (torch.Tensor): Prototype masks of shape (32, H, W). - xyxy (torch.Tensor): Ground truth bounding boxes in xyxy format, normalized to [0, 1], of shape (n, 4). - area (torch.Tensor): Area of each ground truth bounding box of shape (n,). - - Returns: - (torch.Tensor): The calculated mask loss for a single image. - - Notes: - The function uses the equation pred_mask = torch.einsum('in,nhw->ihw', pred, proto) to produce the - predicted masks from the prototype masks and predicted mask coefficients. - """ - pred_mask = torch.einsum( - "in,nhw->ihw", pred, proto - ) # (n, 32) @ (32, 80, 80) -> (n, 80, 80) - loss = F.binary_cross_entropy_with_logits( - pred_mask, gt_mask, reduction="none" - ) - return ( - apply_bounding_box_to_masks(loss, xyxy).mean(dim=(1, 2)) / area - ).sum() + return total_loss / fg_mask.sum() diff --git a/luxonis_train/attached_modules/metrics/mean_average_precision.py b/luxonis_train/attached_modules/metrics/mean_average_precision.py index 56937115..d0ed6b4c 100644 --- a/luxonis_train/attached_modules/metrics/mean_average_precision.py +++ b/luxonis_train/attached_modules/metrics/mean_average_precision.py @@ -1,5 +1,6 @@ from typing import Any +import torch import torchmetrics.detection as detection from torch import Tensor from torchvision.ops import box_convert @@ -14,18 +15,30 @@ class MeanAveragePrecision( BaseMetric[list[dict[str, Tensor]], list[dict[str, Tensor]]] ): """Compute the Mean-Average-Precision (mAP) and Mean-Average-Recall - (mAR) for object detection predictions. + (mAR) for object detection predictions and instance segmentation. Adapted from U{Mean-Average-Precision (mAP) and Mean-Average-Recall (mAR) }. """ - supported_tasks: list[TaskType] = [TaskType.BOUNDINGBOX] + supported_tasks: list[TaskType] = [ + TaskType.BOUNDINGBOX, + TaskType.INSTANCE_SEGMENTATION, + ] def __init__(self, **kwargs: Any): super().__init__(**kwargs) - self.metric = detection.MeanAveragePrecision() + self.is_segmentation = ( + TaskType.INSTANCE_SEGMENTATION in self.node.tasks + ) + + if self.is_segmentation: + iou_type = ("bbox", "segm") + else: + iou_type = "bbox" + + self.metric = detection.MeanAveragePrecision(iou_type=iou_type) def update( self, @@ -37,29 +50,51 @@ def update( def prepare( self, inputs: Packet[Tensor], labels: Labels ) -> tuple[list[dict[str, Tensor]], list[dict[str, Tensor]]]: - box_label = self.get_label(labels) - output_nms = self.get_input_tensors(inputs) - + box_label = self.get_label(labels, TaskType.BOUNDINGBOX) + mask_label = ( + self.get_label(labels, TaskType.INSTANCE_SEGMENTATION) + if self.is_segmentation + else None + ) + + output_nms_bboxes = self.get_input_tensors(inputs, "boundingbox") + output_nms_masks = ( + self.get_input_tensors(inputs, "instance_segmentation") + if self.is_segmentation + else None + ) image_size = self.original_in_shape[1:] output_list: list[dict[str, Tensor]] = [] label_list: list[dict[str, Tensor]] = [] - for i in range(len(output_nms)): - output_list.append( - { - "boxes": output_nms[i][:, :4], - "scores": output_nms[i][:, 4], - "labels": output_nms[i][:, 5].int(), - } - ) - + for i in range(len(output_nms_bboxes)): + # Prepare predictions + pred = { + "boxes": output_nms_bboxes[i][:, :4], + "scores": output_nms_bboxes[i][:, 4], + "labels": output_nms_bboxes[i][:, 5].int(), + } + if self.is_segmentation: + pred["masks"] = output_nms_masks[i].to( + dtype=torch.bool + ) # Predicted masks (M, H, W) + output_list.append(pred) + + # Prepare ground truth curr_label = box_label[box_label[:, 0] == i] curr_bboxs = box_convert(curr_label[:, 2:], "xywh", "xyxy") curr_bboxs[:, 0::2] *= image_size[1] curr_bboxs[:, 1::2] *= image_size[0] - label_list.append( - {"boxes": curr_bboxs, "labels": curr_label[:, 1].int()} - ) + + gt = { + "boxes": curr_bboxs, + "labels": curr_label[:, 1].int(), + } + if self.is_segmentation: + gt["masks"] = mask_label[box_label[:, 0] == i].to( + dtype=torch.bool + ) + label_list.append(gt) return output_list, label_list @@ -69,19 +104,48 @@ def reset(self) -> None: def compute(self) -> tuple[Tensor, dict[str, Tensor]]: metric_dict: dict[str, Tensor] = self.metric.compute() - del metric_dict["classes"] - del metric_dict["map_per_class"] - del metric_dict["mar_100_per_class"] - for key in list(metric_dict.keys()): - if "map" in key: - map = metric_dict[key] - mar_key = key.replace("map", "mar") - if mar_key in metric_dict: - mar = metric_dict[mar_key] - metric_dict[key.replace("map", "f1")] = ( - 2 * (map * mar) / (map + mar) - ) - - map = metric_dict.pop("map") - - return map, metric_dict + if self.is_segmentation: + keys_to_remove = [ + "classes", + "bbox_map_per_class", + "bbox_mar_100_per_class", + "segm_map_per_class", + "segm_mar_100_per_class", + ] + for key in keys_to_remove: + if key in metric_dict: + del metric_dict[key] + + for key in list(metric_dict.keys()): + if "map" in key: + map_metric = metric_dict[key] + mar_key = key.replace("map", "mar") + if mar_key in metric_dict: + mar_metric = metric_dict[mar_key] + metric_dict[key.replace("map", "f1")] = ( + 2 + * (map_metric * mar_metric) + / (map_metric + mar_metric) + ) + + scalar = metric_dict.get("segm_map", torch.tensor(0.0)) + else: + del metric_dict["classes"] + del metric_dict["map_per_class"] + del metric_dict["mar_100_per_class"] + + for key in list(metric_dict.keys()): + if "map" in key: + map_metric = metric_dict[key] + mar_key = key.replace("map", "mar") + if mar_key in metric_dict: + mar_metric = metric_dict[mar_key] + metric_dict[key.replace("map", "f1")] = ( + 2 + * (map_metric * mar_metric) + / (map_metric + mar_metric) + ) + + scalar = metric_dict.pop("map", torch.tensor(0.0)) + + return scalar, metric_dict diff --git a/luxonis_train/attached_modules/visualizers/instance_segmentation_visualizer.py b/luxonis_train/attached_modules/visualizers/instance_segmentation_visualizer.py index 63f8aa37..3f1c1ca1 100644 --- a/luxonis_train/attached_modules/visualizers/instance_segmentation_visualizer.py +++ b/luxonis_train/attached_modules/visualizers/instance_segmentation_visualizer.py @@ -21,7 +21,7 @@ class InstanceSegmentationVisualizer(BaseVisualizer[Tensor, Tensor]): """Visualizer for instance segmentation tasks, supporting the visualization of predicted and ground truth bounding boxes and - instance masks.""" + instance segmentation masks.""" supported_tasks: list[TaskType] = [ TaskType.INSTANCE_SEGMENTATION, @@ -40,18 +40,26 @@ def __init__( alpha: float = 0.6, **kwargs, ): - """Initialize the visualizer with customization options for - appearance. - - Parameters: - - labels: A dictionary or list mapping class indices to labels. Defaults to None. - - draw_labels: Whether to draw labels on bounding boxes. Defaults to True. - - colors: Colors for each class. Can be a dictionary or list. Defaults to None. - - fill: Whether to fill bounding boxes. Defaults to False. - - width: Line width for bounding boxes. Defaults to None (adaptive). - - font: Font to use for labels. Defaults to None. - - font_size: Font size for labels. Defaults to None. - - alpha: Transparency for instance masks. Defaults to 0.6. + """Visualizer for instance segmentation tasks. + + @type labels: dict[int, str] | list[str] | None + @param labels: Dictionary mapping class indices to class labels. + @type draw_labels: bool + @param draw_labels: Whether to draw class labels on the + visualizations. + @type colors: dict[str, L{Color}] | list[L{Color}] | None + @param colors: Dicionary mapping class labels to colors. + @type fill: bool | None + @param fill: Whether to fill the boundingbox with color. + @type width: int | None + @param width: Width of the bounding box Lines. + @type font: str | None + @param font: Font of the clas labels. + @type font_size: int | None + @param font_size: Font size of the class Labels. + @type alpha: float + @param alpha: Alpha value of the segmentation masks. Defaults to + C{0.6}. """ super().__init__(**kwargs) @@ -82,9 +90,7 @@ def __init__( def prepare( self, inputs: Packet[Tensor], labels: Labels | None ) -> tuple[Tensor, Tensor, list[Tensor], Tensor | None, Tensor | None]: - """ - TODO: Docstring - """ + # Override the prepare base method target_bboxes = labels["boundingbox"][0] target_masks = labels["instance_segmentation"][0] predicted_bboxes = inputs["boundingbox"] @@ -103,14 +109,13 @@ def draw_predictions( draw_labels: bool, alpha: float, ) -> Tensor: - """Draw predicted bounding boxes and masks on the canvas.""" viz = torch.zeros_like(canvas) for i in range(len(canvas)): viz[i] = canvas[i].clone() - prediction = pred_bboxes[i] - masks = pred_masks[i] - prediction_classes = prediction[..., 5].int() + image_bboxes = pred_bboxes[i] + image_masks = pred_masks[i] + prediction_classes = image_bboxes[..., 5].int() cls_labels = ( [label_dict[int(c)] for c in prediction_classes] @@ -127,18 +132,16 @@ def draw_predictions( width = width or max(1, int(min(H, W) / 100)) try: - for j, mask in enumerate(masks): - print(f"mask.sum(): {mask.sum()}") - viz[i] = draw_segmentation_labels( - viz[i], - mask.unsqueeze(0), - colors=[cls_colors[j]], - alpha=alpha, - ).to(canvas.device) + viz[i] = draw_segmentation_labels( + viz[i], + image_masks, + colors=cls_colors, + alpha=alpha, + ).to(canvas.device) viz[i] = draw_bounding_boxes( viz[i], - prediction[:, :4], + image_bboxes[:, :4], width=width, labels=cls_labels, colors=cls_colors, @@ -162,14 +165,13 @@ def draw_targets( draw_labels: bool, alpha: float, ) -> Tensor: - """Draw ground truth bounding boxes and masks on the canvas.""" viz = torch.zeros_like(canvas) for i in range(len(canvas)): viz[i] = canvas[i].clone() - image_targets = target_bboxes[target_bboxes[:, 0] == i] + image_bboxes = target_bboxes[target_bboxes[:, 0] == i] image_masks = target_masks[target_bboxes[:, 0] == i] - target_classes = image_targets[:, 1].int() + target_classes = image_bboxes[:, 1].int() cls_labels = ( [label_dict[int(c)] for c in target_classes] @@ -185,23 +187,19 @@ def draw_targets( *_, H, W = canvas.shape width = width or max(1, int(min(H, W) / 100)) - for j, (bbox, mask) in enumerate( - zip(image_targets[:, 2:], image_masks) - ): - print(f"sum(mask): {mask.sum()}") - viz[i] = draw_segmentation_labels( - viz[i], - mask.unsqueeze(0), - alpha=alpha, - colors=[cls_colors[j]], - ).to(canvas.device) - viz[i] = draw_bounding_box_labels( - viz[i], - bbox.unsqueeze(0), - width=width, - labels=[cls_labels[j]] if cls_labels else None, - colors=[cls_colors[j]], - ).to(canvas.device) + viz[i] = draw_segmentation_labels( + viz[i], + image_masks, + alpha=alpha, + colors=cls_colors, + ).to(canvas.device) + viz[i] = draw_bounding_box_labels( + viz[i], + image_bboxes[:, 2:], + width=width, + labels=cls_labels if cls_labels else None, + colors=cls_colors, + ).to(canvas.device) return viz @@ -214,7 +212,28 @@ def forward( predicted_bboxes: Tensor, predicted_masks: Tensor, ) -> tuple[Tensor, Tensor] | Tensor: - """Visualize predictions and ground truth.""" + """Creates visualizations of the predicted and target bounding + boxes and instance masks. + + @type label_canvas: Tensor + @param label_canvas: Tensor containing the target + visualizations. + @type prediction_canvas: Tensor + @param prediction_canvas: Tensor containing the predicted + visualizations. + @type target_bboxes: Tensor | None + @param target_bboxes: Tensor containing the target bounding + boxes. + @type target_masks: Tensor | None + @param target_masks: Tensor containing the target instance + masks. + @type predicted_bboxes: Tensor + @param predicted_bboxes: Tensor containing the predicted + bounding boxes. + @type predicted_masks: Tensor + @param predicted_masks: Tensor containing the predicted instance + masks. + """ predictions_viz = self.draw_predictions( prediction_canvas, predicted_bboxes, diff --git a/luxonis_train/attached_modules/visualizers/utils.py b/luxonis_train/attached_modules/visualizers/utils.py index d6d710c6..45ec454b 100644 --- a/luxonis_train/attached_modules/visualizers/utils.py +++ b/luxonis_train/attached_modules/visualizers/utils.py @@ -118,7 +118,6 @@ def draw_segmentation_labels( @rtype: Tensor @return: Image with segmentation labels drawn on. """ - print(f"sum(label): {label.sum()}") masks = label.bool() masks = masks.cpu() img = img.cpu() diff --git a/luxonis_train/nodes/heads/precision_bbox_head.py b/luxonis_train/nodes/heads/precision_bbox_head.py index bfc5f72d..8230217c 100644 --- a/luxonis_train/nodes/heads/precision_bbox_head.py +++ b/luxonis_train/nodes/heads/precision_bbox_head.py @@ -33,7 +33,9 @@ def __init__( ): """ Adapted from U{Real-Time Flying Object Detection with YOLOv8 - } + } and from U{YOLOv6: A Single-Stage Object Detection Framework + for Industrial Applications + }. @type ch: tuple[int] @param ch: Channels for each detection layer. diff --git a/luxonis_train/nodes/heads/precision_seg_bbox_head.py b/luxonis_train/nodes/heads/precision_seg_bbox_head.py index 3c869e6d..0f656ad8 100644 --- a/luxonis_train/nodes/heads/precision_seg_bbox_head.py +++ b/luxonis_train/nodes/heads/precision_seg_bbox_head.py @@ -34,7 +34,9 @@ def __init__( """ Head for instance segmentation and object detection. Adapted from U{Real-Time Flying Object Detection with YOLOv8 - } + } and from U{YOLOv6: A Single-Stage Object Detection Framework + for Industrial Applications + }. @type n_heads: Literal[2, 3, 4] @param n_heads: Number of output heads. Defaults to 3. From b6be8d45a39e4c5c852039945748661de1c1eac3 Mon Sep 17 00:00:00 2001 From: Jernej Sabadin Date: Tue, 10 Dec 2024 14:40:23 +0100 Subject: [PATCH 04/31] remove loss scaling --- .../losses/efficient_keypoint_bbox_loss.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/luxonis_train/attached_modules/losses/efficient_keypoint_bbox_loss.py b/luxonis_train/attached_modules/losses/efficient_keypoint_bbox_loss.py index 5dc3e564..ad34bff7 100644 --- a/luxonis_train/attached_modules/losses/efficient_keypoint_bbox_loss.py +++ b/luxonis_train/attached_modules/losses/efficient_keypoint_bbox_loss.py @@ -103,7 +103,7 @@ def prepare( target_kpts = self.get_label(labels, TaskType.KEYPOINTS) target_bbox = self.get_label(labels, TaskType.BOUNDINGBOX) - self.batch_size = pred_scores.shape[0] + batch_size = pred_scores.shape[0] n_kpts = (target_kpts.shape[1] - 2) // 3 self._init_parameters(feats) @@ -112,16 +112,14 @@ def prepare( pred_kpts = self.dist2kpts_noscale( self.anchor_points_strided, pred_kpts.view( - self.batch_size, + batch_size, -1, n_kpts, 3, ), ) - target_bbox = self._preprocess_bbox_target( - target_bbox, self.batch_size - ) + target_bbox = self._preprocess_bbox_target(target_bbox, batch_size) gt_bbox_labels = target_bbox[:, :, :1] gt_xyxy = target_bbox[:, :, 1:] @@ -141,7 +139,7 @@ def prepare( ) batched_kpts = self._preprocess_kpts_target( - target_kpts, self.batch_size, self.gt_kpts_scale + target_kpts, batch_size, self.gt_kpts_scale ) assigned_gt_idx_expanded = assigned_gt_idx.unsqueeze(-1).unsqueeze(-1) selected_keypoints = batched_kpts.gather( @@ -234,7 +232,7 @@ def forward( "visibility": visibility_loss.detach(), } - return loss * self.batch_size, sub_losses + return loss, sub_losses def _preprocess_kpts_target( self, kpts_target: Tensor, batch_size: int, scale_tensor: Tensor From 6a6bb12cfb2699830b41e07c57a90587a5572fc5 Mon Sep 17 00:00:00 2001 From: Jernej Sabadin Date: Tue, 10 Dec 2024 14:42:24 +0100 Subject: [PATCH 05/31] remove loss scaling --- .../losses/precision_dfl_detection_loss.py | 8 ++++---- .../losses/precision_dlf_segmentation_loss.py | 8 +++----- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/luxonis_train/attached_modules/losses/precision_dfl_detection_loss.py b/luxonis_train/attached_modules/losses/precision_dfl_detection_loss.py index 5351ec60..fb6b559f 100644 --- a/luxonis_train/attached_modules/losses/precision_dfl_detection_loss.py +++ b/luxonis_train/attached_modules/losses/precision_dfl_detection_loss.py @@ -76,15 +76,15 @@ def prepare( ) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: feats = self.get_input_tensors(inputs, "features") self._init_parameters(feats) - self.batch_size = feats[0].shape[0] + batch_size = feats[0].shape[0] pred_distri, pred_scores = torch.cat( - [xi.view(self.batch_size, self.node.no, -1) for xi in feats], 2 + [xi.view(batch_size, self.node.no, -1) for xi in feats], 2 ).split((self.node.reg_max * 4, self.n_classes), 1) target = self.get_label(labels) pred_distri = pred_distri.permute(0, 2, 1).contiguous() pred_scores = pred_scores.permute(0, 2, 1).contiguous() - target = self._preprocess_bbox_target(target, self.batch_size) + target = self._preprocess_bbox_target(target, batch_size) pred_bboxes = self.decode_bbox(self.anchor_points_strided, pred_distri) @@ -148,7 +148,7 @@ def forward( "dfl": loss_dfl.detach(), } - return loss * self.batch_size, sub_losses + return loss, sub_losses def _preprocess_bbox_target( self, target: Tensor, batch_size: int diff --git a/luxonis_train/attached_modules/losses/precision_dlf_segmentation_loss.py b/luxonis_train/attached_modules/losses/precision_dlf_segmentation_loss.py index 8808dc2c..af777a80 100644 --- a/luxonis_train/attached_modules/losses/precision_dlf_segmentation_loss.py +++ b/luxonis_train/attached_modules/losses/precision_dlf_segmentation_loss.py @@ -68,9 +68,9 @@ def prepare( proto = self.get_input_tensors(inputs, "prototypes") pred_mask = self.get_input_tensors(inputs, "mask_coeficients") self._init_parameters(det_feats) - self.batch_size, _, mask_h, mask_w = proto.shape + batch_size, _, mask_h, mask_w = proto.shape pred_distri, pred_scores = torch.cat( - [xi.view(self.batch_size, self.node.no, -1) for xi in det_feats], 2 + [xi.view(batch_size, self.node.no, -1) for xi in det_feats], 2 ).split((self.node.reg_max * 4, self.n_classes), 1) target_bbox = self.get_label(labels, TaskType.BOUNDINGBOX) img_idx = target_bbox[:, 0].unsqueeze(-1) @@ -84,9 +84,7 @@ def prepare( pred_scores = pred_scores.permute(0, 2, 1).contiguous() pred_mask = pred_mask.permute(0, 2, 1).contiguous() - target_bbox = self._preprocess_bbox_target( - target_bbox, self.batch_size - ) + target_bbox = self._preprocess_bbox_target(target_bbox, batch_size) pred_bboxes = self.decode_bbox(self.anchor_points_strided, pred_distri) From 659768793909339f9a7a4c4932b5d2c2d52a319c Mon Sep 17 00:00:00 2001 From: Jernej Sabadin Date: Wed, 11 Dec 2024 08:00:11 +0100 Subject: [PATCH 06/31] fix: export --- .../nodes/heads/precision_bbox_head.py | 54 ++++++++++--------- .../nodes/heads/precision_seg_bbox_head.py | 26 +++++---- 2 files changed, 45 insertions(+), 35 deletions(-) diff --git a/luxonis_train/nodes/heads/precision_bbox_head.py b/luxonis_train/nodes/heads/precision_bbox_head.py index 8230217c..c2c1893a 100644 --- a/luxonis_train/nodes/heads/precision_bbox_head.py +++ b/luxonis_train/nodes/heads/precision_bbox_head.py @@ -136,11 +136,12 @@ def wrap(self, output: list[Tensor]) -> Packet[Tensor]: return { "features": output, } - y = self._inference(output) + if self.export: - return {self.task: y} + return {self.task: [self._export_bbox_output(output)]} + boxes = non_max_suppression( - y, + self._inference_bbox_output(output), n_classes=self.n_classes, conf_thres=self.conf_thres, iou_thres=self.iou_thres, @@ -166,25 +167,35 @@ def _fit_stride_to_n_heads(self): ) return stride - def _inference(self, x: list[Tensor], masks: Tensor | None = None): - """Decode predicted bounding boxes and class probabilities based - on multiple-level feature maps.""" + def _extract_cls_and_box(self, x: list[Tensor]): + """Extract classification and bounding box tensors.""" shape = x[0].shape x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2) - _, self.anchor_points, _, self.strides = anchors_for_fpn_features( + box, cls = x_cat.split((self.reg_max * 4, self.n_classes), 1) + return box, cls.sigmoid(), shape # Apply sigmoid to cls + + def _export_bbox_output(self, x: list[Tensor]): + """Prepare the output for export.""" + box, cls, _ = self._extract_cls_and_box(x) + box_dist = self.dfl(box) # Shape: [N, 4, N_anchors] + conf, _ = cls.max(1, keepdim=True) # Shape: [N, 1, N_anchors] + export_output = torch.cat( + [box_dist, conf, cls], dim=1 + ) # Shape: [N, 4 + 1 + num_classes, N_anchors] + return export_output + + def _inference_bbox_output(self, x: list[Tensor]): + """Perform inference on predicted bounding boxes and class + probabilities.""" + box, cls, shape = self._extract_cls_and_box(x) + box_dist = self.dfl(box) + + _, anchor_points, _, strides = anchors_for_fpn_features( x, self.stride, 0.5 ) - box, cls = x_cat.split((self.reg_max * 4, self.n_classes), 1) - pred_bboxes = self.decode_bboxes( - self.dfl(box), self.anchor_points.transpose(0, 1) - ) * self.strides.transpose(0, 1) - - if self.export: - return torch.cat( - (pred_bboxes.permute(0, 2, 1), cls.sigmoid().permute(0, 2, 1)), - 1, - ) - + pred_bboxes = dist2bbox( + box_dist, anchor_points.transpose(0, 1), out_format="xyxy", dim=1 + ) * strides.transpose(0, 1) base_output = [ pred_bboxes.permute(0, 2, 1), torch.ones( @@ -195,16 +206,9 @@ def _inference(self, x: list[Tensor], masks: Tensor | None = None): cls.permute(0, 2, 1), ] - if masks is not None: - base_output.append(masks.permute(0, 2, 1)) - output_merged = torch.cat(base_output, dim=-1) return output_merged - def decode_bboxes(self, bboxes: Tensor, anchors: Tensor) -> Tensor: - """Decode bounding boxes.""" - return dist2bbox(bboxes, anchors, out_format="xyxy", dim=1) - def bias_init(self): """Initialize biases for the detection heads. diff --git a/luxonis_train/nodes/heads/precision_seg_bbox_head.py b/luxonis_train/nodes/heads/precision_seg_bbox_head.py index 0f656ad8..05b4a70b 100644 --- a/luxonis_train/nodes/heads/precision_seg_bbox_head.py +++ b/luxonis_train/nodes/heads/precision_seg_bbox_head.py @@ -102,16 +102,24 @@ def wrap( "prototypes": prototypes, "mask_coeficients": mask_coefficients, } + if self.export: - { - self.task: ( - torch.cat([det_feats, mask_coefficients], 1), - prototypes, - ) + pred_bboxes = self._export_bbox_output(det_feats) + return { + TaskType.INSTANCE_SEGMENTATION: [ + torch.cat( + [pred_bboxes, mask_coefficients], 1 + ), # Shape: [N, 4 + 1 + num_classes + n_masks, N_anchors] + ], + "prototypes": [prototypes], # Shape: [N, n_masks, H, W] } - pred_bboxes = self._inference(det_feats, mask_coefficients) + + pred_bboxes = self._inference_bbox_output(det_feats) + preds_combined = torch.cat( + [pred_bboxes, mask_coefficients.permute(0, 2, 1)], dim=-1 + ) preds = non_max_suppression( - pred_bboxes, + preds_combined, n_classes=self.n_classes, conf_thres=self.conf_thres, iou_thres=self.iou_thres, @@ -128,9 +136,7 @@ def wrap( "instance_segmentation": [], } - for i, pred in enumerate( - preds - ): # TODO: Investigate low seg loss but wrong masks + for i, pred in enumerate(preds): results["instance_segmentation"].append( refine_and_apply_masks( prototypes[i], From bd7ff525ea0a18e60e05c43713c0273c4dff6f1f Mon Sep 17 00:00:00 2001 From: Jernej Sabadin Date: Wed, 11 Dec 2024 09:58:21 +0100 Subject: [PATCH 07/31] add docs --- .../attached_modules/losses/README.md | 28 +++++++++++++++ .../losses/precision_dfl_detection_loss.py | 7 ++-- .../losses/precision_dlf_segmentation_loss.py | 4 --- luxonis_train/nodes/README.md | 34 ++++++++++++++++++- .../nodes/heads/precision_bbox_head.py | 2 ++ 5 files changed, 65 insertions(+), 10 deletions(-) diff --git a/luxonis_train/attached_modules/losses/README.md b/luxonis_train/attached_modules/losses/README.md index 38f8b42f..a2f07106 100644 --- a/luxonis_train/attached_modules/losses/README.md +++ b/luxonis_train/attached_modules/losses/README.md @@ -12,6 +12,8 @@ List of all the available loss functions. - [`AdaptiveDetectionLoss`](#adaptivedetectionloss) - [`EfficientKeypointBBoxLoss`](#efficientkeypointbboxloss) - [`FOMOLocalizationLoss`](#fomolocalizationLoss) +- \[`PrecisionDFLDetectionLoss`\] (# precisiondfldetectionloss) +- \[`PrecisionDFLSegmentationLoss`\] (# precisiondflsegmentationloss) ## `CrossEntropyLoss` @@ -121,3 +123,29 @@ Adapted from [here](https://arxiv.org/abs/2108.07610). | Key | Type | Default value | Description | | --------------- | ------- | ------------- | ----------------------------------------------- | | `object_weight` | `float` | `1000` | Weight for the objects in the loss calculation. | + +## `PrecisionDFLDetectionLoss` + +Adapted from [here](https://arxiv.org/pdf/2207.02696.pdf) and [here](https://arxiv.org/pdf/2209.02976.pdf). + +**Parameters:** + +| Key | Type | Default value | Description | +| ------------------- | ------- | ------------- | ------------------------------------------ | +| `tal_topk` | `int` | `10` | Number of anchors considered in selection. | +| `class_loss_weight` | `float` | `0.5` | Weight for classification loss. | +| `bbox_loss_weight` | `float` | `7.5` | Weight for bbox loss. | +| `dfl_loss_weigth` | `float` | `1.5` | Weight for DFL loss. | + +## `PrecisionDFLSegmentationLoss` + +Adapted from [here](https://arxiv.org/pdf/2207.02696.pdf) and [here](https://arxiv.org/pdf/2209.02976.pdf). + +**Parameters:** + +| Key | Type | Default value | Description | +| ------------------- | ------- | ------------- | ------------------------------------------ | +| `tal_topk` | `int` | `10` | Number of anchors considered in selection. | +| `class_loss_weight` | `float` | `0.5` | Weight for classification loss. | +| `bbox_loss_weight` | `float` | `7.5` | Weight for bbox and segmentation loss. | +| `dfl_loss_weigth` | `float` | `1.5` | Weight for DFL loss. | diff --git a/luxonis_train/attached_modules/losses/precision_dfl_detection_loss.py b/luxonis_train/attached_modules/losses/precision_dfl_detection_loss.py index fb6b559f..cb80b105 100644 --- a/luxonis_train/attached_modules/losses/precision_dfl_detection_loss.py +++ b/luxonis_train/attached_modules/losses/precision_dfl_detection_loss.py @@ -31,7 +31,6 @@ class PrecisionDFLDetectionLoss( def __init__( self, - reg_max: int = 16, tal_topk: int = 10, class_loss_weight: float = 0.5, bbox_loss_weight: float = 7.5, @@ -43,8 +42,6 @@ def __init__( }. Code is adapted from U{https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/models}. - @type reg_max: int - @param reg_max: Maximum number of regression channels. Defaults to 16. @type tal_topk: int @param tal_topk: Number of anchors considered in selection. Defaults to 10. @type class_loss_weight: float @@ -67,8 +64,8 @@ def __init__( self.assigner = TaskAlignedAssigner( n_classes=self.n_classes, topk=tal_topk, alpha=0.5, beta=6.0 ) - self.bbox_loss = CustomBboxLoss(reg_max) - self.proj = torch.arange(reg_max, dtype=torch.float) + self.bbox_loss = CustomBboxLoss(self.node.reg_max) + self.proj = torch.arange(self.node.reg_max, dtype=torch.float) self.bce = nn.BCEWithLogitsLoss(reduction="none") def prepare( diff --git a/luxonis_train/attached_modules/losses/precision_dlf_segmentation_loss.py b/luxonis_train/attached_modules/losses/precision_dlf_segmentation_loss.py index af777a80..27f05809 100644 --- a/luxonis_train/attached_modules/losses/precision_dlf_segmentation_loss.py +++ b/luxonis_train/attached_modules/losses/precision_dlf_segmentation_loss.py @@ -29,7 +29,6 @@ class PrecisionDFLSegmentationLoss(PrecisionDFLDetectionLoss): def __init__( self, - reg_max: int = 16, tal_topk: int = 10, class_loss_weight: float = 0.5, bbox_loss_weight: float = 7.5, @@ -41,8 +40,6 @@ def __init__( }. Code is adapted from U{https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/models}. - @type reg_max: int - @param reg_max: Maximum number of regression channels. Defaults to 16. @type tal_topk: int @param tal_topk: Number of anchors considered in selection. Defaults to 10. @type class_loss_weight: float @@ -53,7 +50,6 @@ def __init__( @param dfl_loss_weight: Weight for DFL loss. Defaults to 1.5. For optimal results, multiply with accumulate_grad_batches. """ super().__init__( - reg_max=reg_max, tal_topk=tal_topk, class_loss_weight=class_loss_weight, bbox_loss_weight=bbox_loss_weight, diff --git a/luxonis_train/nodes/README.md b/luxonis_train/nodes/README.md index 31f1f6c2..92046745 100644 --- a/luxonis_train/nodes/README.md +++ b/luxonis_train/nodes/README.md @@ -28,6 +28,8 @@ arbitrarily as long as the two nodes are compatible with each other. We've group - [`DDRNetSegmentationHead`](#ddrnetsegmentationhead) - [`DiscSubNetHead`](#discsubnet) - [`FOMOHead`](#fomohead) + - [`PrecisionBBoxHead`](#precisionbboxhead) + - [`PrecisionSegmentBBoxHead`](#precisionsegmentbboxhead) Every node takes these parameters: | Key | Type | Default value | Description | @@ -222,7 +224,7 @@ Adapted from [here](https://arxiv.org/pdf/2209.02976.pdf). | Key | Type | Default value | Description | | -------------------- | ------- | ------------- | --------------------------------------------------------------------- | -| `n_heads` | `bool` | `3` | Number of output heads | +| `n_heads` | `int` | `3` | Number of output heads | | `conf_thres` | `float` | `0.25` | Confidence threshold for non-maxima-suppression (used for evaluation) | | `iou_thres` | `float` | `0.45` | `IoU` threshold for non-maxima-suppression (used for evaluation) | | `max_det` | `int` | `300` | Maximum number of detections retained after NMS | @@ -272,3 +274,33 @@ Adapted from [here](https://arxiv.org/abs/2108.07610). | ----------------- | ----- | ------------- | ------------------------------------------------------- | | `num_conv_layers` | `int` | `3` | Number of convolutional layers to use in the model. | | `conv_channels` | `int` | `16` | Number of output channels for each convolutional layer. | + +## `PrecisionBBoxHead` + +Adapted from [here](https://arxiv.org/pdf/2207.02696.pdf) and [here](https://arxiv.org/pdf/2209.02976.pdf). + +**Parameters:** + +| Key | Type | Default value | Description | +| ------------ | ------- | ------------- | ------------------------------------------------------------------------- | +| `reg_max` | `int` | `16` | Maximum number of regression channels | +| `n_heads` | `int` | `3` | Number of output heads | +| `conf_thres` | `float` | `0.25` | Confidence threshold for non-maxima-suppression (used for evaluation) | +| `iou_thres` | `float` | `0.45` | IoU threshold for non-maxima-suppression (used for evaluation) | +| `max_det` | `int` | `300` | Max number of detections for non-maxima-suppression (used for evaluation) | + +## `PrecisionSegmentBBoxHead` + +Adapted from [here](https://arxiv.org/pdf/2207.02696.pdf) and [here](https://arxiv.org/pdf/2209.02976.pdf). + +**Parameters:** + +| Key | Type | Default value | Description | +| ------------ | ------- | ------------- | -------------------------------------------------------------------------- | +| `reg_max` | `int` | `16` | Maximum number of regression channels. | +| `n_heads` | `int` | `3` | Number of output heads. | +| `conf_thres` | `float` | `0.25` | Confidence threshold for non-maxima-suppression (used for evaluation). | +| `iou_thres` | `float` | `0.45` | IoU threshold for non-maxima-suppression (used for evaluation). | +| `max_det` | `int` | `300` | Max number of detections for non-maxima-suppression (used for evaluation). | +| `n_masks` | `int` | `32` | Number of of output instance segmentation masks at the output. | +| `n_proto` | `int` | `256` | Number of prototypes generated from the prototype generator. | diff --git a/luxonis_train/nodes/heads/precision_bbox_head.py b/luxonis_train/nodes/heads/precision_bbox_head.py index c2c1893a..27c2fb9f 100644 --- a/luxonis_train/nodes/heads/precision_bbox_head.py +++ b/luxonis_train/nodes/heads/precision_bbox_head.py @@ -47,6 +47,8 @@ def __init__( @param conf_thres: Confidence threshold for NMS. @type iou_thres: float @param iou_thres: IoU threshold for NMS. + @type max_det: int + @param max_det: Maximum number of detections retained after NMS. """ super().__init__(**kwargs) self.reg_max = reg_max From 95ea9c23b2f5663d1773c75a843e161221f26424 Mon Sep 17 00:00:00 2001 From: Jernej Sabadin Date: Wed, 11 Dec 2024 11:10:38 +0100 Subject: [PATCH 08/31] predefined instance segmentation model --- .../instance_segmentation_heavy_model.yaml | 45 ++++++ .../instance_segmentation_light_model.yaml | 45 ++++++ .../config/predefined_models/README.md | 34 +++- .../config/predefined_models/__init__.py | 2 + .../instance_segmentation_model.py | 153 ++++++++++++++++++ 5 files changed, 278 insertions(+), 1 deletion(-) create mode 100644 configs/instance_segmentation_heavy_model.yaml create mode 100644 configs/instance_segmentation_light_model.yaml create mode 100644 luxonis_train/config/predefined_models/instance_segmentation_model.py diff --git a/configs/instance_segmentation_heavy_model.yaml b/configs/instance_segmentation_heavy_model.yaml new file mode 100644 index 00000000..42cedd87 --- /dev/null +++ b/configs/instance_segmentation_heavy_model.yaml @@ -0,0 +1,45 @@ +# Example configuration for training a predefined heavy instance segmentation model + +model: + name: instance_segmentation_heavy + predefined_model: + name: InstanceSegmentationModel + params: + variant: heavy + +loader: + params: + dataset_name: coco_test + +trainer: + preprocessing: + train_image_size: [384, 512] + keep_aspect_ratio: true + normalize: + active: true + + batch_size: 8 + epochs: &epochs 200 + n_workers: 4 + validation_interval: 10 + n_log_images: 8 + + callbacks: + - name: ExportOnTrainEnd + - name: TestOnTrainEnd + + optimizer: + name: SGD + params: + lr: 0.01 + momentum: 0.937 + weight_decay: 0.0005 + dampening: 0.0 + nesterov: true + + scheduler: + name: CosineAnnealingLR + params: + T_max: *epochs + eta_min: 0.0001 + last_epoch: -1 diff --git a/configs/instance_segmentation_light_model.yaml b/configs/instance_segmentation_light_model.yaml new file mode 100644 index 00000000..24d764ed --- /dev/null +++ b/configs/instance_segmentation_light_model.yaml @@ -0,0 +1,45 @@ +# Example configuration for training a predefined light instance segmentation model + +model: + name: instance_segmentation_light + predefined_model: + name: InstanceSegmentationModel + params: + variant: light + +loader: + params: + dataset_name: coco_test + +trainer: + preprocessing: + train_image_size: [384, 512] + keep_aspect_ratio: true + normalize: + active: true + + batch_size: 8 + epochs: &epochs 200 + n_workers: 4 + validation_interval: 10 + n_log_images: 8 + + callbacks: + - name: ExportOnTrainEnd + - name: TestOnTrainEnd + + optimizer: + name: SGD + params: + lr: 0.01 + momentum: 0.937 + weight_decay: 0.0005 + dampening: 0.0 + nesterov: true + + scheduler: + name: CosineAnnealingLR + params: + T_max: *epochs + eta_min: 0.0001 + last_epoch: -1 diff --git a/luxonis_train/config/predefined_models/README.md b/luxonis_train/config/predefined_models/README.md index 0d81a0ea..124cba6e 100644 --- a/luxonis_train/config/predefined_models/README.md +++ b/luxonis_train/config/predefined_models/README.md @@ -10,6 +10,7 @@ models which can be used instead. - [`KeypointDetectionModel`](#keypointdetectionmodel) - [`ClassificationModel`](#classificationmodel) - [`FOMOModel`](#fomomodel) +- [`InstanceSegmentationModel`](#instancesegmentationmodel) **Parameters:** @@ -56,7 +57,7 @@ See an example configuration file using this predefined model [here](../../../co ## `DetectionModel` -The `DetectionModel` allows for both `"light"` and `"heavy"` variants, where the `"heavy"` variant is more accurate, and the `"light"` variant is faster. +The `DetectionModel` supports `"light"`, `"medium"`, and `"heavy"` variants, with `"light"` optimized for speed, `"heavy"` for accuracy, and `"medium"` offering a balance between the two. See an example configuration file using this predefined model [here](../../../configs/detection_light_model.yaml) for the `"light"` variant, and [here](../../../configs/detection_heavy_model.yaml) for the `"heavy"` variant. @@ -177,3 +178,34 @@ See an example configuration file using this predefined model [here](../../../co | `loss_params` | `dict` | `{}` | Additional parameters for the loss function. | | `visualizer_params` | `dict` | `{}` | Additional parameters for the visualizer. | | `task_name` | `str \| None` | `None` | Custom task name for the model head. | + +## `InstanceSegmentationModel` + +The `InstanceSegmentationModel` supports `"light"`, `"medium"`, and `"heavy"` variants, with `"light"` optimized for speed, `"heavy"` for accuracy, and `"medium"` offering a balance between the two. + +See an example configuration file using this predefined model [here](../../../configs/instance_segmentation_light_model.yaml) for the `"light"` variant, and [here](../../../configs/instance_segmentation_heavy_model.yaml) for the `"heavy"` variant. + +**Components:** + +| Name | Alias | Function | +| --------------------------------------------------------------------------------------------------------------- | ------------------------------------ | ---------------------------------------------------------------------------------------------------------------------------------------- | +| [`EfficientRep`](../../nodes/README.md#efficientrep) | `"instance_segmentation_backbone"` | Backbone of the model. Available variants: `"light"` (`EfficientRep-N`), `"medium"` (`EfficientRep-S`), and `"heavy"` (`EfficientRep-L`) | +| [`RepPANNeck`](../../nodes/README.md#reppanneck) | `"instance_segmentation_neck"` | Neck of the model | +| [`PrecisionSegmentBBoxHead`](../../nodes/README.md#precisionsegmentbboxhead) | `"instance_segmentation_head"` | Head of the model for instance segmentation | +| [`PrecisionDFLSegmentationLoss`](../../attached_modules/losses/README.md#precisiondflsegmentationloss) | `"instance_segmentation_loss"` | Loss function for training instance segmentation models | +| [`MeanAveragePrecision`](../../attached_modules/metrics/README.md#meanaverageprecision) | `"instance_segmentation_map"` | Main metric of the model, measuring mean average precision | +| [`InstanceSegmentationVisualizer`](../../attached_modules/visualizers/README.md#instancesegmentationvisualizer) | `"instance_segmentation_visualizer"` | Visualizer for displaying instance segmentation results | + +**Parameters:** + +| Key | Type | Default value | Description | +| ------------------- | ------------------------------------- | ---------------- | ------------------------------------------------------------------------------------------------------------------------------------ | +| `variant` | `Literal["light", "medium", "heavy"]` | `"light"` | Defines the variant of the model. `"light"` uses `EfficientRep-N`, `"medium"` uses `EfficientRep-S`, `"heavy"` uses `EfficientRep-L` | +| `use_neck` | `bool` | `True` | Whether to include the neck in the model | +| `backbone` | `str` | `"EfficientRep"` | Name of the node to be used as a backbone | +| `backbone_params` | `dict` | `{}` | Additional parameters to the backbone | +| `neck_params` | `dict` | `{}` | Additional parameters to the neck | +| `head_params` | `dict` | `{}` | Additional parameters to the head | +| `loss_params` | `dict` | `{}` | Additional parameters to the loss function | +| `visualizer_params` | `dict` | `{}` | Additional parameters to the visualizer | +| `task_name` | `str \| None` | `None` | Custom task name for the head | diff --git a/luxonis_train/config/predefined_models/__init__.py b/luxonis_train/config/predefined_models/__init__.py index a52db8bb..7bec15b0 100644 --- a/luxonis_train/config/predefined_models/__init__.py +++ b/luxonis_train/config/predefined_models/__init__.py @@ -3,6 +3,7 @@ from .classification_model import ClassificationModel from .detection_fomo_model import FOMOModel from .detection_model import DetectionModel +from .instance_segmentation_model import InstanceSegmentationModel from .keypoint_detection_model import KeypointDetectionModel from .segmentation_model import SegmentationModel @@ -14,4 +15,5 @@ "SegmentationModel", "AnomalyDetectionModel", "FOMOModel", + "InstanceSegmentationModel", ] diff --git a/luxonis_train/config/predefined_models/instance_segmentation_model.py b/luxonis_train/config/predefined_models/instance_segmentation_model.py new file mode 100644 index 00000000..28477572 --- /dev/null +++ b/luxonis_train/config/predefined_models/instance_segmentation_model.py @@ -0,0 +1,153 @@ +from typing import Literal, TypeAlias + +from pydantic import BaseModel + +from luxonis_train.config import ( + AttachedModuleConfig, + LossModuleConfig, + MetricModuleConfig, + ModelNodeConfig, + Params, +) + +from .base_predefined_model import BasePredefinedModel + +VariantLiteral: TypeAlias = Literal["light", "medium", "heavy"] + + +class DetectionVariant(BaseModel): + backbone: str + backbone_params: Params + neck_params: Params + + +def get_variant(variant: VariantLiteral) -> DetectionVariant: + """Returns the specific variant configuration for the + DetectionModel.""" + variants = { + "light": DetectionVariant( + backbone="EfficientRep", + backbone_params={"variant": "n"}, + neck_params={"variant": "n"}, + ), + "medium": DetectionVariant( + backbone="EfficientRep", + backbone_params={"variant": "s"}, + neck_params={"variant": "s"}, + ), + "heavy": DetectionVariant( + backbone="EfficientRep", + backbone_params={"variant": "l"}, + neck_params={"variant": "l"}, + ), + } + + if variant not in variants: + raise ValueError( + f"Detection variant should be one of {list(variants.keys())}, got '{variant}'." + ) + + return variants[variant] + + +class InstanceSegmentationModel(BasePredefinedModel): + def __init__( + self, + variant: VariantLiteral = "light", + use_neck: bool = True, + backbone: str | None = None, + backbone_params: Params | None = None, + neck_params: Params | None = None, + head_params: Params | None = None, + loss_params: Params | None = None, + visualizer_params: Params | None = None, + task_name: str | None = None, + ): + var_config = get_variant(variant) + + self.use_neck = use_neck + self.backbone_params = ( + backbone_params + if backbone is not None or backbone_params is not None + else var_config.backbone_params + ) or {} + self.backbone = backbone or var_config.backbone + self.neck_params = neck_params or var_config.neck_params + self.head_params = head_params or {} + self.loss_params = loss_params or {"n_warmup_epochs": 0} + self.visualizer_params = visualizer_params or {} + self.task_name = task_name or "instance_segmentation" + + @property + def nodes(self) -> list[ModelNodeConfig]: + """Defines the model nodes, including backbone, neck, and + head.""" + nodes = [ + ModelNodeConfig( + name=self.backbone, + alias=f"{self.backbone}-{self.task_name}", + freezing=self.backbone_params.pop("freezing", {}), + params=self.backbone_params, + ), + ] + if self.use_neck: + nodes.append( + ModelNodeConfig( + name="RepPANNeck", + alias=f"RepPANNeck-{self.task_name}", + inputs=[f"{self.backbone}-{self.task_name}"], + freezing=self.neck_params.pop("freezing", {}), + params=self.neck_params, + ) + ) + + nodes.append( + ModelNodeConfig( + name="PrecisionSegmentBBoxHead", + alias=f"PrecisionSegmentBBoxHead-{self.task_name}", + freezing=self.head_params.pop("freezing", {}), + inputs=[f"RepPANNeck-{self.task_name}"] + if self.use_neck + else [f"{self.backbone}-{self.task_name}"], + params=self.head_params, + task=self.task_name, + ) + ) + return nodes + + @property + def losses(self) -> list[LossModuleConfig]: + """Defines the loss module for the detection task.""" + return [ + LossModuleConfig( + name="PrecisionDFLSegmentationLoss", + alias=f"PrecisionDFLSegmentationLoss-{self.task_name}", + attached_to=f"PrecisionSegmentBBoxHead-{self.task_name}", + params=self.loss_params, + weight=1.0, + ) + ] + + @property + def metrics(self) -> list[MetricModuleConfig]: + """Defines the metrics used for evaluation.""" + return [ + MetricModuleConfig( + name="MeanAveragePrecision", + alias=f"MeanAveragePrecision-{self.task_name}", + attached_to=f"PrecisionSegmentBBoxHead-{self.task_name}", + is_main_metric=True, + ), + ] + + @property + def visualizers(self) -> list[AttachedModuleConfig]: + """Defines the visualizer used for the detection task.""" + return [ + AttachedModuleConfig( + name="InstanceSegmentationVisualizer", + alias=f"InstanceSegmentationVisualizer-{self.task_name}", + attached_to=f"PrecisionSegmentBBoxHead-{self.task_name}", + params=self.visualizer_params, + ) + ] From a28f0c75ad106ce316b9591572bfeba7cd40efb7 Mon Sep 17 00:00:00 2001 From: Jernej Sabadin Date: Thu, 12 Dec 2024 15:33:12 +0100 Subject: [PATCH 09/31] fix: export --- luxonis_train/nodes/blocks/blocks.py | 41 ++++---- .../nodes/heads/precision_bbox_head.py | 98 +++++++++++++------ .../nodes/heads/precision_seg_bbox_head.py | 47 +++++---- 3 files changed, 111 insertions(+), 75 deletions(-) diff --git a/luxonis_train/nodes/blocks/blocks.py b/luxonis_train/nodes/blocks/blocks.py index 29a2fa9b..870d78b3 100644 --- a/luxonis_train/nodes/blocks/blocks.py +++ b/luxonis_train/nodes/blocks/blocks.py @@ -138,31 +138,26 @@ def forward(self, x): class DFL(nn.Module): - def __init__(self, channels: int = 16): - """ - Constructs the module with a convolutional layer using the specified input channels. - Proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391 - - @type channels: int - @param channels: Number of input channels. Defaults to 16. - + def __init__(self, reg_max: int = 16): + """The DFL (Distribution Focal Loss) module processes input + tensors by applying softmax over a specified dimension and + projecting the resulting tensor to produce output logits. + + @type reg_max: int + @param reg_max: Maximum number of regression outputs. Defaults + to 16. """ super().__init__() - self.transform = nn.Conv2d( - channels, 1, kernel_size=1, bias=False - ).requires_grad_(False) - weights = torch.arange(channels, dtype=torch.float32) - self.transform.weight.data.copy_(weights.view(1, channels, 1, 1)) - self.num_channels = channels - - def forward(self, input: Tensor): - """Transforms the input tensor and returns the processed - output.""" - batch_size, _, anchors = input.size() - reshaped = input.view(batch_size, 4, self.num_channels, anchors) - softmaxed = reshaped.transpose(2, 1).softmax(dim=1) - processed = self.transform(softmaxed) - return processed.view(batch_size, 4, anchors) + self.proj_conv = nn.Conv2d(reg_max, 1, kernel_size=1, bias=False) + self.proj_conv.weight.data.copy_( + torch.arange(reg_max, dtype=torch.float32).view(1, reg_max, 1, 1) + ) + self.proj_conv.requires_grad_(False) + + def forward(self, x: Tensor) -> Tensor: + bs, _, h, w = x.size() + x = F.softmax(x.view(bs, 4, -1, h * w).permute(0, 2, 1, 3), dim=1) + return self.proj_conv(x)[:, 0].view(bs, 4, h, w) class ConvModule(nn.Sequential): diff --git a/luxonis_train/nodes/heads/precision_bbox_head.py b/luxonis_train/nodes/heads/precision_bbox_head.py index 27c2fb9f..0e466359 100644 --- a/luxonis_train/nodes/heads/precision_bbox_head.py +++ b/luxonis_train/nodes/heads/precision_bbox_head.py @@ -126,24 +126,38 @@ def __init__( self.bias_init() self.initialize_weights() - def forward(self, x: list[Tensor]) -> list[Tensor]: + def forward(self, x: list[Tensor]) -> tuple[list[Tensor], list[Tensor]]: + cls_outputs = [] + reg_outputs = [] for i in range(self.n_heads): reg_output = self.detection_heads[i][0](x[i]) cls_output = self.detection_heads[i][1](x[i]) - x[i] = torch.cat((reg_output, cls_output), 1) - return x + reg_outputs.append(reg_output) + cls_outputs.append(cls_output) + return reg_outputs, cls_outputs - def wrap(self, output: list[Tensor]) -> Packet[Tensor]: + def wrap( + self, output: tuple[list[Tensor], list[Tensor]] + ) -> Packet[Tensor]: + reg_outputs, cls_outputs = ( + output # ([bs, 4*reg_max, h_f, w_f]), ([bs, n_classes, h_f, w_f]) + ) + features = [ + torch.cat((reg, cls), dim=1) + for reg, cls in zip(reg_outputs, cls_outputs) + ] if self.training: return { - "features": output, + "features": features, } if self.export: - return {self.task: [self._export_bbox_output(output)]} + return { + self.task: self._prepare_bbox_export(reg_outputs, cls_outputs) + } boxes = non_max_suppression( - self._inference_bbox_output(output), + self._prepare_bbox_inference_output(reg_outputs, cls_outputs), n_classes=self.n_classes, conf_thres=self.conf_thres, iou_thres=self.iou_thres, @@ -153,7 +167,7 @@ def wrap(self, output: list[Tensor]) -> Packet[Tensor]: ) return { - "features": output, + "features": features, "boundingbox": boxes, } @@ -169,46 +183,68 @@ def _fit_stride_to_n_heads(self): ) return stride - def _extract_cls_and_box(self, x: list[Tensor]): + def _prepare_bbox_and_cls( + self, reg_outputs: list[Tensor], cls_outputs: list[Tensor] + ) -> list[Tensor]: """Extract classification and bounding box tensors.""" - shape = x[0].shape - x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2) - box, cls = x_cat.split((self.reg_max * 4, self.n_classes), 1) - return box, cls.sigmoid(), shape # Apply sigmoid to cls + output = [] + for i in range(self.n_heads): + box = self.dfl(reg_outputs[i]) + cls = cls_outputs[i].sigmoid() + conf = cls.max(1, keepdim=True)[0] + output.append( + torch.cat([box, conf, cls], dim=1) + ) # [bs, 4 + 1 + n_classes, h_f, w_f] + return output - def _export_bbox_output(self, x: list[Tensor]): + def _prepare_bbox_export( + self, reg_outputs: list[Tensor], cls_outputs: list[Tensor] + ) -> Tensor: """Prepare the output for export.""" - box, cls, _ = self._extract_cls_and_box(x) - box_dist = self.dfl(box) # Shape: [N, 4, N_anchors] - conf, _ = cls.max(1, keepdim=True) # Shape: [N, 1, N_anchors] - export_output = torch.cat( - [box_dist, conf, cls], dim=1 - ) # Shape: [N, 4 + 1 + num_classes, N_anchors] - return export_output - - def _inference_bbox_output(self, x: list[Tensor]): + return self._prepare_bbox_and_cls(reg_outputs, cls_outputs) + + def _prepare_bbox_inference_output( + self, reg_outputs: list[Tensor], cls_outputs: list[Tensor] + ): """Perform inference on predicted bounding boxes and class probabilities.""" - box, cls, shape = self._extract_cls_and_box(x) - box_dist = self.dfl(box) + processed_outputs = self._prepare_bbox_and_cls( + reg_outputs, cls_outputs + ) + box_dists = [] + class_probs = [] + for feature in processed_outputs: + bs, _, h, w = feature.size() + reshaped = feature.view(bs, -1, h * w) + box_dist = reshaped[:, :4, :] + cls = reshaped[:, 5:, :] + box_dists.append(box_dist) + class_probs.append(cls) + + box_dists = torch.cat(box_dists, dim=2) + class_probs = torch.cat(class_probs, dim=2) _, anchor_points, _, strides = anchors_for_fpn_features( - x, self.stride, 0.5 + processed_outputs, self.stride, 0.5 ) + pred_bboxes = dist2bbox( - box_dist, anchor_points.transpose(0, 1), out_format="xyxy", dim=1 + box_dists, anchor_points.transpose(0, 1), out_format="xyxy", dim=1 ) * strides.transpose(0, 1) + base_output = [ - pred_bboxes.permute(0, 2, 1), + pred_bboxes.permute(0, 2, 1), # [BS, H*W, 4] torch.ones( - (shape[0], pred_bboxes.shape[2], 1), + (box_dists.shape[0], pred_bboxes.shape[2], 1), dtype=pred_bboxes.dtype, device=pred_bboxes.device, ), - cls.permute(0, 2, 1), + class_probs.permute(0, 2, 1), # [BS, H*W, n_classes] ] - output_merged = torch.cat(base_output, dim=-1) + output_merged = torch.cat( + base_output, dim=-1 + ) # [BS, H*W, 4 + 1 + n_classes] return output_merged def bias_init(self): diff --git a/luxonis_train/nodes/heads/precision_seg_bbox_head.py b/luxonis_train/nodes/heads/precision_seg_bbox_head.py index 05b4a70b..56c95061 100644 --- a/luxonis_train/nodes/heads/precision_seg_bbox_head.py +++ b/luxonis_train/nodes/heads/precision_seg_bbox_head.py @@ -80,14 +80,10 @@ def forward( self, inputs: list[Tensor] ) -> tuple[list[Tensor], list[Tensor], list[Tensor], list[Tensor]]: prototypes = self.proto(inputs[0]) - bs = prototypes.shape[0] - mask_coefficients = torch.cat( - [ - self.mask_layers[i](inputs[i]).view(bs, self.n_masks, -1) - for i in range(self.n_heads) - ], - dim=2, - ) + mask_coefficients = [ + self.mask_layers[i](inputs[i]) for i in range(self.n_heads) + ] + det_outs = super().forward(inputs) return det_outs, prototypes, mask_coefficients @@ -96,25 +92,34 @@ def wrap( self, output: tuple[list[Tensor], Tensor, Tensor] ) -> Packet[Tensor]: det_feats, prototypes, mask_coefficients = output - if self.training: + + if self.export: + pred_bboxes = self._prepare_bbox_export(*det_feats) return { - "features": det_feats, + "boundingbox": pred_bboxes, + "masks": mask_coefficients, "prototypes": prototypes, - "mask_coeficients": mask_coefficients, } - if self.export: - pred_bboxes = self._export_bbox_output(det_feats) + det_feats_combined = [ + torch.cat((reg, cls), dim=1) for reg, cls in zip(*det_feats) + ] + mask_coefficients = torch.cat( + [ + coef.view(coef.size(0), self.n_masks, -1) + for coef in mask_coefficients + ], + dim=2, + ) + + if self.training: return { - TaskType.INSTANCE_SEGMENTATION: [ - torch.cat( - [pred_bboxes, mask_coefficients], 1 - ), # Shape: [N, 4 + 1 + num_classes + n_masks, N_anchors] - ], - "prototypes": [prototypes], # Shape: [N, n_masks, H, W] + "features": det_feats_combined, + "prototypes": prototypes, + "mask_coeficients": mask_coefficients, } - pred_bboxes = self._inference_bbox_output(det_feats) + pred_bboxes = self._prepare_bbox_inference_output(*det_feats) preds_combined = torch.cat( [pred_bboxes, mask_coefficients.permute(0, 2, 1)], dim=-1 ) @@ -129,7 +134,7 @@ def wrap( ) results = { - "features": det_feats, + "features": det_feats_combined, "prototypes": prototypes, "mask_coeficients": mask_coefficients, "boundingbox": [], From df89eef7382c21afdfeaadcdc044f27b6683ec38 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Sat, 11 Jan 2025 12:04:32 +0100 Subject: [PATCH 10/31] initial labels refactor support --- .../attached_modules/base_attached_module.py | 28 +++-- luxonis_train/config/config.py | 17 +-- .../predefined_models/classification_model.py | 2 +- .../predefined_models/detection_fomo_model.py | 2 +- .../predefined_models/detection_model.py | 4 +- .../keypoint_detection_model.py | 2 +- .../predefined_models/segmentation_model.py | 7 +- luxonis_train/core/core.py | 36 ++---- luxonis_train/enums.py | 1 + luxonis_train/loaders/base_loader.py | 3 - luxonis_train/loaders/luxonis_loader_torch.py | 43 +++---- luxonis_train/loaders/utils.py | 20 ++-- luxonis_train/models/luxonis_lightning.py | 29 ++--- luxonis_train/nodes/base_node.py | 112 ++++++------------ .../nodes/heads/ddrnet_segmentation_head.py | 2 +- .../nodes/heads/efficient_bbox_head.py | 4 +- luxonis_train/utils/dataset_metadata.py | 11 ++ luxonis_train/utils/types.py | 7 +- 18 files changed, 132 insertions(+), 198 deletions(-) diff --git a/luxonis_train/attached_modules/base_attached_module.py b/luxonis_train/attached_modules/base_attached_module.py index c84d8575..99461c96 100644 --- a/luxonis_train/attached_modules/base_attached_module.py +++ b/luxonis_train/attached_modules/base_attached_module.py @@ -73,13 +73,13 @@ def __init__(self, *, node: BaseNode | None = None): for label in self.supported_tasks ] module_supported = f"[{', '.join(module_supported)}]" - if not self.node.tasks: + if not self.node.task_types: raise IncompatibleException( f"Module '{self.name}' requires one of the following " f"labels or combinations of labels: {module_supported}, " f"but is connected to node '{self.node.name}' which does not specify any tasks." ) - node_tasks = set(self.node.tasks) + node_tasks = set(self.node.task_types) for required_labels in self.supported_tasks: if isinstance(required_labels, TaskType): required_labels = [required_labels] @@ -89,7 +89,7 @@ def __init__(self, *, node: BaseNode | None = None): self.required_labels = required_labels break else: - node_supported = [task.value for task in self.node.tasks] + node_supported = [task.value for task in self.node.task_types] raise IncompatibleException( f"Module '{self.name}' requires one of the following labels or combinations of labels: {module_supported}, " f"but is connected to node '{self.node.name}' which does not support any of them. " @@ -159,18 +159,18 @@ def class_names(self) -> list[str]: return self.node.class_names @property - def node_tasks(self) -> dict[TaskType, str]: + def node_tasks(self) -> list[TaskType]: """Getter for the tasks of the attached node. @type: dict[TaskType, str] @raises RuntimeError: If the node does not have the C{tasks} attribute set. """ - if self.node._tasks is None: + if self.node.task_types is None: raise RuntimeError( "Node must have the `tasks` attribute specified." ) - return self.node._tasks + return self.node.task_types def get_label( self, labels: Labels, task_type: TaskType | None = None @@ -210,13 +210,14 @@ def _get_label( if len(self.required_labels) == 1: task_type = self.required_labels[0] - if task_type is not None: - task_name = self.node.get_task_name(task_type) - if task_name not in labels: + if task_type is not None and self.node.task_name is not None: + task_name = self.node.task_name + task = f"{task_name}/{task_type.value}" + if task not in labels: raise IncompatibleException.from_missing_task( task_type.value, list(labels.keys()), self.name ) - return labels[task_name] + return labels[task], task_type raise ValueError( f"{self.name} requires multiple labels. You must provide the " @@ -260,7 +261,7 @@ def get_input_tensors( f"Task {task_type.value} is not supported by the node " f"{self.node.name}." ) - return inputs[self.node_tasks[task_type]] + return inputs[f"{self.node.task_name}/{task_type.value}"] else: if task_type not in inputs: raise IncompatibleException( @@ -273,7 +274,8 @@ def get_input_tensors( f"{self.name} requires multiple labels, " "you must provide the `task_type` argument to extract the desired input." ) - return inputs[self.node_tasks[self.required_labels[0]]] + task_type = self.node_tasks[0].value + return inputs[f"{self.node.task_name}/{task_type}"] def prepare( self, inputs: Packet[Tensor], labels: Labels | None @@ -305,7 +307,7 @@ def prepare( @raises RuntimeError: If the C{tasks} attribute is not set on the node. @raises RuntimeError: If the C{supported_tasks} attribute is not set on the module. """ - if self.node._tasks is None: + if self.node.task_types is None: raise RuntimeError( f"{self.node.name} must have the `tasks` attribute specified " f"for {self.name} to make use of the default `prepare` method." diff --git a/luxonis_train/config/config.py b/luxonis_train/config/config.py index fcbbf24a..f2a49f85 100644 --- a/luxonis_train/config/config.py +++ b/luxonis_train/config/config.py @@ -1,7 +1,7 @@ import logging import sys import warnings -from typing import Annotated, Any, Literal, TypeAlias +from typing import Annotated, Any, Literal, NamedTuple, TypeAlias from luxonis_ml.enums import DatasetType from luxonis_ml.utils import ( @@ -19,13 +19,16 @@ ) from typing_extensions import Self -from luxonis_train.enums import TaskType - logger = logging.getLogger(__name__) Params: TypeAlias = dict[str, Any] +class ImageSize(NamedTuple): + height: int + width: int + + class AttachedModuleConfig(BaseModelExtraForbid): name: str attached_to: str @@ -62,7 +65,7 @@ class ModelNodeConfig(BaseModelExtraForbid): input_sources: list[str] = [] # From data loader freezing: FreezingConfig = FreezingConfig() remove_on_export: bool = False - task: str | dict[TaskType, str] | None = None + task_name: str | None = None params: Params = {} @@ -303,10 +306,10 @@ class AugmentationConfig(BaseModelExtraForbid): class PreprocessingConfig(BaseModelExtraForbid): train_image_size: Annotated[ - list[int], Field(default=[256, 256], min_length=2, max_length=2) - ] = [256, 256] + ImageSize, Field(default=[256, 256], min_length=2, max_length=2) + ] = ImageSize(256, 256) keep_aspect_ratio: bool = True - train_rgb: bool = True + color_format: Literal["RGB", "BGR"] = "RGB" normalize: NormalizeAugmentationConfig = NormalizeAugmentationConfig() augmentations: list[AugmentationConfig] = [] diff --git a/luxonis_train/config/predefined_models/classification_model.py b/luxonis_train/config/predefined_models/classification_model.py index e028bba5..0d749b5a 100644 --- a/luxonis_train/config/predefined_models/classification_model.py +++ b/luxonis_train/config/predefined_models/classification_model.py @@ -84,7 +84,7 @@ def nodes(self) -> list[ModelNodeConfig]: inputs=[f"{self.backbone}-{self.task_name}"], freezing=self.head_params.pop("freezing", {}), params=self.head_params, - task=self.task_name, + task_name=self.task_name, ), ] diff --git a/luxonis_train/config/predefined_models/detection_fomo_model.py b/luxonis_train/config/predefined_models/detection_fomo_model.py index d9702ece..1a21a5a3 100644 --- a/luxonis_train/config/predefined_models/detection_fomo_model.py +++ b/luxonis_train/config/predefined_models/detection_fomo_model.py @@ -80,7 +80,7 @@ def nodes(self) -> list[ModelNodeConfig]: alias=f"FOMOHead-{self.kpt_task_name}", inputs=[f"{self.backbone}-{self.kpt_task_name}"], params=self.head_params, - task={ + task_name={ TaskType.BOUNDINGBOX: self.bbox_task_name, TaskType.KEYPOINTS: self.kpt_task_name, }, diff --git a/luxonis_train/config/predefined_models/detection_model.py b/luxonis_train/config/predefined_models/detection_model.py index dbbc8886..d1498845 100644 --- a/luxonis_train/config/predefined_models/detection_model.py +++ b/luxonis_train/config/predefined_models/detection_model.py @@ -80,7 +80,7 @@ def __init__( self.head_params = head_params or var_config.head_params self.loss_params = loss_params or {"n_warmup_epochs": 0} self.visualizer_params = visualizer_params or {} - self.task_name = task_name or "boundingbox" + self.task_name = task_name @property def nodes(self) -> list[ModelNodeConfig]: @@ -114,7 +114,7 @@ def nodes(self) -> list[ModelNodeConfig]: if self.use_neck else [f"{self.backbone}-{self.task_name}"], params=self.head_params, - task=self.task_name, + task_name=self.task_name, ) ) return nodes diff --git a/luxonis_train/config/predefined_models/keypoint_detection_model.py b/luxonis_train/config/predefined_models/keypoint_detection_model.py index 51d790a7..8882f338 100644 --- a/luxonis_train/config/predefined_models/keypoint_detection_model.py +++ b/luxonis_train/config/predefined_models/keypoint_detection_model.py @@ -122,7 +122,7 @@ def nodes(self) -> list[ModelNodeConfig]: ), freezing=self.head_params.pop("freezing", {}), params=self.head_params, - task=task, + task_name=task, ) ) return nodes diff --git a/luxonis_train/config/predefined_models/segmentation_model.py b/luxonis_train/config/predefined_models/segmentation_model.py index eff4fd02..0260e843 100644 --- a/luxonis_train/config/predefined_models/segmentation_model.py +++ b/luxonis_train/config/predefined_models/segmentation_model.py @@ -71,7 +71,7 @@ def __init__( self.loss_params = loss_params or {} self.visualizer_params = visualizer_params or {} self.task = task - self.task_name = task_name or "segmentation" + self.task_name = task_name @property def nodes(self) -> list[ModelNodeConfig]: @@ -85,6 +85,7 @@ def nodes(self) -> list[ModelNodeConfig]: alias=f"{self.backbone}-{self.task_name}", freezing=self.backbone_params.pop("freezing", {}), params=self.backbone_params, + task_name=self.task_name, ), ModelNodeConfig( name="DDRNetSegmentationHead", @@ -92,7 +93,7 @@ def nodes(self) -> list[ModelNodeConfig]: inputs=[f"{self.backbone}-{self.task_name}"], freezing=self.head_params.pop("freezing", {}), params=self.head_params, - task=self.task_name, + task_name=self.task_name, ), ] if self.backbone_params.get("use_aux_heads", True): @@ -103,7 +104,7 @@ def nodes(self) -> list[ModelNodeConfig]: inputs=[f"{self.backbone}-{self.task_name}"], freezing=self.aux_head_params.pop("freezing", {}), params=self.aux_head_params, - task=self.task_name, + task_name=self.task_name, remove_on_export=self.aux_head_params.pop( "remove_on_export", True ), diff --git a/luxonis_train/core/core.py b/luxonis_train/core/core.py index 03ff5189..fe959349 100644 --- a/luxonis_train/core/core.py +++ b/luxonis_train/core/core.py @@ -13,7 +13,6 @@ import torch.utils.data as torch_data import yaml from lightning.pytorch.utilities import rank_zero_only -from luxonis_ml.data import Augmentations from luxonis_ml.nn_archive import ArchiveGenerator from luxonis_ml.nn_archive.config import CONFIG_VERSION from luxonis_ml.utils import LuxonisFileSystem, reset_logging, setup_logging @@ -113,26 +112,6 @@ def __init__( precision=self.cfg.trainer.precision, ) - self.train_augmentations = Augmentations( - image_size=self.cfg.trainer.preprocessing.train_image_size, - augmentations=[ - i.model_dump() - for i in self.cfg.trainer.preprocessing.get_active_augmentations() - ], - train_rgb=self.cfg.trainer.preprocessing.train_rgb, - keep_aspect_ratio=self.cfg.trainer.preprocessing.keep_aspect_ratio, - ) - self.val_augmentations = Augmentations( - image_size=self.cfg.trainer.preprocessing.train_image_size, - augmentations=[ - i.model_dump() - for i in self.cfg.trainer.preprocessing.get_active_augmentations() - ], - train_rgb=self.cfg.trainer.preprocessing.train_rgb, - keep_aspect_ratio=self.cfg.trainer.preprocessing.keep_aspect_ratio, - only_normalize=True, - ) - self.loaders: dict[str, BaseLoaderTorch] = {} for view in ["train", "val", "test"]: loader_name = self.cfg.loader.name @@ -141,17 +120,20 @@ def __init__( self.cfg.loader.params["delete_existing"] = False self.loaders[view] = Loader( - augmentations=( - self.train_augmentations - if view == "train" - else self.val_augmentations - ), view={ "train": self.cfg.loader.train_view, "val": self.cfg.loader.val_view, "test": self.cfg.loader.test_view, }[view], image_source=self.cfg.loader.image_source, + height=self.cfg.trainer.preprocessing.train_image_size.height, + width=self.cfg.trainer.preprocessing.train_image_size.width, + augmentation_config=[ + i.model_dump() + for i in self.cfg.trainer.preprocessing.get_active_augmentations() + ], + out_image_format=self.cfg.trainer.preprocessing.color_format, + keep_aspect_ratio=self.cfg.trainer.preprocessing.keep_aspect_ratio, **self.cfg.loader.params, ) @@ -739,7 +721,7 @@ def _mult(lst: list[float | int]) -> list[float]: self.cfg.trainer.preprocessing.normalize.params["std"] ), "dai_type": "RGB888p" - if self.cfg.trainer.preprocessing.train_rgb + if self.cfg.trainer.preprocessing.out_image_format else "BGR888p", } diff --git a/luxonis_train/enums.py b/luxonis_train/enums.py index b024d6a9..ea719e1c 100644 --- a/luxonis_train/enums.py +++ b/luxonis_train/enums.py @@ -6,6 +6,7 @@ class TaskType(str, Enum): CLASSIFICATION = "classification" SEGMENTATION = "segmentation" + INSTANCE_SEGMENTATION = "instance_segmentation" BOUNDINGBOX = "boundingbox" KEYPOINTS = "keypoints" LABEL = "label" diff --git a/luxonis_train/loaders/base_loader.py b/luxonis_train/loaders/base_loader.py index 0c056d98..3e3589dd 100644 --- a/luxonis_train/loaders/base_loader.py +++ b/luxonis_train/loaders/base_loader.py @@ -1,6 +1,5 @@ from abc import ABC, abstractmethod -from luxonis_ml.data import Augmentations from luxonis_ml.utils.registry import AutoRegisterMeta from torch import Size from torch.utils.data import Dataset @@ -23,11 +22,9 @@ class BaseLoaderTorch( def __init__( self, view: str | list[str], - augmentations: Augmentations | None = None, image_source: str | None = None, ): self.view = view if isinstance(view, list) else [view] - self.augmentations = augmentations self._image_source = image_source @property diff --git a/luxonis_train/loaders/luxonis_loader_torch.py b/luxonis_train/loaders/luxonis_loader_torch.py index 230128b5..445fe641 100644 --- a/luxonis_train/loaders/luxonis_loader_torch.py +++ b/luxonis_train/loaders/luxonis_loader_torch.py @@ -1,9 +1,9 @@ import logging -from typing import Literal +from pathlib import Path +from typing import Any, Literal import numpy as np from luxonis_ml.data import ( - Augmentations, BucketStorage, BucketType, LuxonisDataset, @@ -14,8 +14,6 @@ from torch import Size, Tensor from typeguard import typechecked -from luxonis_train.enums import TaskType - from .base_loader import BaseLoaderTorch, LuxonisLoaderTorchOutput logger = logging.getLogger(__name__) @@ -31,10 +29,15 @@ def __init__( team_id: str | None = None, bucket_type: Literal["internal", "external"] = "internal", bucket_storage: Literal["local", "s3", "gcs", "azure"] = "local", - stream: bool = False, delete_existing: bool = True, view: str | list[str] = "train", - augmentations: Augmentations | None = None, + augmentation_engine: str + | Literal["albumentations"] = "albumentations", + augmentation_config: Path | str | list[dict[str, Any]] | None = None, + height: int | None = None, + width: int | None = None, + keep_aspect_ratio: bool = True, + out_image_format: Literal["RGB", "BGR"] = "RGB", **kwargs, ): """Torch-compatible loader for Luxonis datasets. @@ -61,8 +64,6 @@ def __init__( Defaults to 'internal'. @type bucket_storage: Literal["local", "s3", "gcs", "azure"] @param bucket_storage: Type of the bucket storage. Defaults to 'local'. - @type stream: bool - @param stream: Flag for data streaming. Defaults to C{False}. @type delete_existing: bool @param delete_existing: Only relevant when C{dataset_dir} is provided. By default, the dataset is parsed again every time the loader is created @@ -74,10 +75,8 @@ def __init__( view of the dataset. Each split is a string that represents a subset of the dataset. The available splits depend on the dataset, but usually include 'train', 'val', and 'test'. Defaults to 'train'. - @type augmentations: Augmentations | None - @param augmentations: Augmentations to apply to the data. Defaults to C{None}. """ - super().__init__(view=view, augmentations=augmentations, **kwargs) + super().__init__(view=view, **kwargs) if dataset_dir is not None: self.dataset = self._parse_dataset( dataset_dir, dataset_name, dataset_type, delete_existing @@ -93,15 +92,19 @@ def __init__( bucket_type=BucketType(bucket_type), bucket_storage=BucketStorage(bucket_storage), ) - self.base_loader = LuxonisLoader( + self.loader = LuxonisLoader( dataset=self.dataset, - view=self.view, - stream=stream, - augmentations=self.augmentations, + view=view, + augmentation_engine=augmentation_engine, + augmentation_config=augmentation_config, + height=height, + width=width, + keep_aspect_ratio=keep_aspect_ratio, + out_image_format=out_image_format, ) def __len__(self) -> int: - return len(self.base_loader) + return len(self.loader) @property def input_shapes(self) -> dict[str, Size]: @@ -109,13 +112,13 @@ def input_shapes(self) -> dict[str, Size]: return {self.image_source: img.shape} def __getitem__(self, idx: int) -> LuxonisLoaderTorchOutput: - img, labels = self.base_loader[idx] + img, labels = self.loader[idx] img = np.transpose(img, (2, 0, 1)) # HWC to CHW tensor_img = Tensor(img) - tensor_labels: dict[str, tuple[Tensor, TaskType]] = {} - for task, (array, label_type) in labels.items(): - tensor_labels[task] = (Tensor(array), TaskType(label_type.value)) + tensor_labels: dict[str, Tensor] = {} + for task, array in labels.items(): + tensor_labels[task] = Tensor(array) return {self.image_source: tensor_img}, tensor_labels diff --git a/luxonis_train/loaders/utils.py b/luxonis_train/loaders/utils.py index b030e218..10b4d17a 100644 --- a/luxonis_train/loaders/utils.py +++ b/luxonis_train/loaders/utils.py @@ -1,7 +1,7 @@ import torch +from luxonis_ml.data.utils import get_task_type from torch import Tensor -from luxonis_train.enums import TaskType from luxonis_train.utils.types import Labels LuxonisLoaderTorchOutput = tuple[dict[str, Tensor], Labels] @@ -32,22 +32,18 @@ def collate_fn( out_labels: Labels = {} for task in labels[0].keys(): - task_type = labels[0][task][1] - annos = [label[task][0] for label in labels] - if task_type in [ - TaskType.CLASSIFICATION, - TaskType.SEGMENTATION, - TaskType.ARRAY, - ]: - out_labels[task] = torch.stack(annos, 0), task_type - - elif task_type in [TaskType.KEYPOINTS, TaskType.BOUNDINGBOX]: + task_type = get_task_type(task) + annos = [label[task] for label in labels] + + if task_type in {"keypoints", "boundingbox"}: label_box: list[Tensor] = [] for i, box in enumerate(annos): l_box = torch.zeros((box.shape[0], box.shape[1] + 1)) l_box[:, 0] = i # add target image index for build_targets() l_box[:, 1:] = box label_box.append(l_box) - out_labels[task] = torch.cat(label_box, 0), task_type + out_labels[task] = torch.cat(label_box, 0) + else: + out_labels[task] = torch.stack(annos, 0) return out_inputs, out_labels diff --git a/luxonis_train/models/luxonis_lightning.py b/luxonis_train/models/luxonis_lightning.py index dae683ce..cebce43f 100644 --- a/luxonis_train/models/luxonis_lightning.py +++ b/luxonis_train/models/luxonis_lightning.py @@ -181,32 +181,17 @@ def __init__( ) frozen_nodes.append((node_name, unfreeze_after)) - if node_cfg.task is not None: - if Node.tasks is None: - raise ValueError( - f"Cannot define tasks for node {node_name}." - "This node doesn't specify any tasks." - ) - if isinstance(node_cfg.task, str): - assert Node.tasks - if len(Node.tasks) > 1: - raise ValueError( - f"Node {node_name} specifies multiple tasks, " - "but only one task is specified in the config. " - "Specify the tasks as a dictionary instead." - ) + if node_cfg.task_name is not None and Node.task_types is None: + raise ValueError( + f"Cannot define tasks for node {node_name}." + "This node doesn't specify any tasks." + ) - node_cfg.task = {next(iter(Node.tasks)): node_cfg.task} - else: - node_cfg.task = { - **Node._process_tasks(Node.tasks), - **node_cfg.task, - } nodes[node_name] = ( Node, { **node_cfg.params, - "_tasks": node_cfg.task, + "task_name": node_cfg.task_name, "remove_on_export": node_cfg.remove_on_export, }, ) @@ -1000,7 +985,7 @@ def _init_attached_module( loader = self._core.loaders["train"] dataset = getattr(loader, "dataset", None) if isinstance(dataset, LuxonisDataset): - n_classes = len(dataset.get_classes()[1][node.task]) + n_classes = len(dataset.get_classes()[1][node.task_name]) if n_classes == 1: cfg.params["task"] = "binary" else: diff --git a/luxonis_train/nodes/base_node.py b/luxonis_train/nodes/base_node.py index 748742dd..eff2c2d1 100644 --- a/luxonis_train/nodes/base_node.py +++ b/luxonis_train/nodes/base_node.py @@ -109,7 +109,7 @@ def wrap(output: Tensor) -> Packet[Tensor]: """ attach_index: AttachIndexType - tasks: list[TaskType] | dict[TaskType, str] | None = None + task_types: list[TaskType] | None = None def __init__( self, @@ -123,7 +123,7 @@ def __init__( remove_on_export: bool = False, export_output_names: list[str] | None = None, attach_index: AttachIndexType | None = None, - _tasks: dict[TaskType, str] | None = None, + task_name: str | None = None, ): """Constructor for the C{BaseNode}. @@ -168,11 +168,16 @@ def __init__( "Make sure this is intended." ) self.attach_index = attach_index - self._tasks = None - if _tasks is not None: - self._tasks = _tasks - elif self.tasks is not None: - self._tasks = self._process_tasks(self.tasks) + + self.task_name = task_name + if task_name is None and dataset_metadata is not None: + if len(dataset_metadata.task_names) == 1: + self.task_name = next(iter(dataset_metadata.task_names)) + else: + raise ValueError( + f"Dataset contain multiple tasks, but the `task_name` " + f"argument for node '{self.name}' was not provided." + ) if getattr(self, "attach_index", None) is None: parameters = inspect.signature(self.forward).parameters @@ -200,15 +205,6 @@ def __init__( self._check_type_overrides() - @staticmethod - def _process_tasks( - tasks: dict[TaskType, str] | list[TaskType], - ) -> dict[TaskType, str]: - if isinstance(tasks, dict): - return tasks - else: - return {task: task.value for task in tasks} - def _check_type_overrides(self) -> None: properties = [] for name, value in inspect.getmembers(self.__class__): @@ -228,67 +224,28 @@ def _check_type_overrides(self) -> None: "not compatible with its predecessor." ) from e - def get_task_name(self, task: TaskType) -> str: - """Gets the name of a task for a particular C{TaskType}. - - @type task: TaskType - @param task: Task to get the name for. - @rtype: str - @return: Name of the task. - @raises RuntimeError: If the node does not define any tasks. - @raises ValueError: If the task is not supported by the node. - """ - if not self._tasks: - raise RuntimeError(f"Node '{self.name}' does not define any task.") - - if task not in self._tasks: - raise ValueError( - f"Node '{self.name}' does not support the '{task.value}' task." - ) - return self._tasks[task] - @property def name(self) -> str: return self.__class__.__name__ @property - def task(self) -> str: - """Getter for the task. + def task_type(self) -> str: + """Getter for the task type. @type: str @raises RuntimeError: If the node doesn't define any task. @raises ValueError: If the node defines more than one task. In that case, use the L{get_task_name} method instead. """ - if not self._tasks: + if not self.task_types: raise RuntimeError(f"{self.name} does not define any task.") - if len(self._tasks) > 1: + if len(self.task_types) > 1: raise ValueError( f"Node {self.name} has multiple tasks defined. " "Use the `get_task_name` method instead." ) - return next(iter(self._tasks.values())) - - def get_n_classes(self, task: TaskType) -> int: - """Gets the number of classes for a particular task. - - @type task: TaskType - @param task: Task to get the number of classes for. - @rtype: int - @return: Number of classes for the task. - """ - return self.dataset_metadata.n_classes(self.get_task_name(task)) - - def get_class_names(self, task: TaskType) -> list[str]: - """Gets the class names for a particular task. - - @type task: TaskType - @param task: Task to get the class names for. - @rtype: list[str] - @return: Class names for the task. - """ - return self.dataset_metadata.classes(self.get_task_name(task)) + return self.task_types[0].value @property def n_keypoints(self) -> int: @@ -301,12 +258,10 @@ def n_keypoints(self) -> int: if self._n_keypoints is not None: return self._n_keypoints - if self._tasks: - if TaskType.KEYPOINTS not in self._tasks: + if self.task_types: + if TaskType.KEYPOINTS not in self.task_types: raise ValueError(f"{self.name} does not support keypoints.") - return self.dataset_metadata.n_keypoints( - self.get_task_name(TaskType.KEYPOINTS) - ) + return self.dataset_metadata.n_keypoints(self.task_name) raise RuntimeError( f"{self.name} does not have any tasks defined, " @@ -329,7 +284,7 @@ def n_classes(self) -> int: if self._n_classes is not None: return self._n_classes - if not self._tasks: + if not self.task_types: raise RuntimeError( f"{self.name} does not have any tasks defined, " "`BaseNode.n_classes` property cannot be used. " @@ -337,12 +292,12 @@ def n_classes(self) -> int: "pass the `n_classes` attribute to the constructor or call " "the `BaseNode.dataset_metadata.n_classes` method manually." ) - elif len(self._tasks) == 1: - return self.dataset_metadata.n_classes(self.task) + elif len(self.task_types) == 1: + return self.dataset_metadata.n_classes(self.task_name) else: n_classes = [ - self.dataset_metadata.n_classes(self.get_task_name(task)) - for task in self._tasks + self.dataset_metadata.n_classes(self.task_name) + for task in self.task_types ] if len(set(n_classes)) == 1: return n_classes[0] @@ -362,7 +317,7 @@ def class_names(self) -> list[str]: different tasks. In that case, use the L{get_class_names} method. """ - if not self._tasks: + if not self.task_types: raise RuntimeError( f"{self.name} does not have any tasks defined, " "`BaseNode.class_names` property cannot be used. " @@ -370,12 +325,12 @@ def class_names(self) -> list[str]: "pass the `n_classes` attribute to the constructor or call " "the `BaseNode.dataset_metadata.class_names` method manually." ) - elif len(self._tasks) == 1: - return self.dataset_metadata.classes(self.task) + elif len(self.task_types) == 1: + return self.dataset_metadata.classes(self.task_name) else: class_names = [ - self.dataset_metadata.classes(self.get_task_name(task)) - for task in self._tasks + self.dataset_metadata.classes(self.task_name) + for task in self.task_types ] if all(set(names) == set(class_names[0]) for names in class_names): return class_names[0] @@ -633,7 +588,7 @@ def wrap(self, output: ForwardOutputT) -> Packet[Tensor]: "Default `wrap` expects a single tensor or a list of tensors." ) try: - task = self.task + task = f"{self.task_name}/{self.task_type}" except RuntimeError: task = "features" return {task: outputs} @@ -654,11 +609,12 @@ def run(self, inputs: list[Packet[Tensor]]) -> Packet[Tensor]: unwrapped = self.unwrap(inputs) outputs = self(unwrapped) wrapped = self.wrap(outputs) - str_tasks = [task.value for task in self._tasks] if self._tasks else [] + str_tasks = [task.value for task in self.task_types or []] for key in list(wrapped.keys()): if key in str_tasks: + assert self.task_name is not None value = wrapped.pop(key) - wrapped[self.get_task_name(TaskType(key))] = value + wrapped[f"{self.task_name}/{key}"] = value return wrapped T = TypeVar("T", Tensor, Size) diff --git a/luxonis_train/nodes/heads/ddrnet_segmentation_head.py b/luxonis_train/nodes/heads/ddrnet_segmentation_head.py index 2b313ab6..35c5e69c 100644 --- a/luxonis_train/nodes/heads/ddrnet_segmentation_head.py +++ b/luxonis_train/nodes/heads/ddrnet_segmentation_head.py @@ -18,7 +18,7 @@ class DDRNetSegmentationHead(BaseHead[Tensor, Tensor]): in_width: int in_channels: int - tasks: list[TaskType] = [TaskType.SEGMENTATION] + task_types: list[TaskType] = [TaskType.SEGMENTATION] parser: str = "SegmentationParser" def __init__( diff --git a/luxonis_train/nodes/heads/efficient_bbox_head.py b/luxonis_train/nodes/heads/efficient_bbox_head.py index 76eb2e5a..2eb57dfb 100644 --- a/luxonis_train/nodes/heads/efficient_bbox_head.py +++ b/luxonis_train/nodes/heads/efficient_bbox_head.py @@ -21,7 +21,7 @@ class EfficientBBoxHead( BaseHead[list[Tensor], tuple[list[Tensor], list[Tensor], list[Tensor]]], ): in_channels: list[int] - tasks: list[TaskType] = [TaskType.BOUNDINGBOX] + task_types: list[TaskType] = [TaskType.BOUNDINGBOX] parser = "YOLO" def __init__( @@ -171,7 +171,7 @@ def wrap( conf, _ = out_cls.max(1, keepdim=True) out = torch.cat([out_reg, conf, out_cls], dim=1) outputs.append(out) - return {self.task: outputs} + return {self.task_type: outputs} cls_tensor = torch.cat( [cls_score_list[i].flatten(2) for i in range(len(cls_score_list))], diff --git a/luxonis_train/utils/dataset_metadata.py b/luxonis_train/utils/dataset_metadata.py index 3a9cecdf..fdbec775 100644 --- a/luxonis_train/utils/dataset_metadata.py +++ b/luxonis_train/utils/dataset_metadata.py @@ -1,3 +1,5 @@ +from typing import Set + from luxonis_train.loaders import BaseLoaderTorch @@ -28,6 +30,15 @@ def __init__( self._n_keypoints = n_keypoints or {} self._loader = loader + @property + def task_names(self) -> Set[str]: + """Gets the names of the tasks present in the dataset. + + @rtype: set[str] + @return: Names of the tasks present in the dataset. + """ + return set(self._classes.keys()) + def n_classes(self, task: str | None = None) -> int: """Gets the number of classes for the specified task. diff --git a/luxonis_train/utils/types.py b/luxonis_train/utils/types.py index 8666751b..f1d8e6b5 100644 --- a/luxonis_train/utils/types.py +++ b/luxonis_train/utils/types.py @@ -2,14 +2,11 @@ from torch import Size, Tensor -from luxonis_train.enums import TaskType - Kwargs = dict[str, Any] """Kwargs is a dictionary containing keyword arguments.""" -Labels = dict[str, tuple[Tensor, TaskType]] -"""Labels is a dictionary containing a tuple of tensors and their -corresponding task type.""" +Labels = dict[str, Tensor] +"""Labels is a dictionary mapping task names to tensors.""" AttachIndexType = Literal["all"] | int | tuple[int, int] | tuple[int, int, int] """AttachIndexType is used to specify to which output of the prevoius From d01816bc5b551f0408036a38890f7c0360e20f98 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Tue, 14 Jan 2025 07:14:04 -0600 Subject: [PATCH 11/31] updated docs --- configs/README.md | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/configs/README.md b/configs/README.md index 69d77243..37d150f9 100644 --- a/configs/README.md +++ b/configs/README.md @@ -280,14 +280,14 @@ We use [`Albumentations`](https://albumentations.ai/docs/) library for `augmenta Additionally, we support `Mosaic4` and `MixUp` batch augmentations and letterbox resizing if `keep_aspect_ratio: true`. -| Key | Type | Default value | Description | -| ------------------- | ------------ | ------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `train_image_size` | `list[int]` | `[256, 256]` | Image size used for training as `[height, width]` | -| `keep_aspect_ratio` | `bool` | `True` | Whether to keep the aspect ratio while resizing | -| `train_rgb` | `bool` | `True` | Whether to train on RGB or BGR images | -| `normalize.active` | `bool` | `True` | Whether to use normalization | -| `normalize.params` | `dict` | `{}` | Parameters for normalization, see [Normalize](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.Normalize) | -| `augmentations` | `list[dict]` | `[]` | List of `Albumentations` augmentations | +| Key | Type | Default value | Description | +| ------------------- | ----------------------- | ------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `train_image_size` | `list[int]` | `[256, 256]` | Image size used for training as `[height, width]` | +| `keep_aspect_ratio` | `bool` | `True` | Whether to keep the aspect ratio while resizing | +| `color_format` | `Literal["RGB", "BGR"]` | `"RGB"` | Whether to train on RGB or BGR images | +| `normalize.active` | `bool` | `True` | Whether to use normalization | +| `normalize.params` | `dict` | `{}` | Parameters for normalization, see [Normalize](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.Normalize) | +| `augmentations` | `list[dict]` | `[]` | List of `Albumentations` augmentations | #### Augmentations @@ -306,7 +306,7 @@ trainer: # using YAML capture to reuse the image size train_image_size: [&height 384, &width 384] keep_aspect_ratio: true - train_rgb: true + color_format: "RGB" normalize: active: true augmentations: @@ -418,7 +418,7 @@ Each training strategy is a dictionary with the following fields: ```yaml training_strategy: name: "TripleLRSGDStrategy" - params: + params: warmup_epochs: 3 warmup_bias_lr: 0.1 warmup_momentum: 0.8 From e34e893b83b5e03b2fbead9b1eb4e45b8d564cd1 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Tue, 14 Jan 2025 07:43:38 -0600 Subject: [PATCH 12/31] updated predefined models --- configs/complex_model.yaml | 2 +- luxonis_train/config/config.py | 39 ++++++++++++------- .../anomaly_detection_model.py | 19 ++++----- .../predefined_models/classification_model.py | 23 +++++------ .../predefined_models/detection_fomo_model.py | 29 +++++--------- .../predefined_models/detection_model.py | 23 +++++------ .../keypoint_detection_model.py | 38 +++++++----------- .../predefined_models/segmentation_model.py | 35 ++++++----------- 8 files changed, 85 insertions(+), 123 deletions(-) diff --git a/configs/complex_model.yaml b/configs/complex_model.yaml index 149530ad..fed25c23 100644 --- a/configs/complex_model.yaml +++ b/configs/complex_model.yaml @@ -100,7 +100,7 @@ trainer: preprocessing: train_image_size: [&height 384, &width 384] keep_aspect_ratio: true - train_rgb: true + color_format: RGB normalize: active: true augmentations: diff --git a/luxonis_train/config/config.py b/luxonis_train/config/config.py index f2a49f85..014caaf2 100644 --- a/luxonis_train/config/config.py +++ b/luxonis_train/config/config.py @@ -65,7 +65,7 @@ class ModelNodeConfig(BaseModelExtraForbid): input_sources: list[str] = [] # From data loader freezing: FreezingConfig = FreezingConfig() remove_on_export: bool = False - task_name: str | None = None + task_name: str = "" params: Params = {} @@ -105,7 +105,7 @@ def validate_nodes(cls, nodes: Any) -> Any: if "Head" in name and last_body_index is None: last_body_index = i - 1 names.append(name) - if i > 0 and "inputs" not in node: + if i > 0 and "inputs" not in node and "input_sources" not in node: if last_body_index is not None: prev_name = names[last_body_index] else: @@ -202,23 +202,32 @@ def check_graph(self) -> Self: @model_validator(mode="after") def check_unique_names(self) -> Self: - for section, objects in [ - ("nodes", self.nodes), - ("losses", self.losses), - ("metrics", self.metrics), - ("visualizers", self.visualizers), + for modules in [ + self.nodes, + self.losses, + self.metrics, + self.visualizers, ]: names: set[str] = set() - for obj in objects: - obj: AttachedModuleConfig - name = obj.alias or obj.name + node_index = 0 + for module in modules: + module: AttachedModuleConfig | ModelNodeConfig + name = module.alias or module.name if name in names: - if obj.alias is None: - obj.alias = f"{name}_{obj.attached_to}" - if obj.alias in names: - raise ValueError( - f"Duplicate name `{name}` in `{section}` section." + if module.alias is None: + if isinstance(module, ModelNodeConfig): + module.alias = module.name + else: + module.alias = f"{name}_{module.attached_to}" + + if module.alias in names: + new_alias = f"{module.alias}_{node_index}" + logger.warning( + f"Duplicate name: {module.alias}. Renaming to {new_alias}." ) + module.alias = new_alias + node_index += 1 + names.add(name) return self diff --git a/luxonis_train/config/predefined_models/anomaly_detection_model.py b/luxonis_train/config/predefined_models/anomaly_detection_model.py index 6dbb1a1d..5dff9b93 100644 --- a/luxonis_train/config/predefined_models/anomaly_detection_model.py +++ b/luxonis_train/config/predefined_models/anomaly_detection_model.py @@ -5,7 +5,7 @@ from luxonis_train.config import ( AttachedModuleConfig, LossModuleConfig, - MetricModuleConfig, # Metrics support added + MetricModuleConfig, ModelNodeConfig, Params, ) @@ -51,7 +51,7 @@ def __init__( disc_subnet_params: Params | None = None, loss_params: Params | None = None, visualizer_params: Params | None = None, - task_name: str | None = None, + task_name: str = "", ): var_config = get_variant(variant) @@ -73,13 +73,13 @@ def nodes(self) -> list[ModelNodeConfig]: return [ ModelNodeConfig( name=self.backbone, - alias=f"{self.backbone}-{self.task_name}", + alias=f"{self.task_name}/{self.backbone}", params=self.backbone_params, ), ModelNodeConfig( name="DiscSubNetHead", - alias=f"DiscSubNetHead-{self.task_name}", - inputs=[f"{self.backbone}-{self.task_name}"], + alias=f"{self.task_name}/DiscSubNetHead", + inputs=[f"{self.task_name}/{self.backbone}"], params=self.disc_subnet_params, ), ] @@ -90,8 +90,7 @@ def losses(self) -> list[LossModuleConfig]: return [ LossModuleConfig( name="ReconstructionSegmentationLoss", - alias=f"ReconstructionSegmentationLoss-{self.task_name}", - attached_to=f"DiscSubNetHead-{self.task_name}", + attached_to=f"{self.task_name}/DiscSubNetHead", params=self.loss_params, weight=1.0, ) @@ -103,8 +102,7 @@ def metrics(self) -> list[MetricModuleConfig]: return [ MetricModuleConfig( name="JaccardIndex", - alias=f"JaccardIndex-{self.task_name}", - attached_to=f"DiscSubNetHead-{self.task_name}", + attached_to=f"{self.task_name}/DiscSubNetHead", params={"num_classes": 2, "task": "multiclass"}, is_main_metric=True, ), @@ -117,8 +115,7 @@ def visualizers(self) -> list[AttachedModuleConfig]: return [ AttachedModuleConfig( name="SegmentationVisualizer", - alias=f"SegmentationVisualizer-{self.task_name}", - attached_to=f"DiscSubNetHead-{self.task_name}", + attached_to=f"{self.task_name}/DiscSubNetHead", params=self.visualizer_params, ) ] diff --git a/luxonis_train/config/predefined_models/classification_model.py b/luxonis_train/config/predefined_models/classification_model.py index 0d749b5a..86964e0a 100644 --- a/luxonis_train/config/predefined_models/classification_model.py +++ b/luxonis_train/config/predefined_models/classification_model.py @@ -52,7 +52,7 @@ def __init__( loss_params: Params | None = None, visualizer_params: Params | None = None, task: Literal["multiclass", "multilabel"] = "multiclass", - task_name: str | None = None, + task_name: str = "", ): var_config = get_variant(variant) @@ -74,14 +74,14 @@ def nodes(self) -> list[ModelNodeConfig]: return [ ModelNodeConfig( name=self.backbone, - alias=f"{self.backbone}-{self.task_name}", + alias=f"{self.task_name}/{self.backbone}", freezing=self.backbone_params.pop("freezing", {}), params=self.backbone_params, ), ModelNodeConfig( name="ClassificationHead", - alias=f"ClassificationHead-{self.task_name}", - inputs=[f"{self.backbone}-{self.task_name}"], + alias=f"{self.task_name}/ClassificationHead", + inputs=[f"{self.task_name}/{self.backbone}"], freezing=self.head_params.pop("freezing", {}), params=self.head_params, task_name=self.task_name, @@ -94,8 +94,7 @@ def losses(self) -> list[LossModuleConfig]: return [ LossModuleConfig( name="CrossEntropyLoss", - alias=f"CrossEntropyLoss-{self.task_name}", - attached_to=f"ClassificationHead-{self.task_name}", + attached_to=f"{self.task_name}/ClassificationHead", params=self.loss_params, weight=1.0, ) @@ -107,21 +106,18 @@ def metrics(self) -> list[MetricModuleConfig]: return [ MetricModuleConfig( name="F1Score", - alias=f"F1Score-{self.task_name}", is_main_metric=True, - attached_to=f"ClassificationHead-{self.task_name}", + attached_to=f"{self.task_name}/ClassificationHead", params={"task": self.task}, ), MetricModuleConfig( name="Accuracy", - alias=f"Accuracy-{self.task_name}", - attached_to=f"ClassificationHead-{self.task_name}", + attached_to=f"{self.task_name}/ClassificationHead", params={"task": self.task}, ), MetricModuleConfig( name="Recall", - alias=f"Recall-{self.task_name}", - attached_to=f"ClassificationHead-{self.task_name}", + attached_to=f"{self.task_name}/ClassificationHead", params={"task": self.task}, ), ] @@ -132,8 +128,7 @@ def visualizers(self) -> list[AttachedModuleConfig]: return [ AttachedModuleConfig( name="ClassificationVisualizer", - alias=f"ClassificationVisualizer-{self.task_name}", - attached_to=f"ClassificationHead-{self.task_name}", + attached_to=f"{self.task_name}/ClassificationHead", params=self.visualizer_params, ) ] diff --git a/luxonis_train/config/predefined_models/detection_fomo_model.py b/luxonis_train/config/predefined_models/detection_fomo_model.py index 1a21a5a3..309a4572 100644 --- a/luxonis_train/config/predefined_models/detection_fomo_model.py +++ b/luxonis_train/config/predefined_models/detection_fomo_model.py @@ -9,7 +9,6 @@ ModelNodeConfig, Params, ) -from luxonis_train.enums import TaskType from .base_predefined_model import BasePredefinedModel @@ -51,8 +50,7 @@ def __init__( head_params: Params | None = None, loss_params: Params | None = None, kpt_visualizer_params: Params | None = None, - bbox_task_name: str | None = None, - kpt_task_name: str | None = None, + task_name: str = "", ): var_config = get_variant(variant) @@ -61,29 +59,23 @@ def __init__( self.head_params = head_params or var_config.head_params self.loss_params = loss_params or {} self.kpt_visualizer_params = kpt_visualizer_params or {} - self.bbox_task_name = ( - bbox_task_name or "boundingbox" - ) # Needed for OKS calculation - self.kpt_task_name = kpt_task_name or "keypoints" + self.task_name = task_name @property def nodes(self) -> list[ModelNodeConfig]: nodes = [ ModelNodeConfig( name=self.backbone, - alias=f"{self.backbone}-{self.kpt_task_name}", + alias=f"{self.task_name}/{self.backbone}", freezing=self.backbone_params.pop("freezing", {}), params=self.backbone_params, ), ModelNodeConfig( name="FOMOHead", - alias=f"FOMOHead-{self.kpt_task_name}", - inputs=[f"{self.backbone}-{self.kpt_task_name}"], + alias=f"{self.task_name}/FOMOHead", + inputs=[f"{self.task_name}/{self.backbone}"], params=self.head_params, - task_name={ - TaskType.BOUNDINGBOX: self.bbox_task_name, - TaskType.KEYPOINTS: self.kpt_task_name, - }, + task_name=self.task_name, ), ] return nodes @@ -93,8 +85,7 @@ def losses(self) -> list[LossModuleConfig]: return [ LossModuleConfig( name="FOMOLocalizationLoss", - alias=f"FOMOLocalizationLoss-{self.kpt_task_name}", - attached_to=f"FOMOHead-{self.kpt_task_name}", + attached_to=f"{self.task_name}/FOMOHead", params=self.loss_params, weight=1.0, ) @@ -105,8 +96,7 @@ def metrics(self) -> list[MetricModuleConfig]: return [ MetricModuleConfig( name="ObjectKeypointSimilarity", - alias=f"ObjectKeypointSimilarity-{self.kpt_task_name}", - attached_to=f"FOMOHead-{self.kpt_task_name}", + attached_to=f"{self.task_name}/FOMOHead", is_main_metric=True, ), ] @@ -116,8 +106,7 @@ def visualizers(self) -> list[AttachedModuleConfig]: return [ AttachedModuleConfig( name="MultiVisualizer", - alias=f"MultiVisualizer-{self.kpt_task_name}", - attached_to=f"FOMOHead-{self.kpt_task_name}", + attached_to=f"{self.task_name}/FOMOHead", params={ "visualizers": [ { diff --git a/luxonis_train/config/predefined_models/detection_model.py b/luxonis_train/config/predefined_models/detection_model.py index d1498845..cda9c503 100644 --- a/luxonis_train/config/predefined_models/detection_model.py +++ b/luxonis_train/config/predefined_models/detection_model.py @@ -65,7 +65,7 @@ def __init__( head_params: Params | None = None, loss_params: Params | None = None, visualizer_params: Params | None = None, - task_name: str | None = None, + task_name: str = "", ): var_config = get_variant(variant) @@ -89,7 +89,7 @@ def nodes(self) -> list[ModelNodeConfig]: nodes = [ ModelNodeConfig( name=self.backbone, - alias=f"{self.backbone}-{self.task_name}", + alias=f"{self.task_name}/{self.backbone}", freezing=self.backbone_params.pop("freezing", {}), params=self.backbone_params, ), @@ -98,8 +98,8 @@ def nodes(self) -> list[ModelNodeConfig]: nodes.append( ModelNodeConfig( name="RepPANNeck", - alias=f"RepPANNeck-{self.task_name}", - inputs=[f"{self.backbone}-{self.task_name}"], + alias=f"{self.task_name}/RepPANNeck", + inputs=[f"{self.task_name}/{self.backbone}"], freezing=self.neck_params.pop("freezing", {}), params=self.neck_params, ) @@ -108,11 +108,11 @@ def nodes(self) -> list[ModelNodeConfig]: nodes.append( ModelNodeConfig( name="EfficientBBoxHead", - alias=f"EfficientBBoxHead-{self.task_name}", + alias=f"{self.task_name}/EfficientBBoxHead", freezing=self.head_params.pop("freezing", {}), - inputs=[f"RepPANNeck-{self.task_name}"] + inputs=[f"{self.task_name}/RepPANNeck"] if self.use_neck - else [f"{self.backbone}-{self.task_name}"], + else [f"{self.task_name}/{self.backbone}"], params=self.head_params, task_name=self.task_name, ) @@ -125,8 +125,7 @@ def losses(self) -> list[LossModuleConfig]: return [ LossModuleConfig( name="AdaptiveDetectionLoss", - alias=f"AdaptiveDetectionLoss-{self.task_name}", - attached_to=f"EfficientBBoxHead-{self.task_name}", + attached_to=f"{self.task_name}/EfficientBBoxHead", params=self.loss_params, weight=1.0, ) @@ -138,8 +137,7 @@ def metrics(self) -> list[MetricModuleConfig]: return [ MetricModuleConfig( name="MeanAveragePrecision", - alias=f"MeanAveragePrecision-{self.task_name}", - attached_to=f"EfficientBBoxHead-{self.task_name}", + attached_to=f"{self.task_name}/EfficientBBoxHead", is_main_metric=True, ), ] @@ -150,8 +148,7 @@ def visualizers(self) -> list[AttachedModuleConfig]: return [ AttachedModuleConfig( name="BBoxVisualizer", - alias=f"BBoxVisualizer-{self.task_name}", - attached_to=f"EfficientBBoxHead-{self.task_name}", + attached_to=f"{self.task_name}/EfficientBBoxHead", params=self.visualizer_params, ) ] diff --git a/luxonis_train/config/predefined_models/keypoint_detection_model.py b/luxonis_train/config/predefined_models/keypoint_detection_model.py index 8882f338..88c7aa63 100644 --- a/luxonis_train/config/predefined_models/keypoint_detection_model.py +++ b/luxonis_train/config/predefined_models/keypoint_detection_model.py @@ -62,8 +62,7 @@ def __init__( loss_params: Params | None = None, kpt_visualizer_params: Params | None = None, bbox_visualizer_params: Params | None = None, - bbox_task_name: str | None = None, - kpt_task_name: str | None = None, + task_name: str = "", ): var_config = get_variant(variant) @@ -79,8 +78,7 @@ def __init__( self.loss_params = loss_params or {"n_warmup_epochs": 0} self.kpt_visualizer_params = kpt_visualizer_params or {} self.bbox_visualizer_params = bbox_visualizer_params or {} - self.bbox_task_name = bbox_task_name or "boundingbox" - self.kpt_task_name = kpt_task_name or "keypoints" + self.task_name = task_name @property def nodes(self) -> list[ModelNodeConfig]: @@ -89,7 +87,7 @@ def nodes(self) -> list[ModelNodeConfig]: nodes = [ ModelNodeConfig( name=self.backbone, - alias=f"{self.backbone}-{self.kpt_task_name}", + alias=f"{self.task_name}/{self.backbone}", freezing=self.backbone_params.pop("freezing", {}), params=self.backbone_params, ), @@ -98,31 +96,25 @@ def nodes(self) -> list[ModelNodeConfig]: nodes.append( ModelNodeConfig( name="RepPANNeck", - alias=f"RepPANNeck-{self.kpt_task_name}", - inputs=[f"{self.backbone}-{self.kpt_task_name}"], + alias=f"{self.task_name}/RepPANNeck", + inputs=[f"{self.task_name}/{self.backbone}"], freezing=self.neck_params.pop("freezing", {}), params=self.neck_params, ) ) - task = {} - if self.bbox_task_name is not None: - task["boundingbox"] = self.bbox_task_name - if self.kpt_task_name is not None: - task["keypoints"] = self.kpt_task_name - nodes.append( ModelNodeConfig( name="EfficientKeypointBBoxHead", - alias=f"EfficientKeypointBBoxHead-{self.kpt_task_name}", + alias=f"{self.task_name}/EfficientKeypointBBoxHead", inputs=( - [f"RepPANNeck-{self.kpt_task_name}"] + [f"{self.task_name}/RepPANNeck"] if self.use_neck - else [f"{self.backbone}-{self.kpt_task_name}"] + else [f"{self.task_name}/{self.backbone}"] ), freezing=self.head_params.pop("freezing", {}), params=self.head_params, - task_name=task, + task_name=self.task_name, ) ) return nodes @@ -133,8 +125,7 @@ def losses(self) -> list[LossModuleConfig]: return [ LossModuleConfig( name="EfficientKeypointBBoxLoss", - alias=f"EfficientKeypointBBoxLoss-{self.kpt_task_name}", - attached_to=f"EfficientKeypointBBoxHead-{self.kpt_task_name}", + attached_to=f"{self.task_name}/EfficientKeypointBBoxHead", params=self.loss_params, weight=1.0, ) @@ -146,14 +137,12 @@ def metrics(self) -> list[MetricModuleConfig]: return [ MetricModuleConfig( name="ObjectKeypointSimilarity", - alias=f"ObjectKeypointSimilarity-{self.kpt_task_name}", - attached_to=f"EfficientKeypointBBoxHead-{self.kpt_task_name}", + attached_to=f"{self.task_name}/EfficientKeypointBBoxHead", is_main_metric=True, ), MetricModuleConfig( name="MeanAveragePrecisionKeypoints", - alias=f"MeanAveragePrecisionKeypoints-{self.kpt_task_name}", - attached_to=f"EfficientKeypointBBoxHead-{self.kpt_task_name}", + attached_to=f"{self.task_name}/EfficientKeypointBBoxHead", ), ] @@ -164,8 +153,7 @@ def visualizers(self) -> list[AttachedModuleConfig]: return [ AttachedModuleConfig( name="MultiVisualizer", - alias=f"MultiVisualizer-{self.kpt_task_name}", - attached_to=f"EfficientKeypointBBoxHead-{self.kpt_task_name}", + attached_to=f"{self.task_name}/EfficientKeypointBBoxHead", params={ "visualizers": [ { diff --git a/luxonis_train/config/predefined_models/segmentation_model.py b/luxonis_train/config/predefined_models/segmentation_model.py index 0260e843..fb04d0f1 100644 --- a/luxonis_train/config/predefined_models/segmentation_model.py +++ b/luxonis_train/config/predefined_models/segmentation_model.py @@ -56,7 +56,7 @@ def __init__( loss_params: Params | None = None, visualizer_params: Params | None = None, task: Literal["binary", "multiclass"] = "binary", - task_name: str | None = None, + task_name: str = "", ): var_config = get_variant(variant) @@ -82,15 +82,15 @@ def nodes(self) -> list[ModelNodeConfig]: node_list = [ ModelNodeConfig( name=self.backbone, - alias=f"{self.backbone}-{self.task_name}", + alias=f"{self.task_name}/{self.backbone}", freezing=self.backbone_params.pop("freezing", {}), params=self.backbone_params, task_name=self.task_name, ), ModelNodeConfig( name="DDRNetSegmentationHead", - alias=f"DDRNetSegmentationHead-{self.task_name}", - inputs=[f"{self.backbone}-{self.task_name}"], + alias=f"{self.task_name}/DDRNetSegmentationHead", + inputs=[f"{self.task_name}/{self.backbone}"], freezing=self.head_params.pop("freezing", {}), params=self.head_params, task_name=self.task_name, @@ -100,8 +100,8 @@ def nodes(self) -> list[ModelNodeConfig]: node_list.append( ModelNodeConfig( name="DDRNetSegmentationHead", - alias=f"DDRNetSegmentationHead_aux-{self.task_name}", - inputs=[f"{self.backbone}-{self.task_name}"], + alias=f"{self.task_name}/DDRNetSegmentationHead_aux", + inputs=[f"{self.task_name}/{self.backbone}"], freezing=self.aux_head_params.pop("freezing", {}), params=self.aux_head_params, task_name=self.task_name, @@ -122,12 +122,7 @@ def losses(self) -> list[LossModuleConfig]: if self.task == "binary" else "OHEMCrossEntropyLoss" ), - alias=( - f"OHEMBCEWithLogitsLoss-{self.task_name}" - if self.task == "binary" - else f"OHEMCrossEntropyLoss-{self.task_name}" - ), - attached_to=f"DDRNetSegmentationHead-{self.task_name}", + attached_to=f"{self.task_name}/DDRNetSegmentationHead", params=self.loss_params, weight=1.0, ), @@ -140,12 +135,7 @@ def losses(self) -> list[LossModuleConfig]: if self.task == "binary" else "OHEMCrossEntropyLoss" ), - alias=( - f"OHEMBCEWithLogitsLoss_aux-{self.task_name}" - if self.task == "binary" - else f"OHEMCrossEntropyLoss_aux-{self.task_name}" - ), - attached_to=f"DDRNetSegmentationHead_aux-{self.task_name}", + attached_to=f"{self.task_name}/DDRNetSegmentationHead_aux", params=self.loss_params, weight=0.4, ) @@ -158,15 +148,13 @@ def metrics(self) -> list[MetricModuleConfig]: return [ MetricModuleConfig( name="JaccardIndex", - alias=f"JaccardIndex-{self.task_name}", - attached_to=f"DDRNetSegmentationHead-{self.task_name}", + attached_to=f"{self.task_name}/DDRNetSegmentationHead", is_main_metric=True, params={"task": self.task}, ), MetricModuleConfig( name="F1Score", - alias=f"F1Score-{self.task_name}", - attached_to=f"DDRNetSegmentationHead-{self.task_name}", + attached_to=f"{self.task_name}/DDRNetSegmentationHead", params={"task": self.task}, ), ] @@ -177,8 +165,7 @@ def visualizers(self) -> list[AttachedModuleConfig]: return [ AttachedModuleConfig( name="SegmentationVisualizer", - alias=f"SegmentationVisualizer-{self.task_name}", - attached_to=f"DDRNetSegmentationHead-{self.task_name}", + attached_to=f"{self.task_name}/DDRNetSegmentationHead", params=self.visualizer_params, ) ] From 82abeae938fc771ac534f8ad9349b5e9d6b817af Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Tue, 14 Jan 2025 07:43:51 -0600 Subject: [PATCH 13/31] updated attached modules --- .../attached_modules/base_attached_module.py | 34 ++++++++++++++----- .../losses/efficient_keypoint_bbox_loss.py | 4 ++- .../losses/fomo_localization_loss.py | 3 ++ .../mean_average_precision_keypoints.py | 22 ++++++------ .../metrics/object_keypoint_similarity.py | 2 ++ .../attached_modules/visualizers/utils.py | 2 +- 6 files changed, 46 insertions(+), 21 deletions(-) diff --git a/luxonis_train/attached_modules/base_attached_module.py b/luxonis_train/attached_modules/base_attached_module.py index 99461c96..a5f14761 100644 --- a/luxonis_train/attached_modules/base_attached_module.py +++ b/luxonis_train/attached_modules/base_attached_module.py @@ -73,13 +73,13 @@ def __init__(self, *, node: BaseNode | None = None): for label in self.supported_tasks ] module_supported = f"[{', '.join(module_supported)}]" - if not self.node.task_types: + if not self.node.tasks: raise IncompatibleException( f"Module '{self.name}' requires one of the following " f"labels or combinations of labels: {module_supported}, " f"but is connected to node '{self.node.name}' which does not specify any tasks." ) - node_tasks = set(self.node.task_types) + node_tasks = set(self.node.tasks) for required_labels in self.supported_tasks: if isinstance(required_labels, TaskType): required_labels = [required_labels] @@ -89,7 +89,7 @@ def __init__(self, *, node: BaseNode | None = None): self.required_labels = required_labels break else: - node_supported = [task.value for task in self.node.task_types] + node_supported = [task.value for task in self.node.tasks] raise IncompatibleException( f"Module '{self.name}' requires one of the following labels or combinations of labels: {module_supported}, " f"but is connected to node '{self.node.name}' which does not support any of them. " @@ -166,11 +166,11 @@ def node_tasks(self) -> list[TaskType]: @raises RuntimeError: If the node does not have the C{tasks} attribute set. """ - if self.node.task_types is None: + if self.node.tasks is None: raise RuntimeError( "Node must have the `tasks` attribute specified." ) - return self.node.task_types + return self.node.tasks def get_label( self, labels: Labels, task_type: TaskType | None = None @@ -214,8 +214,12 @@ def _get_label( task_name = self.node.task_name task = f"{task_name}/{task_type.value}" if task not in labels: - raise IncompatibleException.from_missing_task( - task_type.value, list(labels.keys()), self.name + raise IncompatibleException( + f"Module '{self.name}' requires label of type " + f"'{task_type.value}' assigned to task '{task_name}', " + "but the label is missing from the dataset. " + f"Available labels: {list(labels.keys())}. " + f"Missing label: '{task}'." ) return labels[task], task_type @@ -274,7 +278,19 @@ def get_input_tensors( f"{self.name} requires multiple labels, " "you must provide the `task_type` argument to extract the desired input." ) - task_type = self.node_tasks[0].value + if len(self.node_tasks) == 1: + task_type = self.node_tasks[0].value + else: + required_label = self.required_labels[0] + for task in self.node_tasks: + if task.value == required_label: + task_type = task.value + break + else: + raise IncompatibleException( + f"Task {required_label} is not supported by the node " + f"{self.node.name}." + ) return inputs[f"{self.node.task_name}/{task_type}"] def prepare( @@ -307,7 +323,7 @@ def prepare( @raises RuntimeError: If the C{tasks} attribute is not set on the node. @raises RuntimeError: If the C{supported_tasks} attribute is not set on the module. """ - if self.node.task_types is None: + if self.node.tasks is None: raise RuntimeError( f"{self.node.name} must have the `tasks` attribute specified " f"for {self.name} to make use of the default `prepare` method." diff --git a/luxonis_train/attached_modules/losses/efficient_keypoint_bbox_loss.py b/luxonis_train/attached_modules/losses/efficient_keypoint_bbox_loss.py index 701a3c72..d9a191e9 100644 --- a/luxonis_train/attached_modules/losses/efficient_keypoint_bbox_loss.py +++ b/luxonis_train/attached_modules/losses/efficient_keypoint_bbox_loss.py @@ -16,6 +16,7 @@ get_with_default, ) from luxonis_train.utils.boundingbox import IoUType +from luxonis_train.utils.keypoints import insert_class from .bce_with_logits import BCEWithLogitsLoss @@ -100,8 +101,9 @@ def prepare( pred_distri = self.get_input_tensors(inputs, "distributions")[0] pred_kpts = self.get_input_tensors(inputs, "keypoints_raw")[0] - target_kpts = self.get_label(labels, TaskType.KEYPOINTS) target_bbox = self.get_label(labels, TaskType.BOUNDINGBOX) + target_kpts = self.get_label(labels, TaskType.KEYPOINTS) + target_kpts = insert_class(target_kpts, target_bbox) batch_size = pred_scores.shape[0] n_kpts = (target_kpts.shape[1] - 2) // 3 diff --git a/luxonis_train/attached_modules/losses/fomo_localization_loss.py b/luxonis_train/attached_modules/losses/fomo_localization_loss.py index 0ad1ea60..1b181077 100644 --- a/luxonis_train/attached_modules/losses/fomo_localization_loss.py +++ b/luxonis_train/attached_modules/losses/fomo_localization_loss.py @@ -8,6 +8,7 @@ from luxonis_train.enums import TaskType from luxonis_train.nodes import FOMOHead from luxonis_train.utils import Labels, Packet +from luxonis_train.utils.keypoints import insert_class from .base_loss import BaseLoss @@ -37,6 +38,8 @@ def prepare( ) -> tuple[Tensor, Tensor]: heatmap = self.get_input_tensors(inputs, "features")[0] target_kpts = self.get_label(labels, TaskType.KEYPOINTS) + target_bbox = self.get_label(labels, TaskType.BOUNDINGBOX) + target_kpts = insert_class(target_kpts, target_bbox) batch_size, num_classes, height, width = heatmap.shape target_heatmap = torch.zeros( (batch_size, num_classes, height, width), device=heatmap.device diff --git a/luxonis_train/attached_modules/metrics/mean_average_precision_keypoints.py b/luxonis_train/attached_modules/metrics/mean_average_precision_keypoints.py index 6a6440df..6be5f2ec 100644 --- a/luxonis_train/attached_modules/metrics/mean_average_precision_keypoints.py +++ b/luxonis_train/attached_modules/metrics/mean_average_precision_keypoints.py @@ -15,6 +15,7 @@ get_sigmas, get_with_default, ) +from luxonis_train.utils.keypoints import insert_class from .base_metric import BaseMetric @@ -107,16 +108,17 @@ def prepare( self, inputs: Packet[Tensor], labels: Labels ) -> tuple[list[dict[str, Tensor]], list[dict[str, Tensor]]]: assert self.node.tasks is not None - kpts = self.get_label(labels, TaskType.KEYPOINTS) - boxes = self.get_label(labels, TaskType.BOUNDINGBOX) - - nkpts = (kpts.shape[1] - 2) // 3 - label = torch.zeros((len(boxes), nkpts * 3 + 6)) - label[:, :2] = boxes[:, :2] - label[:, 2:6] = box_convert(boxes[:, 2:], "xywh", "xyxy") - label[:, 6::3] = kpts[:, 2::3] # x - label[:, 7::3] = kpts[:, 3::3] # y - label[:, 8::3] = kpts[:, 4::3] # visiblity + kpts_labels = self.get_label(labels, TaskType.KEYPOINTS) + bbox_labels = self.get_label(labels, TaskType.BOUNDINGBOX) + kpts_labels = insert_class(kpts_labels, bbox_labels) + + n_kpts = (kpts_labels.shape[1] - 2) // 3 + label = torch.zeros((len(bbox_labels), n_kpts * 3 + 6)) + label[:, :2] = bbox_labels[:, :2] + label[:, 2:6] = box_convert(bbox_labels[:, 2:], "xywh", "xyxy") + label[:, 6::3] = kpts_labels[:, 2::3] # x + label[:, 7::3] = kpts_labels[:, 3::3] # y + label[:, 8::3] = kpts_labels[:, 4::3] # visiblity output_list_kpt_map: list[dict[str, Tensor]] = [] label_list_kpt_map: list[dict[str, Tensor]] = [] diff --git a/luxonis_train/attached_modules/metrics/object_keypoint_similarity.py b/luxonis_train/attached_modules/metrics/object_keypoint_similarity.py index f32051b3..ec7b930d 100644 --- a/luxonis_train/attached_modules/metrics/object_keypoint_similarity.py +++ b/luxonis_train/attached_modules/metrics/object_keypoint_similarity.py @@ -13,6 +13,7 @@ get_sigmas, get_with_default, ) +from luxonis_train.utils.keypoints import insert_class from .base_metric import BaseMetric @@ -77,6 +78,7 @@ def prepare( ) -> tuple[list[dict[str, Tensor]], list[dict[str, Tensor]]]: kpts_labels = self.get_label(labels, TaskType.KEYPOINTS) bbox_labels = self.get_label(labels, TaskType.BOUNDINGBOX) + kpts_labels = insert_class(kpts_labels, bbox_labels) n_keypoints = (kpts_labels.shape[1] - 2) // 3 label = torch.zeros((len(bbox_labels), n_keypoints * 3 + 6)) label[:, :2] = bbox_labels[:, :2] diff --git a/luxonis_train/attached_modules/visualizers/utils.py b/luxonis_train/attached_modules/visualizers/utils.py index 45ec454b..1a571eca 100644 --- a/luxonis_train/attached_modules/visualizers/utils.py +++ b/luxonis_train/attached_modules/visualizers/utils.py @@ -160,7 +160,7 @@ def draw_keypoint_labels(img: Tensor, label: Tensor, **kwargs) -> Tensor: @return: Image with keypoint labels drawn on. """ _, H, W = img.shape - keypoints_unflat = label[:, 1:].reshape(-1, 3) + keypoints_unflat = label.reshape(-1, 3) keypoints_points = keypoints_unflat[:, :2] keypoints_points[:, 0] *= W keypoints_points[:, 1] *= H From f48622bf5a92288181b8d60d4111d52b234a4446 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Tue, 14 Jan 2025 07:44:18 -0600 Subject: [PATCH 14/31] small changes --- luxonis_train/core/core.py | 2 +- luxonis_train/enums.py | 1 - luxonis_train/models/luxonis_lightning.py | 21 ++-- luxonis_train/nodes/base_node.py | 113 +++--------------- .../nodes/heads/ddrnet_segmentation_head.py | 2 +- .../nodes/heads/efficient_bbox_head.py | 4 +- luxonis_train/utils/__init__.py | 3 +- luxonis_train/utils/exceptions.py | 9 +- luxonis_train/utils/keypoints.py | 21 ++++ 9 files changed, 59 insertions(+), 117 deletions(-) diff --git a/luxonis_train/core/core.py b/luxonis_train/core/core.py index fe959349..2fb59457 100644 --- a/luxonis_train/core/core.py +++ b/luxonis_train/core/core.py @@ -721,7 +721,7 @@ def _mult(lst: list[float | int]) -> list[float]: self.cfg.trainer.preprocessing.normalize.params["std"] ), "dai_type": "RGB888p" - if self.cfg.trainer.preprocessing.out_image_format + if self.cfg.trainer.preprocessing.color_format == "RGB" else "BGR888p", } diff --git a/luxonis_train/enums.py b/luxonis_train/enums.py index ea719e1c..09d38fb2 100644 --- a/luxonis_train/enums.py +++ b/luxonis_train/enums.py @@ -9,5 +9,4 @@ class TaskType(str, Enum): INSTANCE_SEGMENTATION = "instance_segmentation" BOUNDINGBOX = "boundingbox" KEYPOINTS = "keypoints" - LABEL = "label" ARRAY = "array" diff --git a/luxonis_train/models/luxonis_lightning.py b/luxonis_train/models/luxonis_lightning.py index cebce43f..afd422ae 100644 --- a/luxonis_train/models/luxonis_lightning.py +++ b/luxonis_train/models/luxonis_lightning.py @@ -180,12 +180,16 @@ def __init__( node_cfg.freezing.unfreeze_after * epochs ) frozen_nodes.append((node_name, unfreeze_after)) - - if node_cfg.task_name is not None and Node.task_types is None: - raise ValueError( - f"Cannot define tasks for node {node_name}." - "This node doesn't specify any tasks." - ) + task_names = list(self.dataset_metadata.task_names) + if not node_cfg.task_name: + if len(task_names) == 1: + node_cfg.task_name = task_names[0] + elif issubclass(Node, BaseHead): + raise ValueError( + f"Dataset contains multiple tasks: {task_names}. " + f"Node {node_name} does not have the `task_name` parameter set. " + "Please specify the `task_name` parameter for each head node. " + ) nodes[node_name] = ( Node, @@ -1030,7 +1034,10 @@ def _print_results( ) if self.main_metric is not None: - main_metric_node, main_metric_name = self.main_metric.split("/") + print(self.main_metric) + *main_metric_node, main_metric_name = self.main_metric.split("/") + main_metric_node = "/".join(main_metric_node) + main_metric = metrics[main_metric_node][main_metric_name] logger.info( f"{stage} main metric ({self.main_metric}): {main_metric:.4f}" diff --git a/luxonis_train/nodes/base_node.py b/luxonis_train/nodes/base_node.py index eff2c2d1..19ec33a7 100644 --- a/luxonis_train/nodes/base_node.py +++ b/luxonis_train/nodes/base_node.py @@ -109,7 +109,7 @@ def wrap(output: Tensor) -> Packet[Tensor]: """ attach_index: AttachIndexType - task_types: list[TaskType] | None = None + tasks: list[TaskType] | None = None def __init__( self, @@ -154,10 +154,6 @@ def __init__( of outputs or C{"all"} to specify all outputs. Defaults to "all". Python indexing conventions apply. If provided as a constructor argument, overrides the class attribute. - @type _tasks: dict[TaskType, str] | None - @param _tasks: Dictionary of tasks that the node supports. - Overrides the class L{tasks} attribute. Shouldn't be - provided by the user in most cases. """ super().__init__() @@ -169,15 +165,15 @@ def __init__( ) self.attach_index = attach_index - self.task_name = task_name if task_name is None and dataset_metadata is not None: if len(dataset_metadata.task_names) == 1: - self.task_name = next(iter(dataset_metadata.task_names)) + task_name = next(iter(dataset_metadata.task_names)) else: raise ValueError( f"Dataset contain multiple tasks, but the `task_name` " f"argument for node '{self.name}' was not provided." ) + self.task_name = task_name or "" if getattr(self, "attach_index", None) is None: parameters = inspect.signature(self.forward).parameters @@ -228,117 +224,38 @@ def _check_type_overrides(self) -> None: def name(self) -> str: return self.__class__.__name__ - @property - def task_type(self) -> str: - """Getter for the task type. - - @type: str - @raises RuntimeError: If the node doesn't define any task. - @raises ValueError: If the node defines more than one task. In - that case, use the L{get_task_name} method instead. - """ - if not self.task_types: - raise RuntimeError(f"{self.name} does not define any task.") - - if len(self.task_types) > 1: - raise ValueError( - f"Node {self.name} has multiple tasks defined. " - "Use the `get_task_name` method instead." - ) - return self.task_types[0].value - @property def n_keypoints(self) -> int: """Getter for the number of keypoints. @type: int @raises ValueError: If the node does not support keypoints. - @raises RuntimeError: If the node doesn't define any task. """ if self._n_keypoints is not None: return self._n_keypoints - if self.task_types: - if TaskType.KEYPOINTS not in self.task_types: - raise ValueError(f"{self.name} does not support keypoints.") - return self.dataset_metadata.n_keypoints(self.task_name) - - raise RuntimeError( - f"{self.name} does not have any tasks defined, " - "`BaseNode.n_keypoints` property cannot be used. " - "Either override the `tasks` class attribute, " - "pass the `n_keypoints` attribute to the constructor or call " - "the `BaseNode.dataset_metadata.get_n_keypoints` method manually." - ) + if TaskType.KEYPOINTS not in (self.tasks or []): + raise ValueError(f"{self.name} does not support keypoints.") + return self.dataset_metadata.n_keypoints(self.task_name) @property def n_classes(self) -> int: """Getter for the number of classes. @type: int - @raises RuntimeError: If the node doesn't define any task. - @raises ValueError: If the number of classes is different for - different tasks. In that case, use the L{get_n_classes} - method. """ if self._n_classes is not None: return self._n_classes - if not self.task_types: - raise RuntimeError( - f"{self.name} does not have any tasks defined, " - "`BaseNode.n_classes` property cannot be used. " - "Either override the `tasks` class attribute, " - "pass the `n_classes` attribute to the constructor or call " - "the `BaseNode.dataset_metadata.n_classes` method manually." - ) - elif len(self.task_types) == 1: - return self.dataset_metadata.n_classes(self.task_name) - else: - n_classes = [ - self.dataset_metadata.n_classes(self.task_name) - for task in self.task_types - ] - if len(set(n_classes)) == 1: - return n_classes[0] - raise ValueError( - "Node defines multiple tasks but they have different number of classes. " - "This is likely an error, as the number of classes should be the same." - "If it is intended, use `BaseNode.get_n_classes` instead." - ) + return self.dataset_metadata.n_classes(self.task_name) @property def class_names(self) -> list[str]: """Getter for the class names. @type: list[str] - @raises RuntimeError: If the node doesn't define any task. - @raises ValueError: If the class names are different for - different tasks. In that case, use the L{get_class_names} - method. """ - if not self.task_types: - raise RuntimeError( - f"{self.name} does not have any tasks defined, " - "`BaseNode.class_names` property cannot be used. " - "Either override the `tasks` class attribute, " - "pass the `n_classes` attribute to the constructor or call " - "the `BaseNode.dataset_metadata.class_names` method manually." - ) - elif len(self.task_types) == 1: - return self.dataset_metadata.classes(self.task_name) - else: - class_names = [ - self.dataset_metadata.classes(self.task_name) - for task in self.task_types - ] - if all(set(names) == set(class_names[0]) for names in class_names): - return class_names[0] - raise ValueError( - "Node defines multiple tasks but they have different class names. " - "This is likely an error, as the class names should be the same. " - "If it is intended, use `BaseNode.get_class_names` instead." - ) + return self.dataset_metadata.classes(self.task_name) @property def input_shapes(self) -> list[Packet[Size]]: @@ -587,10 +504,14 @@ def wrap(self, output: ForwardOutputT) -> Packet[Tensor]: raise ValueError( "Default `wrap` expects a single tensor or a list of tensors." ) - try: - task = f"{self.task_name}/{self.task_type}" - except RuntimeError: - task = "features" + if not self.tasks: + return {"features": outputs} + if len(self.tasks) > 1: + raise RuntimeError( + f"Node {self.name} defines multiple tasks. " + "The `wrap` method should be overridden." + ) + task = f"{self.task_name or ''}/{self.tasks[0].value}" return {task: outputs} def run(self, inputs: list[Packet[Tensor]]) -> Packet[Tensor]: @@ -609,7 +530,7 @@ def run(self, inputs: list[Packet[Tensor]]) -> Packet[Tensor]: unwrapped = self.unwrap(inputs) outputs = self(unwrapped) wrapped = self.wrap(outputs) - str_tasks = [task.value for task in self.task_types or []] + str_tasks = [task.value for task in self.tasks or []] for key in list(wrapped.keys()): if key in str_tasks: assert self.task_name is not None diff --git a/luxonis_train/nodes/heads/ddrnet_segmentation_head.py b/luxonis_train/nodes/heads/ddrnet_segmentation_head.py index 35c5e69c..2b313ab6 100644 --- a/luxonis_train/nodes/heads/ddrnet_segmentation_head.py +++ b/luxonis_train/nodes/heads/ddrnet_segmentation_head.py @@ -18,7 +18,7 @@ class DDRNetSegmentationHead(BaseHead[Tensor, Tensor]): in_width: int in_channels: int - task_types: list[TaskType] = [TaskType.SEGMENTATION] + tasks: list[TaskType] = [TaskType.SEGMENTATION] parser: str = "SegmentationParser" def __init__( diff --git a/luxonis_train/nodes/heads/efficient_bbox_head.py b/luxonis_train/nodes/heads/efficient_bbox_head.py index 2eb57dfb..8d1a55f4 100644 --- a/luxonis_train/nodes/heads/efficient_bbox_head.py +++ b/luxonis_train/nodes/heads/efficient_bbox_head.py @@ -21,7 +21,7 @@ class EfficientBBoxHead( BaseHead[list[Tensor], tuple[list[Tensor], list[Tensor], list[Tensor]]], ): in_channels: list[int] - task_types: list[TaskType] = [TaskType.BOUNDINGBOX] + tasks: list[TaskType] = [TaskType.BOUNDINGBOX] parser = "YOLO" def __init__( @@ -171,7 +171,7 @@ def wrap( conf, _ = out_cls.max(1, keepdim=True) out = torch.cat([out_reg, conf, out_cls], dim=1) outputs.append(out) - return {self.task_type: outputs} + return {"boundingbox": outputs} cls_tensor = torch.cat( [cls_score_list[i].flatten(2) for i in range(len(cls_score_list))], diff --git a/luxonis_train/utils/__init__.py b/luxonis_train/utils/__init__.py index 2944dfde..d7c0be9f 100644 --- a/luxonis_train/utils/__init__.py +++ b/luxonis_train/utils/__init__.py @@ -16,7 +16,7 @@ to_shape_packet, ) from .graph import traverse_graph -from .keypoints import get_sigmas +from .keypoints import get_sigmas, insert_class from .tracker import LuxonisTrackerPL from .types import AttachIndexType, Kwargs, Labels, Packet @@ -41,4 +41,5 @@ "compute_iou_loss", "get_sigmas", "traverse_graph", + "insert_class", ] diff --git a/luxonis_train/utils/exceptions.py b/luxonis_train/utils/exceptions.py index bab8c1aa..811b1a29 100644 --- a/luxonis_train/utils/exceptions.py +++ b/luxonis_train/utils/exceptions.py @@ -2,11 +2,4 @@ class IncompatibleException(Exception): """Raised when two parts of the model are incompatible with each other.""" - @classmethod - def from_missing_task( - cls, task: str, present_tasks: list[str], class_name: str - ): - return cls( - f"{class_name} requires '{task}' label, but it was not found in " - f"the label dictionary. Available labels: {present_tasks}." - ) + pass diff --git a/luxonis_train/utils/keypoints.py b/luxonis_train/utils/keypoints.py index 8073c399..7eaac550 100644 --- a/luxonis_train/utils/keypoints.py +++ b/luxonis_train/utils/keypoints.py @@ -65,3 +65,24 @@ def get_sigmas( msg = f"[{caller_name}] {msg}" logger.info(msg) return torch.tensor([0.04] * n_keypoints, dtype=torch.float32) + + +def insert_class(keypoints: Tensor, bboxes: Tensor) -> Tensor: + """Insert class index into keypoints tensor. + + @type keypoints: Tensor + @param keypoints: Tensor of keypoints. + @type bboxes: Tensor + @param bboxes: Tensor of bounding boxes with class index. + @rtype: Tensor + @return: Tensor of keypoints with class index. + """ + classes = bboxes[:, 1] + return torch.cat( + ( + keypoints[:, :1], + classes.unsqueeze(-1), + keypoints[:, 1:], + ), + dim=-1, + ) From 7c244af626ac3160bc6e65b43e50da0eff84a254 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Tue, 14 Jan 2025 07:44:25 -0600 Subject: [PATCH 15/31] updated tests --- tests/configs/cli_commands.yaml | 2 +- tests/configs/multi_input.yaml | 78 +++----- tests/configs/parking_lot_config.yaml | 127 +++++-------- tests/integration/conftest.py | 164 +---------------- tests/integration/multi_input_modules.py | 20 +- tests/integration/parking_lot.json | 171 +++++++----------- tests/integration/test_detection.py | 7 +- .../test_fixed_validation_batch_limit.py | 6 +- tests/integration/test_fomo_detection.py | 22 +-- tests/integration/test_segmentation.py | 12 +- .../test_unsupervised_anomaly_detection.py | 20 +- tests/unittests/test_base_attached_module.py | 74 +++++--- tests/unittests/test_base_node.py | 61 +------ .../test_loaders/test_base_loader.py | 45 ++--- 14 files changed, 251 insertions(+), 558 deletions(-) diff --git a/tests/configs/cli_commands.yaml b/tests/configs/cli_commands.yaml index 56f77ef9..7123534c 100644 --- a/tests/configs/cli_commands.yaml +++ b/tests/configs/cli_commands.yaml @@ -51,7 +51,7 @@ trainer: preprocessing: train_image_size: [256, 320] keep_aspect_ratio: true - train_rgb: true + color_format: RGB normalize: active: true diff --git a/tests/configs/multi_input.yaml b/tests/configs/multi_input.yaml index 7db03d90..dcf92cdd 100644 --- a/tests/configs/multi_input.yaml +++ b/tests/configs/multi_input.yaml @@ -12,83 +12,57 @@ model: name: example_multi_input nodes: - name: FullBackbone - alias: full_backbone - name: RGBDBackbone - alias: rgbd_backbone input_sources: - left - right - disparity - name: PointcloudBackbone - alias: pointcloud_backbone input_sources: - pointcloud - name: FusionNeck - alias: fusion_neck inputs: - - rgbd_backbone - - pointcloud_backbone + - RGBDBackbone + - PointcloudBackbone input_sources: - disparity - name: FusionNeck2 - alias: fusion_neck_2 inputs: - - rgbd_backbone - - pointcloud_backbone - - full_backbone + - RGBDBackbone + - PointcloudBackbone + - FullBackbone - name: CustomSegHead1 - alias: head_1 inputs: - - fusion_neck + - FusionNeck + losses: + - name: BCEWithLogitsLoss + metrics: + - name: JaccardIndex + is_main_metric: true + params: + task: binary + visualizers: + - name: SegmentationVisualizer - name: CustomSegHead2 - alias: head_2 inputs: - - fusion_neck - - fusion_neck_2 + - FusionNeck + - FusionNeck2 input_sources: - disparity - - losses: - - name: BCEWithLogitsLoss - alias: loss_1 - attached_to: head_1 - - - name: CrossEntropyLoss - alias: loss_2 - attached_to: head_2 - - metrics: - - name: JaccardIndex - alias: jaccard_index_1 - attached_to: head_1 - is_main_metric: True - params: - task: binary - - - name: JaccardIndex - alias: jaccard_index_2 - attached_to: head_2 - params: - task: binary - - visualizers: - - name: SegmentationVisualizer - alias: seg_vis_1 - attached_to: head_1 - params: - colors: "#FF5055" - - - name: SegmentationVisualizer - alias: seg_vis_2 - attached_to: head_2 - params: - colors: "#55AAFF" + losses: + - name: CrossEntropyLoss + metrics: + - name: JaccardIndex + params: + task: binary + visualizers: + - name: SegmentationVisualizer tracker: project_name: multi_input_example @@ -111,4 +85,4 @@ trainer: exporter: onnx: - opset_version: 11 \ No newline at end of file + opset_version: 11 diff --git a/tests/configs/parking_lot_config.yaml b/tests/configs/parking_lot_config.yaml index 5cda65c1..2754197e 100644 --- a/tests/configs/parking_lot_config.yaml +++ b/tests/configs/parking_lot_config.yaml @@ -4,103 +4,58 @@ model: nodes: - name: EfficientRep - alias: backbone - name: RepPANNeck - alias: neck - inputs: - - backbone - name: EfficientBBoxHead - alias: bbox-head - inputs: - - neck + task_name: vehicle_type + losses: + - name: AdaptiveDetectionLoss + metrics: + - name: MeanAveragePrecision + is_main_metric: true + visualizers: + - name: BBoxVisualizer - name: EfficientKeypointBBoxHead - alias: motorbike-detection-head - task: - keypoints: motorbike-keypoints - boundingbox: motorbike-boundingbox - inputs: - - neck + task_name: motorbike + losses: + - name: EfficientKeypointBBoxLoss + metrics: + - name: MeanAveragePrecisionKeypoints + visualizers: + - name: MultiVisualizer + params: + visualizers: + - name: KeypointVisualizer + - name: BBoxVisualizer - name: SegmentationHead - alias: color-segmentation-head - task: color-segmentation - inputs: - - neck - - - name: SegmentationHead - alias: any-vehicle-segmentation-head - task: vehicle-segmentation - inputs: - - neck + task_name: color + losses: + - name: CrossEntropyLoss + metrics: + - name: JaccardIndex + visualizers: + - name: SegmentationVisualizer - name: BiSeNetHead - alias: brand-segmentation-head - task: brand-segmentation - inputs: - - neck + task_name: brand + losses: + - name: CrossEntropyLoss + metrics: + - name: Precision + visualizers: + - name: SegmentationVisualizer - name: BiSeNetHead - alias: vehicle-type-segmentation-head - task: vehicle_type-segmentation - inputs: - - neck - - losses: - - name: AdaptiveDetectionLoss - attached_to: bbox-head - - name: BCEWithLogitsLoss - attached_to: any-vehicle-segmentation-head - - name: CrossEntropyLoss - attached_to: vehicle-type-segmentation-head - - name: CrossEntropyLoss - attached_to: color-segmentation-head - - name: EfficientKeypointBBoxLoss - attached_to: motorbike-detection-head - - metrics: - - name: MeanAveragePrecisionKeypoints - attached_to: motorbike-detection-head - - name: MeanAveragePrecision - attached_to: bbox-head - is_main_metric: true - - name: F1Score - attached_to: any-vehicle-segmentation-head - - name: JaccardIndex - attached_to: color-segmentation-head - - name: Accuracy - attached_to: vehicle-type-segmentation-head - - name: Precision - attached_to: brand-segmentation-head - - visualizers: - - name: MultiVisualizer - alias: multi-visualizer-motorbike - attached_to: motorbike-detection-head - params: - visualizers: - - name: KeypointVisualizer - params: - nonvisible_color: blue - - name: BBoxVisualizer - - - name: SegmentationVisualizer - alias: color-segmentation-visualizer - attached_to: color-segmentation-head - - name: SegmentationVisualizer - alias: vehicle-type-segmentation-visualizer - attached_to: vehicle-type-segmentation-head - - name: SegmentationVisualizer - alias: vehicle-segmentation-visualizer - attached_to: any-vehicle-segmentation-head - - name: SegmentationVisualizer - alias: brand-segmentation-visualizer - attached_to: brand-segmentation-head - - name: BBoxVisualizer - alias: bbox-visualizer - attached_to: bbox-head + task_name: vehicle_type + losses: + - name: CrossEntropyLoss + metrics: + - name: Accuracy + visualizers: + - name: SegmentationVisualizer tracker: project_name: Parking_Lot @@ -132,7 +87,7 @@ trainer: preprocessing: train_image_size: [256, 320] keep_aspect_ratio: false - train_rgb: true + color_format: RGB normalize: active: true augmentations: diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 97189476..ab2fb1e8 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -1,22 +1,17 @@ -import json import multiprocessing as mp import os import shutil -from collections import defaultdict from pathlib import Path from typing import Any -import cv2 import gdown -import numpy as np import pytest import torchvision from luxonis_ml.data import LuxonisDataset from luxonis_ml.data.parsers import LuxonisParser -from luxonis_ml.data.utils.data_utils import rgb_to_bool_masks -from luxonis_ml.utils import LuxonisFileSystem, environ +from luxonis_ml.utils import environ -WORK_DIR = Path("tests", "data") +WORK_DIR = Path("tests", "data").absolute() @pytest.fixture(scope="session") @@ -40,152 +35,14 @@ def train_overfit() -> bool: @pytest.fixture(scope="session") def parking_lot_dataset() -> LuxonisDataset: - url = "gs://luxonis-test-bucket/luxonis-ml-test-data/D1_ParkingSlotTest" - base_path = WORK_DIR / "D1_ParkingSlotTest" - if not base_path.exists(): - base_path = LuxonisFileSystem.download(url, WORK_DIR) - - mask_brand_path = base_path / "mask_brand" - mask_color_path = base_path / "mask_color" - kpt_mask_path = base_path / "keypoints_mask_vehicle" - - def generator(): - filenames: dict[int, Path] = {} - for base_path in [kpt_mask_path, mask_brand_path, mask_color_path]: - for sequence_path in sorted(list(base_path.glob("sequence.*"))): - frame_data = sequence_path / "step0.frame_data.json" - with open(frame_data) as f: - data = json.load(f)["captures"][0] - frame_data = data["annotations"] - sequence_num = int(sequence_path.suffix[1:]) - filename = data["filename"] - if filename is not None: - filename = sequence_path / filename - filenames[sequence_num] = filename - else: - filename = filenames[sequence_num] - W, H = data["dimension"] - - annotations = { - anno["@type"].split(".")[-1]: anno for anno in frame_data - } - - bbox_classes = {} - bboxes = {} - - for bbox_annotation in annotations.get( - "BoundingBox2DAnnotation", defaultdict(list) - )["values"]: - class_ = ( - bbox_annotation["labelName"].split("-")[-1].lower() - ) - if class_ == "motorbiek": - class_ = "motorbike" - x, y = bbox_annotation["origin"] - w, h = bbox_annotation["dimension"] - instance_id = bbox_annotation["instanceId"] - bbox_classes[instance_id] = class_ - bboxes[instance_id] = [x / W, y / H, w / W, h / H] - yield { - "file": filename, - "annotation": { - "type": "boundingbox", - "class": class_, - "x": x / W, - "y": y / H, - "w": w / W, - "h": h / H, - "instance_id": instance_id, - }, - } - - for kpt_annotation in annotations.get( - "KeypointAnnotation", defaultdict(list) - )["values"]: - keypoints = kpt_annotation["keypoints"] - instance_id = kpt_annotation["instanceId"] - class_ = bbox_classes[instance_id] - bbox = bboxes[instance_id] - kpts = [] - - if class_ == "motorbike": - keypoints = keypoints[:3] - else: - keypoints = keypoints[3:] - - for kp in keypoints: - x, y = kp["location"] - kpts.append([x / W, y / H, kp["state"]]) - - yield { - "file": filename, - "annotation": { - "type": "detection", - "class": class_, - "task": class_, - "keypoints": kpts, - "instance_id": instance_id, - "boundingbox": { - "x": bbox[0], - "y": bbox[1], - "w": bbox[2], - "h": bbox[3], - }, - }, - } - - vehicle_type_segmentation = annotations[ - "SemanticSegmentationAnnotation" - ] - mask = cv2.cvtColor( - cv2.imread( - str( - sequence_path - / vehicle_type_segmentation["filename"] - ) - ), - cv2.COLOR_BGR2RGB, - ) - classes = { - inst["labelName"]: inst["pixelValue"][:3] - for inst in vehicle_type_segmentation["instances"] - } - if base_path == kpt_mask_path: - task = "vehicle_type-segmentation" - elif base_path == mask_brand_path: - task = "brand-segmentation" - else: - task = "color-segmentation" - for class_, mask_ in rgb_to_bool_masks( - mask, classes, add_background_class=True - ): - yield { - "file": filename, - "annotation": { - "type": "mask", - "class": class_, - "task": task, - "mask": mask_, - }, - } - if base_path == mask_color_path: - yield { - "file": filename, - "annotation": { - "type": "mask", - "class": "vehicle", - "task": "vehicle-segmentation", - "mask": mask.astype(bool)[..., 0] - | mask.astype(bool)[..., 1] - | mask.astype(bool)[..., 2], - }, - } - - dataset = LuxonisDataset("_ParkingLot", delete_existing=True) - dataset.add(generator()) - np.random.seed(42) - dataset.make_splits() - return dataset + url = "gs://luxonis-test-bucket/luxonis-ml-test-data/D1_ParkingLot_Native.zip" + parser = LuxonisParser( + url, + dataset_name="_D1_ParkingLot", + delete_existing=True, + save_dir=WORK_DIR, + ) + return parser.parse(random_split=True) @pytest.fixture(scope="session") @@ -236,7 +93,6 @@ def CIFAR10_subset_generator(): yield { "file": path, "annotation": { - "type": "classification", "class": classes[label], }, } diff --git a/tests/integration/multi_input_modules.py b/tests/integration/multi_input_modules.py index 31db4e2f..3db26f5d 100644 --- a/tests/integration/multi_input_modules.py +++ b/tests/integration/multi_input_modules.py @@ -1,5 +1,6 @@ import torch from torch import Tensor, nn +from typing_extensions import override from luxonis_train.enums import TaskType from luxonis_train.loaders import BaseLoaderTorch @@ -8,8 +9,10 @@ class CustomMultiInputLoader(BaseLoaderTorch): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__( + self, view: str | list[str], image_source: str | None = None, **_ + ): + super().__init__(view=view, image_source=image_source) @property def input_shapes(self): @@ -36,17 +39,16 @@ def __getitem__(self, _): # pragma: no cover # Fake labels segmap = torch.zeros(1, 224, 224, dtype=torch.float32) segmap[0, 100:150, 100:150] = 1 - labels = { - "segmentation": (segmap, TaskType.SEGMENTATION), - } + labels = {"/segmentation": segmap} return inputs, labels def __len__(self): return 10 - def get_classes(self) -> dict[TaskType, list[str]]: - return {TaskType.SEGMENTATION: ["square"]} + @override + def get_classes(self) -> dict[str, list[str]]: + return {"": ["square"]} class MultiInputTestBaseNode(BaseNode): @@ -77,7 +79,7 @@ class FusionNeck2(MultiInputTestBaseNode): ... class CustomSegHead1(MultiInputTestBaseNode): - tasks = {TaskType.SEGMENTATION: "segmentation"} + tasks = [TaskType.SEGMENTATION] def __init__(self, **kwargs): super().__init__(**kwargs) @@ -92,7 +94,7 @@ def forward(self, inputs: Tensor): class CustomSegHead2(MultiInputTestBaseNode): - tasks = {TaskType.SEGMENTATION: "segmentation"} + tasks = [TaskType.SEGMENTATION] def __init__(self, **kwargs): super().__init__(**kwargs) diff --git a/tests/integration/parking_lot.json b/tests/integration/parking_lot.json index 4a3f4f6e..66ce962d 100644 --- a/tests/integration/parking_lot.json +++ b/tests/integration/parking_lot.json @@ -37,62 +37,51 @@ ], "outputs": [ { - "name": "any-vehicle-segmentation-head/vehicle-segmentation/0", + "name": "BiSeNetHead/brand/segmentation/0", "dtype": "float32", "shape": [ 1, - 1, + 23, 256, 320 ], "layout": "NCHW" }, { - "name": "output1_yolov6r2", + "name": "EfficientKeypointBBoxHead/outputs/0", "dtype": "float32", "shape": [ 1, - 7, + 14, 32, 40 ], "layout": "NCHW" }, { - "name": "output2_yolov6r2", + "name": "EfficientKeypointBBoxHead/outputs/1", "dtype": "float32", "shape": [ 1, - 7, + 14, 16, 20 ], "layout": "NCHW" }, { - "name": "output3_yolov6r2", + "name": "EfficientKeypointBBoxHead/outputs/2", "dtype": "float32", "shape": [ 1, - 7, + 14, 8, 10 ], - "layout": "NCHW" - }, - { - "name": "brand-segmentation-head/brand-segmentation/0", - "dtype": "float32", - "shape": [ - 1, - 23, - 256, - 320 - ], - "layout": "NCHW" + "layout": "NCDE" }, { - "name": "color-segmentation-head/color-segmentation/0", + "name": "SegmentationHead/color/segmentation/0", "dtype": "float32", "shape": [ 1, @@ -103,91 +92,100 @@ "layout": "NCHW" }, { - "name": "motorbike-detection-head/outputs/0", + "name": "output1_yolov6r2", "dtype": "float32", "shape": [ 1, - 14, + 7, 32, 40 ], "layout": "NCHW" }, { - "name": "motorbike-detection-head/outputs/1", + "name": "output2_yolov6r2", "dtype": "float32", "shape": [ 1, - 14, + 7, 16, 20 ], "layout": "NCHW" }, { - "name": "motorbike-detection-head/outputs/2", + "name": "output3_yolov6r2", "dtype": "float32", "shape": [ 1, - 14, + 7, 8, 10 ], - "layout": "NCDE" - }, - { - "name": "vehicle-type-segmentation-head/vehicle_type-segmentation/0", - "dtype": "float32", - "shape": [ - 1, - 3, - 256, - 320 - ], "layout": "NCHW" } ], "heads": [ { - "name": "vehicle-type-segmentation-head", + "name": "BiSeNetHead", "parser": "SegmentationParser", "metadata": { "postprocessor_path": null, "classes": [ "background", - "car", - "motorbike" + "alfa-romeo", + "buick", + "ducati", + "harley", + "ferrari", + "infiniti", + "jeep", + "land-rover", + "roll-royce", + "yamaha", + "aprilia", + "bmw", + "dodge", + "honda", + "moto", + "piaggio", + "isuzu", + "Kawasaki", + "truimph", + "pontiac", + "saab", + "chrysler" ], - "n_classes": 3, + "n_classes": 23, "is_softmax": false }, "outputs": [ - "vehicle-type-segmentation-head/vehicle_type-segmentation/0" + "BiSeNetHead/brand/segmentation/0" ] }, { - "name": "any-vehicle-segmentation-head", + "name": "BiSeNetHead_0", "parser": "SegmentationParser", "metadata": { "postprocessor_path": null, "classes": [ - "vehicle" + "background", + "car", + "motorbike" ], - "n_classes": 1, + "n_classes": 3, "is_softmax": false }, - "outputs": [ - "any-vehicle-segmentation-head/vehicle-segmentation/0" - ] + "outputs": [] }, { - "name": "bbox-head", + "name": "EfficientBBoxHead", "parser": "YOLO", "metadata": { "postprocessor_path": null, "classes": [ - "car", - "motorbike" + "motorbike", + "car" ], "n_classes": 2, "iou_threshold": 0.45, @@ -203,44 +201,29 @@ ] }, { - "name": "brand-segmentation-head", - "parser": "SegmentationParser", + "name": "EfficientKeypointBBoxHead", + "parser": "YOLOExtendedParser", "metadata": { "postprocessor_path": null, "classes": [ - "Kawasaki", - "alfa-romeo", - "aprilia", - "background", - "bmw", - "buick", - "chrysler", - "dodge", - "ducati", - "ferrari", - "harley", - "honda", - "infiniti", - "isuzu", - "jeep", - "land-rover", - "moto", - "piaggio", - "pontiac", - "roll-royce", - "saab", - "truimph", - "yamaha" + "motorbike" ], - "n_classes": 23, - "is_softmax": false + "n_classes": 1, + "iou_threshold": 0.45, + "conf_threshold": 0.25, + "max_det": 300, + "anchors": null, + "subtype": "yolov6", + "n_keypoints": 3 }, "outputs": [ - "brand-segmentation-head/brand-segmentation/0" + "EfficientKeypointBBoxHead/outputs/0", + "EfficientKeypointBBoxHead/outputs/1", + "EfficientKeypointBBoxHead/outputs/2" ] }, { - "name": "color-segmentation-head", + "name": "SegmentationHead", "parser": "SegmentationParser", "metadata": { "postprocessor_path": null, @@ -254,31 +237,9 @@ "is_softmax": false }, "outputs": [ - "color-segmentation-head/color-segmentation/0" - ] - }, - { - "name": "motorbike-detection-head", - "parser": "YOLOExtendedParser", - "metadata": { - "postprocessor_path": null, - "classes": [ - "motorbike" - ], - "n_classes": 1, - "iou_threshold": 0.45, - "conf_threshold": 0.25, - "max_det": 300, - "anchors": null, - "subtype": "yolov6", - "n_keypoints": 3 - }, - "outputs": [ - "motorbike-detection-head/outputs/0", - "motorbike-detection-head/outputs/1", - "motorbike-detection-head/outputs/2" + "SegmentationHead/color/segmentation/0" ] } ] } -} \ No newline at end of file +} diff --git a/tests/integration/test_detection.py b/tests/integration/test_detection.py index 45e83f0a..6360ec79 100644 --- a/tests/integration/test_detection.py +++ b/tests/integration/test_detection.py @@ -16,14 +16,12 @@ def get_opts_backbone(backbone: str) -> dict[str, Any]: }, { "name": "EfficientBBoxHead", + "task_name": "vehicle_type", "inputs": [backbone], }, { "name": "EfficientKeypointBBoxHead", - "task": { - "keypoints": "car-keypoints", - "boundingbox": "car-boundingbox", - }, + "task_name": "car", "inputs": [backbone], }, ], @@ -70,6 +68,7 @@ def get_opts_variant(variant: str) -> dict[str, Any]: }, { "name": "EfficientBBoxHead", + "task_name": "motorbike", "inputs": ["neck"], }, ], diff --git a/tests/integration/test_fixed_validation_batch_limit.py b/tests/integration/test_fixed_validation_batch_limit.py index 25794cec..5a7a398c 100644 --- a/tests/integration/test_fixed_validation_batch_limit.py +++ b/tests/integration/test_fixed_validation_batch_limit.py @@ -19,7 +19,11 @@ def get_config() -> dict[str, Any]: "alias": "neck", "inputs": ["backbone"], }, - {"name": "EfficientBBoxHead", "inputs": ["neck"]}, + { + "name": "EfficientBBoxHead", + "task_name": "motorbike", + "inputs": ["neck"], + }, ], "losses": [ { diff --git a/tests/integration/test_fomo_detection.py b/tests/integration/test_fomo_detection.py index cd1d89fd..49a81da8 100644 --- a/tests/integration/test_fomo_detection.py +++ b/tests/integration/test_fomo_detection.py @@ -50,26 +50,18 @@ def dummy_generator(image_paths: List[Path]): ) for i, bbox in enumerate(bboxes): - # Generate bounding box annotation yield { "file": path, "annotation": { - "type": "boundingbox", - "instance_id": i, "class": "object", - "x": bbox["x"], - "y": bbox["y"], - "w": bbox["w"], - "h": bbox["h"], - }, - } - yield { - "file": path, - "annotation": { - "type": "keypoints", "instance_id": i, - "class": "object", - "keypoints": [keypoints[i]], + "boundingbox": { + "x": bbox["x"], + "y": bbox["y"], + "w": bbox["w"], + "h": bbox["h"], + }, + "keypoints": {"keypoints": [keypoints[i]]}, }, } diff --git a/tests/integration/test_segmentation.py b/tests/integration/test_segmentation.py index a8b4df91..1c79dc29 100644 --- a/tests/integration/test_segmentation.py +++ b/tests/integration/test_segmentation.py @@ -17,37 +17,37 @@ def get_opts(backbone: str) -> dict[str, Any]: { "name": "SegmentationHead", "alias": "seg-color-segmentation", - "task": "color-segmentation", + "task_name": "color", "inputs": [backbone], }, { "name": "BiSeNetHead", "alias": "bi-color-segmentation", - "task": "color-segmentation", + "task_name": "color", "inputs": [backbone], }, { "name": "SegmentationHead", "alias": "seg-vehicle-segmentation", - "task": "vehicle-segmentation", + "task_name": "vehicles", "inputs": [backbone], }, { "name": "BiSeNetHead", "alias": "bi-vehicle-segmentation", - "task": "vehicle-segmentation", + "task_name": "vehicles", "inputs": [backbone], }, { "name": "SegmentationHead", "alias": "seg-vehicle-segmentation-2", - "task": "vehicle-segmentation", + "task_name": "vehicles", "inputs": [backbone], }, { "name": "SegmentationHead", "alias": "seg-vehicle-segmentation-3", - "task": "vehicle-segmentation", + "task_name": "vehicles", "inputs": [backbone], }, ], diff --git a/tests/integration/test_unsupervised_anomaly_detection.py b/tests/integration/test_unsupervised_anomaly_detection.py index a98faa27..10c0a6f9 100644 --- a/tests/integration/test_unsupervised_anomaly_detection.py +++ b/tests/integration/test_unsupervised_anomaly_detection.py @@ -70,11 +70,12 @@ def dummy_generator( yield { "file": path, "annotation": { - "type": "rle", "class": "object", - "height": 256, - "width": 256, - "counts": "0" * (256 * 256), + "segmentation": { + "height": 256, + "width": 256, + "counts": "0" * (256 * 256), + }, }, } @@ -94,11 +95,14 @@ def dummy_generator( yield { "file": path, "annotation": { - "type": "polyline", "class": "object", - "points": [ - pt for segment in poly_normalized for pt in segment - ], + "segmentation": { + "height": img_h, + "width": img_w, + "points": [ + pt for segment in poly_normalized for pt in segment + ], + }, }, } diff --git a/tests/unittests/test_base_attached_module.py b/tests/unittests/test_base_attached_module.py index 77cab14b..c7cd1508 100644 --- a/tests/unittests/test_base_attached_module.py +++ b/tests/unittests/test_base_attached_module.py @@ -1,8 +1,17 @@ import pytest +import torch +from torch import Tensor from luxonis_train import BaseLoss, BaseNode from luxonis_train.enums import TaskType from luxonis_train.utils.exceptions import IncompatibleException +from luxonis_train.utils.types import Labels, Packet + +SEGMENTATION_ARRAY = torch.tensor([0]) +KEYPOINT_ARRAY = torch.tensor([1]) +BOUNDINGBOX_ARRAY = torch.tensor([2]) +CLASSIFICATION_ARRAY = torch.tensor([3]) +FEATURES_ARRAY = torch.tensor([4]) class DummyBackbone(BaseNode): @@ -41,20 +50,20 @@ def forward(self, _): ... @pytest.fixture -def labels(): +def labels() -> Labels: return { - "segmentation": ("segmentation", TaskType.SEGMENTATION), - "keypoints": ("keypoints", TaskType.KEYPOINTS), - "boundingbox": ("boundingbox", TaskType.BOUNDINGBOX), - "classification": ("classification", TaskType.CLASSIFICATION), + "/segmentation": SEGMENTATION_ARRAY, + "/keypoints": KEYPOINT_ARRAY, + "/boundingbox": BOUNDINGBOX_ARRAY, + "/classification": CLASSIFICATION_ARRAY, } @pytest.fixture -def inputs(): +def inputs() -> Packet[Tensor]: return { - "features": ["features"], - "segmentation": ["segmentation"], + "features": [FEATURES_ARRAY], + "/segmentation": [SEGMENTATION_ARRAY], } @@ -63,10 +72,10 @@ def test_valid_properties(): loss = DummyLoss(node=head) no_labels_loss = NoLabelLoss(node=head) assert loss.node == head - assert loss.node_tasks == {TaskType.SEGMENTATION: "segmentation"} + assert loss.node_tasks == [TaskType.SEGMENTATION] assert loss.required_labels == [TaskType.SEGMENTATION] assert no_labels_loss.node == head - assert no_labels_loss.node_tasks == {TaskType.SEGMENTATION: "segmentation"} + assert no_labels_loss.node_tasks == [TaskType.SEGMENTATION] assert no_labels_loss.required_labels == [] @@ -82,45 +91,49 @@ def test_invalid_properties(): _ = NoLabelLoss(node=backbone).node_tasks -def test_get_label(labels): +def test_get_label(labels: Labels): seg_head = DummySegmentationHead() det_head = DummyDetectionHead() seg_loss = DummyLoss(node=seg_head) - assert seg_loss.get_label(labels) == "segmentation" - assert seg_loss.get_label(labels, TaskType.SEGMENTATION) == "segmentation" + assert seg_loss.get_label(labels) == SEGMENTATION_ARRAY + assert ( + seg_loss.get_label(labels, TaskType.SEGMENTATION) == SEGMENTATION_ARRAY + ) - del labels["segmentation"] - labels["segmentation-task"] = ("segmentation", TaskType.SEGMENTATION) + del labels["/segmentation"] + labels["task/segmentation"] = SEGMENTATION_ARRAY with pytest.raises(IncompatibleException): seg_loss.get_label(labels) det_loss = DummyLoss(node=det_head) - assert det_loss.get_label(labels, TaskType.KEYPOINTS) == "keypoints" - assert det_loss.get_label(labels, TaskType.BOUNDINGBOX) == "boundingbox" + assert det_loss.get_label(labels, TaskType.KEYPOINTS) == KEYPOINT_ARRAY + assert ( + det_loss.get_label(labels, TaskType.BOUNDINGBOX) == BOUNDINGBOX_ARRAY + ) with pytest.raises(ValueError): det_loss.get_label(labels) - with pytest.raises(ValueError): + with pytest.raises(IncompatibleException): det_loss.get_label(labels, TaskType.SEGMENTATION) -def test_input_tensors(inputs): +def test_input_tensors(inputs: Packet[Tensor]): seg_head = DummySegmentationHead() seg_loss = DummyLoss(node=seg_head) - assert seg_loss.get_input_tensors(inputs) == ["segmentation"] - assert seg_loss.get_input_tensors(inputs, "segmentation") == [ - "segmentation" + assert seg_loss.get_input_tensors(inputs) == [SEGMENTATION_ARRAY] + assert seg_loss.get_input_tensors(inputs, "/segmentation") == [ + SEGMENTATION_ARRAY ] assert seg_loss.get_input_tensors(inputs, TaskType.SEGMENTATION) == [ - "segmentation" + SEGMENTATION_ARRAY ] with pytest.raises(IncompatibleException): seg_loss.get_input_tensors(inputs, TaskType.KEYPOINTS) with pytest.raises(IncompatibleException): - seg_loss.get_input_tensors(inputs, "keypoints") + seg_loss.get_input_tensors(inputs, "/keypoints") det_head = DummyDetectionHead() det_loss = DummyLoss(node=det_head) @@ -128,17 +141,20 @@ def test_input_tensors(inputs): det_loss.get_input_tensors(inputs) -def test_prepare(inputs, labels): +def test_prepare(inputs: Packet[Tensor], labels: Labels): backbone = DummyBackbone() seg_head = DummySegmentationHead() seg_loss = DummyLoss(node=seg_head) det_head = DummyDetectionHead() - assert seg_loss.prepare(inputs, labels) == ("segmentation", "segmentation") - inputs["segmentation"].append("segmentation2") assert seg_loss.prepare(inputs, labels) == ( - "segmentation2", - "segmentation", + SEGMENTATION_ARRAY, + SEGMENTATION_ARRAY, + ) + inputs["/segmentation"].append(FEATURES_ARRAY) + assert seg_loss.prepare(inputs, labels) == ( + FEATURES_ARRAY, + SEGMENTATION_ARRAY, ) with pytest.raises(RuntimeError): diff --git a/tests/unittests/test_base_node.py b/tests/unittests/test_base_node.py index 3ed284c3..e6cfd21d 100644 --- a/tests/unittests/test_base_node.py +++ b/tests/unittests/test_base_node.py @@ -2,9 +2,8 @@ import torch from torch import Size, Tensor -from luxonis_train.enums import TaskType from luxonis_train.nodes import AttachIndexType, BaseNode -from luxonis_train.utils import DatasetMetadata, Packet +from luxonis_train.utils import Packet from luxonis_train.utils.exceptions import IncompatibleException @@ -100,61 +99,3 @@ def forward(self, _): ... {"features": [Size((3, 224, 224)) for _ in range(3)]} ] ) - - -def test_tasks(): - class DummyHead(DummyNode): - tasks = [TaskType.CLASSIFICATION] - - class DummyMultiHead(DummyNode): - tasks = [TaskType.CLASSIFICATION, TaskType.SEGMENTATION] - - dummy_head = DummyHead() - dummy_node = DummyNode() - dummy_multi_head = DummyMultiHead(n_keypoints=4) - assert ( - dummy_head.get_task_name(TaskType.CLASSIFICATION) == "classification" - ) - assert dummy_head.task == "classification" - with pytest.raises(ValueError): - dummy_head.get_task_name(TaskType.SEGMENTATION) - - with pytest.raises(RuntimeError): - dummy_node.get_task_name(TaskType.SEGMENTATION) - - with pytest.raises(RuntimeError): - _ = dummy_node.task - - with pytest.raises(ValueError): - _ = dummy_multi_head.task - - metadata = DatasetMetadata( - classes={ - "segmentation": ["car", "person", "dog"], - "classification": ["car-class", "person-class"], - }, - n_keypoints={"color-segmentation": 0, "detection": 0}, - ) - - dummy_multi_head._dataset_metadata = metadata - assert dummy_multi_head.get_class_names(TaskType.SEGMENTATION) == [ - "car", - "person", - "dog", - ] - assert dummy_multi_head.get_class_names(TaskType.CLASSIFICATION) == [ - "car-class", - "person-class", - ] - assert dummy_multi_head.get_n_classes(TaskType.SEGMENTATION) == 3 - assert dummy_multi_head.get_n_classes(TaskType.CLASSIFICATION) == 2 - assert dummy_multi_head.n_keypoints == 4 - with pytest.raises(ValueError): - _ = dummy_head.n_keypoints - with pytest.raises(RuntimeError): - _ = dummy_node.n_keypoints - - dummy_head = DummyHead(n_classes=5) - assert dummy_head.n_classes == 5 - with pytest.raises(ValueError): - _ = dummy_multi_head.n_classes diff --git a/tests/unittests/test_loaders/test_base_loader.py b/tests/unittests/test_loaders/test_base_loader.py index 293b3c10..3cc9231f 100644 --- a/tests/unittests/test_loaders/test_base_loader.py +++ b/tests/unittests/test_loaders/test_base_loader.py @@ -2,7 +2,6 @@ import torch from torch import Size -from luxonis_train.enums import TaskType from luxonis_train.loaders import collate_fn @@ -40,22 +39,12 @@ def build_batch_element(): inputs[name] = torch.rand(shape, dtype=torch.float32) labels = { - "classification": ( - torch.randint(0, 2, (2,), dtype=torch.int64), - TaskType.CLASSIFICATION, - ), - "segmentation": ( - torch.randint(0, 2, (1, 224, 224), dtype=torch.int64), - TaskType.SEGMENTATION, - ), - "keypoints": ( - torch.rand(1, 52, dtype=torch.float32), - TaskType.KEYPOINTS, - ), - "boundingbox": ( - torch.rand(1, 5, dtype=torch.float32), - TaskType.BOUNDINGBOX, + "/classification": (torch.randint(0, 2, (2,), dtype=torch.int64)), + "/segmentation": ( + torch.randint(0, 2, (1, 224, 224), dtype=torch.int64) ), + "/keypoints": (torch.rand(1, 52, dtype=torch.float32)), + "/boundingbox": (torch.rand(1, 5, dtype=torch.float32)), } return inputs, labels @@ -69,26 +58,26 @@ def build_batch_element(): assert inputs["features"].dtype == torch.float32 with subtests.test("classification"): - assert "classification" in annotations - assert annotations["classification"][0].shape == (batch_size, 2) - assert annotations["classification"][0].dtype == torch.int64 + assert "/classification" in annotations + assert annotations["/classification"].shape == (batch_size, 2) + assert annotations["/classification"].dtype == torch.int64 with subtests.test("segmentation"): - assert "segmentation" in annotations - assert annotations["segmentation"][0].shape == ( + assert "/segmentation" in annotations + assert annotations["/segmentation"].shape == ( batch_size, 1, 224, 224, ) - assert annotations["segmentation"][0].dtype == torch.int64 + assert annotations["/segmentation"].dtype == torch.int64 with subtests.test("keypoints"): - assert "keypoints" in annotations - assert annotations["keypoints"][0].shape == (batch_size, 53) - assert annotations["keypoints"][0].dtype == torch.float32 + assert "/keypoints" in annotations + assert annotations["/keypoints"].shape == (batch_size, 53) + assert annotations["/keypoints"].dtype == torch.float32 with subtests.test("boundingbox"): - assert "boundingbox" in annotations - assert annotations["boundingbox"][0].shape == (batch_size, 6) - assert annotations["boundingbox"][0].dtype == torch.float32 + assert "/boundingbox" in annotations + assert annotations["/boundingbox"].shape == (batch_size, 6) + assert annotations["/boundingbox"].dtype == torch.float32 From 1de6f7413fc617c10df15696fffaf75dd16bc105 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Tue, 14 Jan 2025 07:47:55 -0600 Subject: [PATCH 16/31] fixed predefined classification --- luxonis_train/config/predefined_models/classification_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/luxonis_train/config/predefined_models/classification_model.py b/luxonis_train/config/predefined_models/classification_model.py index 86964e0a..dc67a96e 100644 --- a/luxonis_train/config/predefined_models/classification_model.py +++ b/luxonis_train/config/predefined_models/classification_model.py @@ -66,7 +66,7 @@ def __init__( self.loss_params = loss_params or {} self.visualizer_params = visualizer_params or {} self.task = task - self.task_name = task_name or "classification" + self.task_name = task_name or "" @property def nodes(self) -> list[ModelNodeConfig]: From 8c320146c86cf755aa031a16a281431e2c04c9b8 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Tue, 14 Jan 2025 07:55:18 -0600 Subject: [PATCH 17/31] docs --- luxonis_train/nodes/README.md | 2 +- luxonis_train/nodes/base_node.py | 14 ++++++++------ 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/luxonis_train/nodes/README.md b/luxonis_train/nodes/README.md index 31f1f6c2..92f1969e 100644 --- a/luxonis_train/nodes/README.md +++ b/luxonis_train/nodes/README.md @@ -40,7 +40,7 @@ In addition, the following class attributes can be overridden: | Key | Type | Default value | Description | | -------------- | ----------------------------------------------------------------- | ------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | `attach_index` | `int \| "all" \| tuple[int, int] \| tuple[int, int, int] \| None` | `None` | Index of previous output that the head attaches to. Each node has a sensible default. Usually should not be manually set in most cases. Can be either a single index, a slice (negative indexing is also supported), or `"all"` | -| `tasks` | `list[TaskType] \| Dict[TaskType, str] \| None` | `None` | Tasks supported by the node. Should be overridden for head nodes. Either a list of tasks or a dictionary mapping tasks to their default names | +| `tasks` | `list[TaskType] \| None` | `None` | List of tasks types supported by the node. Should be overridden for head nodes. | Additional parameters for specific nodes are listed below. diff --git a/luxonis_train/nodes/base_node.py b/luxonis_train/nodes/base_node.py index 19ec33a7..7dcdbcf3 100644 --- a/luxonis_train/nodes/base_node.py +++ b/luxonis_train/nodes/base_node.py @@ -93,18 +93,16 @@ def wrap(output: Tensor) -> Packet[Tensor]: # by the attached modules. return {"classification": [output]} - @type attach_index: AttachIndexType @ivar attach_index: Index of previous output that this node attaches to. Can be a single integer to specify a single output, a tuple of two or three integers to specify a range of outputs or C{"all"} to specify all outputs. Defaults to "all". Python indexing conventions apply. - @type tasks: list[TaskType] | dict[TaskType, str] | None - @ivar tasks: Dictionary of tasks that the node supports. Should be defined - by the user as a class attribute. The key is the task type and the value - is the name of the task. For example: - C{{TaskType.CLASSIFICATION: "classification"}}. + @type tasks: list[TaskType] | None + @ivar tasks: List of task types that the node supports. + Should be defined as a class attribute by the user. + For example C{[TaskType.CLASSIFICATION]}. Only needs to be defined for head nodes. """ @@ -154,6 +152,10 @@ def __init__( of outputs or C{"all"} to specify all outputs. Defaults to "all". Python indexing conventions apply. If provided as a constructor argument, overrides the class attribute. + @type task_name: str | None + @param task_name: Specifies which task group from the dataset to use + in case the dataset contains multiple tasks. Otherwise, the + task group is inferred from the dataset metadata. """ super().__init__() From 8d7685bb45cc7773e2a958f3894ba8225267cab4 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Tue, 14 Jan 2025 08:59:53 -0600 Subject: [PATCH 18/31] fix inspect --- luxonis_train/__main__.py | 55 ++++++++------------------------------- 1 file changed, 11 insertions(+), 44 deletions(-) diff --git a/luxonis_train/__main__.py b/luxonis_train/__main__.py index c0aae2dc..5fd11b1e 100644 --- a/luxonis_train/__main__.py +++ b/luxonis_train/__main__.py @@ -3,11 +3,10 @@ from pathlib import Path from typing import Annotated +import numpy as np import typer from luxonis_ml.utils import setup_logging -from luxonis_train.config import Config - setup_logging(use_rich=True) @@ -175,47 +174,16 @@ def inspect( To close the window press 'q' or 'Esc'. """ import cv2 - from luxonis_ml.data import Augmentations, LabelType from luxonis_ml.data.utils.visualizations import visualize - from luxonis_train.utils.registry import LOADERS - - cfg = Config.get_config(config, opts) - train_augmentations = Augmentations( - image_size=cfg.trainer.preprocessing.train_image_size, - augmentations=[ - i.model_dump() - for i in cfg.trainer.preprocessing.get_active_augmentations() - if i.name != "Normalize" - ], - train_rgb=cfg.trainer.preprocessing.train_rgb, - keep_aspect_ratio=cfg.trainer.preprocessing.keep_aspect_ratio, - ) - val_augmentations = Augmentations( - image_size=cfg.trainer.preprocessing.train_image_size, - augmentations=[ - i.model_dump() - for i in cfg.trainer.preprocessing.get_active_augmentations() - ], - train_rgb=cfg.trainer.preprocessing.train_rgb, - keep_aspect_ratio=cfg.trainer.preprocessing.keep_aspect_ratio, - only_normalize=True, - ) + from luxonis_train.core import LuxonisModel - Loader = LOADERS.get(cfg.loader.name) - loader = Loader( - augmentations=( - train_augmentations if view == "train" else val_augmentations - ), - view={ - "train": cfg.loader.train_view, - "val": cfg.loader.val_view, - "test": cfg.loader.test_view, - }[view], - image_source=cfg.loader.image_source, - **cfg.loader.params, - ) + opts = opts or [] + opts.extend(["trainer.preprocessing.normalize.active", "False"]) + + model = LuxonisModel(config, opts) + loader = model.loaders[view.value] for images, labels in loader: for img in images.values(): if len(img.shape) != 3: @@ -226,11 +194,10 @@ def inspect( k: v.numpy().transpose(1, 2, 0) for k, v in images.items() } main_image = np_images[loader.image_source] - main_image = cv2.cvtColor(main_image, cv2.COLOR_RGB2BGR) - np_labels = { - task: (label.numpy(), LabelType(task_type)) - for task, (label, task_type) in labels.items() - } + main_image = cv2.cvtColor(main_image, cv2.COLOR_RGB2BGR).astype( + np.uint8 + ) + np_labels = {task: label.numpy() for task, label in labels.items()} h, w, _ = main_image.shape new_h, new_w = int(h * size_multiplier), int(w * size_multiplier) From fbbbc26cd44fdbbc433587ae45f3a23b93b51aa1 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Wed, 15 Jan 2025 18:20:31 -0500 Subject: [PATCH 19/31] fixed tests --- .../reconstruction_segmentation_loss.py | 64 ++++--------- .../predefined_models/classification_model.py | 4 +- .../predefined_models/detection_model.py | 4 +- .../keypoint_detection_model.py | 4 +- .../predefined_models/segmentation_model.py | 4 +- luxonis_train/core/utils/infer_utils.py | 12 +-- luxonis_train/loaders/base_loader.py | 9 +- luxonis_train/loaders/luxonis_loader_torch.py | 11 ++- .../loaders/luxonis_perlin_loader_torch.py | 96 ++++++++----------- luxonis_train/loaders/perlin.py | 16 ++-- luxonis_train/models/luxonis_lightning.py | 3 +- tests/configs/smart_cfg_populate_config.yaml | 1 + tests/integration/parking_lot.json | 17 ++-- .../test_unsupervised_anomaly_detection.py | 23 +++-- .../test_metrics/test_confusion_matrix.py | 3 +- 15 files changed, 123 insertions(+), 148 deletions(-) diff --git a/luxonis_train/attached_modules/losses/reconstruction_segmentation_loss.py b/luxonis_train/attached_modules/losses/reconstruction_segmentation_loss.py index 1937ad0b..a5b50f2b 100644 --- a/luxonis_train/attached_modules/losses/reconstruction_segmentation_loss.py +++ b/luxonis_train/attached_modules/losses/reconstruction_segmentation_loss.py @@ -1,18 +1,14 @@ import logging from math import exp -from typing import Any, Literal, Union +from typing import Literal import torch -import torch.nn as nn import torch.nn.functional as F -from torch import Tensor +from torch import Tensor, nn from luxonis_train.enums import TaskType from luxonis_train.nodes import DiscSubNetHead -from luxonis_train.utils import ( - Labels, - Packet, -) +from luxonis_train.utils import Labels, Packet from .base_loss import BaseLoss from .softmax_focal_loss import SoftmaxFocalLoss @@ -30,7 +26,7 @@ def __init__( gamma: float = 2.0, reduction: Literal["none", "mean", "sum"] = "mean", smooth: float = 1e-5, - **kwargs: Any, + **kwargs, ): """ReconstructionSegmentationLoss implements a combined loss function for reconstruction and segmentation tasks. @@ -55,28 +51,17 @@ def __init__( self.loss_ssim = SSIM() def prepare( - self, - inputs: Packet[Tensor], - labels: Labels, + self, inputs: Packet[Tensor], labels: Labels ) -> tuple[Tensor, Tensor, Tensor, Tensor]: recon = self.get_input_tensors(inputs, "reconstructed")[0] - seg_out = self.get_input_tensors(inputs, "segmentation")[0] - an_mask = labels["segmentation"][0] - orig = labels["original"][0] - - return ( - orig, - recon, - seg_out, - an_mask, - ) + seg_out = self.get_input_tensors(inputs)[0] + an_mask = self.get_label(labels) + orig = labels[f"{self.node.task_name}/original/segmentation"] + + return orig, recon, seg_out, an_mask def forward( - self, - orig: Tensor, - recon: Tensor, - seg_out: Tensor, - an_mask: Tensor, + self, orig: Tensor, recon: Tensor, seg_out: Tensor, an_mask: Tensor ): l2 = self.loss_l2(recon, orig) ssim = self.loss_ssim(recon, orig) @@ -93,14 +78,14 @@ def forward( return total_loss, sub_losses -class SSIM(torch.nn.Module): +class SSIM(nn.Module): def __init__( self, window_size: int = 11, size_average: bool = True, val_range: float | None = None, ): - super(SSIM, self).__init__() + super().__init__() self.window_size = window_size self.size_average = size_average self.val_range = val_range @@ -123,7 +108,7 @@ def forward(self, img1: Tensor, img2: Tensor) -> Tensor: self.window = window self.channel = channel - s_score, ssim_map = ssim( + s_score = ssim( img1, img2, window=window, @@ -134,11 +119,9 @@ def forward(self, img1: Tensor, img2: Tensor) -> Tensor: def create_window(window_size: int, channel: int = 1) -> Tensor: - _1D_window = gaussian(window_size, 1.5).unsqueeze(1) - _2D_window = ( - _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) - ) - window = _2D_window.expand( + window_1d = gaussian(window_size, 1.5).unsqueeze(1) + widnow_2d = window_1d.mm(window_1d.t()).float().unsqueeze(0).unsqueeze(0) + window = widnow_2d.expand( channel, 1, window_size, window_size ).contiguous() return window @@ -160,9 +143,8 @@ def ssim( window_size: int = 11, window: Tensor | None = None, size_average=True, - full=False, val_range=None, -) -> Union[Tensor, tuple[Tensor, Tensor]]: +) -> Tensor: if val_range is None: if torch.max(img1) > 128: max_val = 255 @@ -205,15 +187,9 @@ def ssim( v1 = 2.0 * sigma12 + c2 v2 = sigma1_sq + sigma2_sq + c2 - cs = torch.mean(v1 / v2) # contrast sensitivity ssim_map = ((2 * mu1_mu2 + c1) * v1) / ((mu1_sq + mu2_sq + c1) * v2) if size_average: - ret = ssim_map.mean() - else: - ret = ssim_map.mean(1).mean(1).mean(1) - - if full: - return ret, cs - return ret, ssim_map + return ssim_map.mean() + return ssim_map.mean(1).mean(1).mean(1) diff --git a/luxonis_train/config/predefined_models/classification_model.py b/luxonis_train/config/predefined_models/classification_model.py index a7405890..7508a278 100644 --- a/luxonis_train/config/predefined_models/classification_model.py +++ b/luxonis_train/config/predefined_models/classification_model.py @@ -129,8 +129,8 @@ def metrics(self) -> list[MetricModuleConfig]: metrics.append( MetricModuleConfig( name="ConfusionMatrix", - alias=f"ConfusionMatrix-{self.task_name}", - attached_to=f"ClassificationHead-{self.task_name}", + alias=f"{self.task_name}/ConfusionMatrix", + attached_to=f"{self.task_name}/ClassificationHead", params={**self.confusion_matrix_params}, ) ) diff --git a/luxonis_train/config/predefined_models/detection_model.py b/luxonis_train/config/predefined_models/detection_model.py index 3c1cabb9..0a827d36 100644 --- a/luxonis_train/config/predefined_models/detection_model.py +++ b/luxonis_train/config/predefined_models/detection_model.py @@ -149,8 +149,8 @@ def metrics(self) -> list[MetricModuleConfig]: metrics.append( MetricModuleConfig( name="ConfusionMatrix", - alias=f"ConfusionMatrix-{self.task_name}", - attached_to=f"EfficientBBoxHead-{self.task_name}", + alias=f"{self.task_name}/ConfusionMatrix", + attached_to=f"{self.task_name}/EfficientBBoxHead", params={**self.confusion_matrix_params}, ) ) diff --git a/luxonis_train/config/predefined_models/keypoint_detection_model.py b/luxonis_train/config/predefined_models/keypoint_detection_model.py index d5706d40..8144de46 100644 --- a/luxonis_train/config/predefined_models/keypoint_detection_model.py +++ b/luxonis_train/config/predefined_models/keypoint_detection_model.py @@ -153,8 +153,8 @@ def metrics(self) -> list[MetricModuleConfig]: metrics.append( MetricModuleConfig( name="ConfusionMatrix", - alias=f"ConfusionMatrix-{self.kpt_task_name}", - attached_to=f"EfficientKeypointBBoxHead-{self.kpt_task_name}", + alias=f"{self.task_name}/ConfusionMatrix", + attached_to=f"{self.task_name}/EfficientKeypointBBoxHead", params={**self.confusion_matrix_params}, ) ) diff --git a/luxonis_train/config/predefined_models/segmentation_model.py b/luxonis_train/config/predefined_models/segmentation_model.py index 9fab84c1..54beeb1f 100644 --- a/luxonis_train/config/predefined_models/segmentation_model.py +++ b/luxonis_train/config/predefined_models/segmentation_model.py @@ -166,8 +166,8 @@ def metrics(self) -> list[MetricModuleConfig]: metrics.append( MetricModuleConfig( name="ConfusionMatrix", - alias=f"ConfusionMatrix-{self.task_name}", - attached_to=f"DDRNetSegmentationHead-{self.task_name}", + alias=f"{self.task_name}/ConfusionMatrix", + attached_to=f"{self.task_name}/DDRNetSegmentationHead", params={**self.confusion_matrix_params}, ) ) diff --git a/luxonis_train/core/utils/infer_utils.py b/luxonis_train/core/utils/infer_utils.py index b4996751..7f009d8f 100644 --- a/luxonis_train/core/utils/infer_utils.py +++ b/luxonis_train/core/utils/infer_utils.py @@ -47,12 +47,10 @@ def process_visualizations( def prepare_and_infer_image( - model: "luxonis_train.core.LuxonisModel", - img: np.ndarray, + model: "luxonis_train.core.LuxonisModel", img: Tensor ): """Prepares the image for inference and runs the model.""" - img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - img, _ = model.val_augmentations([(img, {})]) + img = model.loaders["val"].augment_test_image(img) # type: ignore inputs = { "image": torch.tensor(img).unsqueeze(0).permute(0, 3, 1, 2).float() @@ -94,9 +92,11 @@ def infer_from_video( ret, frame = cap.read() if not ret: # pragma: no cover break + if model.cfg.trainer.preprocessing.color_format == "RGB": + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # TODO: batched inference - outputs = prepare_and_infer_image(model, frame) + outputs = prepare_and_infer_image(model, torch.tensor(frame)) renders = process_visualizations(outputs.visualizations, batch_size=1) for (node_name, viz_name), [viz] in renders.items(): @@ -213,8 +213,8 @@ def generator(): dataset_name=dataset_name, image_source="image", view="test", - augmentations=model.val_augmentations, ) + loader.loader.augmentations = model.loaders["val"].loader.augmentations # type: ignore loader = torch_data.DataLoader( loader, batch_size=model.cfg.trainer.batch_size ) diff --git a/luxonis_train/loaders/base_loader.py b/luxonis_train/loaders/base_loader.py index 3e3589dd..a9998e4b 100644 --- a/luxonis_train/loaders/base_loader.py +++ b/luxonis_train/loaders/base_loader.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from luxonis_ml.utils.registry import AutoRegisterMeta -from torch import Size +from torch import Size, Tensor from torch.utils.data import Dataset from luxonis_train.utils.registry import LOADERS @@ -83,6 +83,13 @@ def input_shape(self) -> Size: """ return self.input_shapes[self.image_source] + def augment_test_image(self, img: Tensor) -> Tensor: + raise NotImplementedError( + f"{self.__class__.__name__} does not expose interface " + "for test-time augmentation. Implement " + "`augment_test_image` method to expose this functionality." + ) + @abstractmethod def __len__(self) -> int: """Returns length of the dataset.""" diff --git a/luxonis_train/loaders/luxonis_loader_torch.py b/luxonis_train/loaders/luxonis_loader_torch.py index 445fe641..d6efe97b 100644 --- a/luxonis_train/loaders/luxonis_loader_torch.py +++ b/luxonis_train/loaders/luxonis_loader_torch.py @@ -3,6 +3,7 @@ from typing import Any, Literal import numpy as np +import torch from luxonis_ml.data import ( BucketStorage, BucketType, @@ -123,13 +124,19 @@ def __getitem__(self, idx: int) -> LuxonisLoaderTorchOutput: return {self.image_source: tensor_img}, tensor_labels def get_classes(self) -> dict[str, list[str]]: - _, classes = self.dataset.get_classes() - return {task: classes[task] for task in classes} + return self.dataset.get_classes() def get_n_keypoints(self) -> dict[str, int]: skeletons = self.dataset.get_skeletons() return {task: len(skeletons[task][0]) for task in skeletons} + def augment_test_image(self, img: Tensor) -> Tensor: + if self.loader.augmentations is None: + return img + return torch.tensor( + self.loader.augmentations.apply([(img.numpy(), {})])[0] + ) + def _parse_dataset( self, dataset_dir: str, diff --git a/luxonis_train/loaders/luxonis_perlin_loader_torch.py b/luxonis_train/loaders/luxonis_perlin_loader_torch.py index c1660ffb..05e2c449 100644 --- a/luxonis_train/loaders/luxonis_perlin_loader_torch.py +++ b/luxonis_train/loaders/luxonis_perlin_loader_torch.py @@ -1,16 +1,13 @@ -import glob -import os import random -from typing import Callable, List +from typing import cast import numpy as np import torch import torch.nn.functional as F +from luxonis_ml.data import AlbumentationsEngine from luxonis_ml.utils import LuxonisFileSystem from torch import Tensor -from luxonis_train.enums import TaskType - from .base_loader import LuxonisLoaderTorchOutput from .luxonis_loader_torch import LuxonisLoaderTorch from .perlin import apply_anomaly_to_img @@ -32,53 +29,45 @@ def __init__( @param noise_prob: The probability with which to apply Perlin noise (only used during training). """ - if not anomaly_source_path: - raise ValueError("anomaly_source_path must be a valid string.") - super().__init__(*args, **kwargs) - lux_fs = LuxonisFileSystem(path=anomaly_source_path) - if lux_fs.protocol in ["s3", "gcs"]: - anomaly_source_path = str( - lux_fs.get_dir( - remote_paths=[anomaly_source_path], local_dir="./data" - ) + + try: + self.anomaly_source_path = LuxonisFileSystem.download( + anomaly_source_path, dest="./data" ) - else: - anomaly_source_path = str(lux_fs.path) + except Exception as e: + raise FileNotFoundError( + "The anomaly source path is invalid." + ) from e + + from luxonis_train.core.utils.infer_utils import IMAGE_FORMATS - if anomaly_source_path and os.path.exists(anomaly_source_path): - self.anomaly_source_paths = sorted( - glob.glob(os.path.join(anomaly_source_path, "*/*.jpg")) + self.anomaly_files = [ + f + for f in self.anomaly_source_path.rglob("*") + if f.suffix.lower() in IMAGE_FORMATS + ] + if not self.anomaly_files: + raise FileNotFoundError( + "No image files found at the specified path." ) - if not self.anomaly_source_paths: - raise FileNotFoundError( - "No .jpg files found at the specified path." - ) - else: - raise ValueError("Invalid or unspecified anomaly source path.") - self.anomaly_source_path = anomaly_source_path self.noise_prob = noise_prob - self.base_loader.add_background = True # type: ignore - self.base_loader.class_mappings["segmentation"]["background"] = 0 - self.base_loader.class_mappings["segmentation"] = { - k: (v + 1 if k != "background" else v) - for k, v in self.base_loader.class_mappings["segmentation"].items() - } - - if ( - self.augmentations is None - or self.augmentations.pixel_transform is None - ): - self.pixel_augs: List[Callable] = [] + if len(self.loader.dataset.get_tasks()) > 1: + # TODO: Can be extended to multiple tasks + raise ValueError( + "This loader only supports datasets with a single task." + ) + self.task_name = next(iter(self.loader.dataset.get_tasks())) + + augmentations = cast(AlbumentationsEngine, self.loader.augmentations) + if augmentations is None or augmentations.pixel_transform is None: + self.pixel_augs = None else: - self.pixel_augs: List[Callable] = [ - transform - for transform in self.augmentations.pixel_transform.transforms - ] + self.pixel_augs = augmentations.pixel_transform def __getitem__(self, idx: int) -> LuxonisLoaderTorchOutput: - img, labels = self.base_loader[idx] + img, labels = self.loader[idx] img = np.transpose(img, (2, 0, 1)) tensor_img = Tensor(img) @@ -86,7 +75,7 @@ def __getitem__(self, idx: int) -> LuxonisLoaderTorchOutput: if self.view[0] == "train" and random.random() < self.noise_prob: aug_tensor_img, an_mask = apply_anomaly_to_img( tensor_img, - anomaly_source_paths=self.anomaly_source_paths, + anomaly_source_paths=self.anomaly_files, pixel_augs=self.pixel_augs, ) else: @@ -94,17 +83,12 @@ def __getitem__(self, idx: int) -> LuxonisLoaderTorchOutput: h, w = aug_tensor_img.shape[-2:] an_mask = torch.zeros((h, w)) - tensor_labels = {"original": (tensor_img, TaskType.ARRAY)} - if self.view[0] == "train": - tensor_labels["segmentation"] = ( - F.one_hot(an_mask.long(), 2).permute(2, 0, 1).float(), - TaskType.SEGMENTATION, - ) - else: - for task, (array, label_type) in labels.items(): - tensor_labels[task] = ( - Tensor(array), - TaskType(label_type.value), - ) + tensor_labels = {f"{self.task_name}/original/segmentation": tensor_img} + for task, array in labels.items(): + tensor_labels[task] = Tensor(array) + + tensor_labels[f"{self.task_name}/segmentation"] = ( + F.one_hot(an_mask.long(), 2).permute(2, 0, 1).float() + ) return {self.image_source: aug_tensor_img}, tensor_labels diff --git a/luxonis_train/loaders/perlin.py b/luxonis_train/loaders/perlin.py index a54e3d7c..687fd015 100644 --- a/luxonis_train/loaders/perlin.py +++ b/luxonis_train/loaders/perlin.py @@ -1,4 +1,5 @@ import random +from pathlib import Path from typing import Callable, List, Tuple import cv2 @@ -135,17 +136,17 @@ def generate_perlin_noise( return perlin_mask -def load_image_as_numpy(img_path: str) -> np.ndarray: - image = cv2.imread(img_path, cv2.IMREAD_COLOR) +def load_image_as_numpy(img_path: Path) -> np.ndarray: + image = cv2.imread(str(img_path), cv2.IMREAD_COLOR) image = image.astype(np.float32) / 255.0 return image def apply_anomaly_to_img( img: torch.Tensor, - anomaly_source_paths: List[str], + anomaly_source_paths: List[Path], beta: float | None = None, - pixel_augs: List[Callable] | None = None, + pixel_augs: Callable | None = None, # type: ignore ) -> Tuple[torch.Tensor, torch.Tensor]: """Applies Perlin noise-based anomalies to a single image (C, H, W). @@ -165,7 +166,9 @@ def apply_anomaly_to_img( """ if pixel_augs is None: - pixel_augs = [] + + def pixel_augs(image): + return {"image": image} sampled_anomaly_image_path = random.choice(anomaly_source_paths) @@ -177,8 +180,7 @@ def apply_anomaly_to_img( interpolation=cv2.INTER_LINEAR, ) - for aug in pixel_augs: - anomaly_image = aug(image=anomaly_image)["image"] + anomaly_image = pixel_augs(image=anomaly_image)["image"] anomaly_image = torch.tensor(anomaly_image).permute(2, 0, 1) diff --git a/luxonis_train/models/luxonis_lightning.py b/luxonis_train/models/luxonis_lightning.py index afd422ae..c478c126 100644 --- a/luxonis_train/models/luxonis_lightning.py +++ b/luxonis_train/models/luxonis_lightning.py @@ -989,7 +989,7 @@ def _init_attached_module( loader = self._core.loaders["train"] dataset = getattr(loader, "dataset", None) if isinstance(dataset, LuxonisDataset): - n_classes = len(dataset.get_classes()[1][node.task_name]) + n_classes = len(dataset.get_classes()[node.task_name]) if n_classes == 1: cfg.params["task"] = "binary" else: @@ -1034,7 +1034,6 @@ def _print_results( ) if self.main_metric is not None: - print(self.main_metric) *main_metric_node, main_metric_name = self.main_metric.split("/") main_metric_node = "/".join(main_metric_node) diff --git a/tests/configs/smart_cfg_populate_config.yaml b/tests/configs/smart_cfg_populate_config.yaml index 0c81752b..87f8a96d 100644 --- a/tests/configs/smart_cfg_populate_config.yaml +++ b/tests/configs/smart_cfg_populate_config.yaml @@ -9,6 +9,7 @@ model: - EfficientRep - name: EfficientBBoxHead + task_name: vehicle_type inputs: - RepPANNeck diff --git a/tests/integration/parking_lot.json b/tests/integration/parking_lot.json index 66ce962d..bf3e3835 100644 --- a/tests/integration/parking_lot.json +++ b/tests/integration/parking_lot.json @@ -96,7 +96,7 @@ "dtype": "float32", "shape": [ 1, - 7, + 8, 32, 40 ], @@ -107,7 +107,7 @@ "dtype": "float32", "shape": [ 1, - 7, + 8, 16, 20 ], @@ -118,11 +118,11 @@ "dtype": "float32", "shape": [ 1, - 7, + 8, 8, 10 ], - "layout": "NCHW" + "layout": "NCDE" } ], "heads": [ @@ -169,9 +169,9 @@ "metadata": { "postprocessor_path": null, "classes": [ - "background", + "motorbike", "car", - "motorbike" + "background" ], "n_classes": 3, "is_softmax": false @@ -185,9 +185,10 @@ "postprocessor_path": null, "classes": [ "motorbike", - "car" + "car", + "background" ], - "n_classes": 2, + "n_classes": 3, "iou_threshold": 0.45, "conf_threshold": 0.25, "max_det": 300, diff --git a/tests/integration/test_unsupervised_anomaly_detection.py b/tests/integration/test_unsupervised_anomaly_detection.py index 10c0a6f9..bf9347b5 100644 --- a/tests/integration/test_unsupervised_anomaly_detection.py +++ b/tests/integration/test_unsupervised_anomaly_detection.py @@ -11,7 +11,7 @@ PathType = Union[str, Path] -def get_opts() -> dict[str, Any]: +def get_config() -> dict[str, Any]: return { "model": { "name": "DREAM", @@ -38,12 +38,15 @@ def get_opts() -> dict[str, Any]: "keep_aspect_ratio": False, "normalize": {"active": True}, }, - "batch_size": 1, - "epochs": 1, + "batch_size": 4, + "epochs": 800, "num_workers": 0, "validation_interval": 10, "num_sanity_val_steps": 0, }, + "tracker": { + "save_directory": "tests/integration/save-directory", + }, } @@ -67,22 +70,19 @@ def dummy_generator( train_paths: List[PathType], test_paths: List[PathType] ): for path in train_paths: + img = cv2.imread(str(path)) + img_h, img_w, _ = img.shape + mask = np.zeros((img_h, img_w), dtype=np.uint8) yield { "file": path, "annotation": { "class": "object", - "segmentation": { - "height": 256, - "width": 256, - "counts": "0" * (256 * 256), - }, + "segmentation": {"mask": mask}, }, } for path in test_paths: img = cv2.imread(str(path)) - if img is None: - continue img_h, img_w, _ = img.shape mask = random_square_mask((img_h, img_w)) poly = cv2.findContours( @@ -130,7 +130,6 @@ def test_anomaly_detection(): create_dummy_anomaly_detection_dataset( Path("tests/data/COCO_people_subset/person_val2017_subset/*") ) - config = get_opts() - model = LuxonisModel(config) + model = LuxonisModel(get_config()) model.train() model.test() diff --git a/tests/unittests/test_metrics/test_confusion_matrix.py b/tests/unittests/test_metrics/test_confusion_matrix.py index 07429d8b..40a59dbb 100644 --- a/tests/unittests/test_metrics/test_confusion_matrix.py +++ b/tests/unittests/test_metrics/test_confusion_matrix.py @@ -11,8 +11,7 @@ def test_compute_detection_confusion_matrix_specific_case(): class DummyNodeDetection(BaseNode): tasks = [TaskType.BOUNDINGBOX] - def forward(self, _): - pass + def forward(self, _): ... metric = ConfusionMatrix( node=DummyNodeDetection(n_classes=3), iou_threshold=0.5 From bb5e88295a25a08e0725f53de6ef40033bce8112 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Wed, 15 Jan 2025 19:11:46 -0500 Subject: [PATCH 20/31] fix debug config --- tests/integration/test_unsupervised_anomaly_detection.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/integration/test_unsupervised_anomaly_detection.py b/tests/integration/test_unsupervised_anomaly_detection.py index bf9347b5..be00cfb0 100644 --- a/tests/integration/test_unsupervised_anomaly_detection.py +++ b/tests/integration/test_unsupervised_anomaly_detection.py @@ -38,8 +38,8 @@ def get_config() -> dict[str, Any]: "keep_aspect_ratio": False, "normalize": {"active": True}, }, - "batch_size": 4, - "epochs": 800, + "batch_size": 1, + "epochs": 1, "num_workers": 0, "validation_interval": 10, "num_sanity_val_steps": 0, From 785f2f819d203492e8fbc2264563fd40e69b72a3 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Wed, 15 Jan 2025 19:14:02 -0500 Subject: [PATCH 21/31] updated perlin --- luxonis_train/loaders/perlin.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/luxonis_train/loaders/perlin.py b/luxonis_train/loaders/perlin.py index 687fd015..cd201d35 100644 --- a/luxonis_train/loaders/perlin.py +++ b/luxonis_train/loaders/perlin.py @@ -14,7 +14,7 @@ def compute_gradients(res: tuple[int, int]) -> torch.Tensor: @torch.jit.script -def lerp_torch( +def lerp_torch( # pragma: no cover x: torch.Tensor, y: torch.Tensor, w: torch.Tensor ) -> torch.Tensor: return (y - x) * w + x @@ -92,7 +92,7 @@ def rand_perlin_2d( @torch.jit.script -def rotate_noise(noise: torch.Tensor) -> torch.Tensor: +def rotate_noise(noise: torch.Tensor) -> torch.Tensor: # pragma: no cover angle = torch.rand(1) * 2 * torch.pi h, w = noise.shape center_y, center_x = h // 2, w // 2 @@ -165,11 +165,6 @@ def apply_anomaly_to_img( - perlin_mask (torch.Tensor): The Perlin noise mask applied to the image. """ - if pixel_augs is None: - - def pixel_augs(image): - return {"image": image} - sampled_anomaly_image_path = random.choice(anomaly_source_paths) anomaly_image = load_image_as_numpy(sampled_anomaly_image_path) @@ -180,7 +175,8 @@ def pixel_augs(image): interpolation=cv2.INTER_LINEAR, ) - anomaly_image = pixel_augs(image=anomaly_image)["image"] + if pixel_augs is not None: + anomaly_image = pixel_augs(image=anomaly_image)["image"] anomaly_image = torch.tensor(anomaly_image).permute(2, 0, 1) From c09336363f499a3f03c3f13aeab890c5a3fd193e Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Wed, 15 Jan 2025 19:18:07 -0500 Subject: [PATCH 22/31] missing doc --- luxonis_train/loaders/luxonis_loader_torch.py | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/luxonis_train/loaders/luxonis_loader_torch.py b/luxonis_train/loaders/luxonis_loader_torch.py index d6efe97b..d44c92ea 100644 --- a/luxonis_train/loaders/luxonis_loader_torch.py +++ b/luxonis_train/loaders/luxonis_loader_torch.py @@ -76,6 +76,39 @@ def __init__( view of the dataset. Each split is a string that represents a subset of the dataset. The available splits depend on the dataset, but usually include 'train', 'val', and 'test'. Defaults to 'train'. + @type augmentation_engine: Union[Literal["albumentations"], str] + @param augmentation_engine: The augmentation engine to use. + Defaults to C{"albumentations"}. + @type augmentation_config: Optional[Union[List[Dict[str, Any]], + PathType]] + @param augmentation_config: The configuration for the + augmentations. This can be either a list of C{Dict[str, Any]} or + a path to a configuration file. + The config member is a dictionary with two keys: C{name} and + C{params}. C{name} is the name of the augmentation to + instantiate and C{params} is an optional dictionary + of parameters to pass to the augmentation. + + Example:: + + [ + {"name": "HorizontalFlip", "params": {"p": 0.5}}, + {"name": "RandomBrightnessContrast", "params": {"p": 0.1}}, + {"name": "Defocus"} + ] + + @type height: Optional[int] + @param height: The height of the output images. Defaults to + C{None}. + @type width: Optional[int] + @param width: The width of the output images. Defaults to + C{None}. + @type keep_aspect_ratio: bool + @param keep_aspect_ratio: Whether to keep the aspect ratio of the + images. Defaults to C{True}. + @type out_image_format: Literal["RGB", "BGR"] + @param out_image_format: The format of the output images. Defaults + to C{"RGB"}. """ super().__init__(view=view, **kwargs) if dataset_dir is not None: From f2cdfa3ba3c884b2a7c10d1f718573f7c9eb815e Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Wed, 15 Jan 2025 19:21:17 -0500 Subject: [PATCH 23/31] reverted bacj to train_rgb --- configs/README.md | 18 +++++++++--------- configs/complex_model.yaml | 1 - luxonis_train/config/config.py | 2 +- luxonis_train/core/core.py | 6 ++++-- luxonis_train/core/utils/infer_utils.py | 2 +- tests/configs/cli_commands.yaml | 1 - tests/configs/parking_lot_config.yaml | 1 - 7 files changed, 15 insertions(+), 16 deletions(-) diff --git a/configs/README.md b/configs/README.md index 37d150f9..e5ae8c81 100644 --- a/configs/README.md +++ b/configs/README.md @@ -280,14 +280,14 @@ We use [`Albumentations`](https://albumentations.ai/docs/) library for `augmenta Additionally, we support `Mosaic4` and `MixUp` batch augmentations and letterbox resizing if `keep_aspect_ratio: true`. -| Key | Type | Default value | Description | -| ------------------- | ----------------------- | ------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `train_image_size` | `list[int]` | `[256, 256]` | Image size used for training as `[height, width]` | -| `keep_aspect_ratio` | `bool` | `True` | Whether to keep the aspect ratio while resizing | -| `color_format` | `Literal["RGB", "BGR"]` | `"RGB"` | Whether to train on RGB or BGR images | -| `normalize.active` | `bool` | `True` | Whether to use normalization | -| `normalize.params` | `dict` | `{}` | Parameters for normalization, see [Normalize](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.Normalize) | -| `augmentations` | `list[dict]` | `[]` | List of `Albumentations` augmentations | +| Key | Type | Default value | Description | +| ------------------- | ------------ | ------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `train_image_size` | `list[int]` | `[256, 256]` | Image size used for training as `[height, width]` | +| `keep_aspect_ratio` | `bool` | `True` | Whether to keep the aspect ratio while resizing | +| `train_rgb` | `bool` | `"RGB"` | Whether to train on RGB or BGR images | +| `normalize.active` | `bool` | `True` | Whether to use normalization | +| `normalize.params` | `dict` | `{}` | Parameters for normalization, see [Normalize](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.Normalize) | +| `augmentations` | `list[dict]` | `[]` | List of `Albumentations` augmentations | #### Augmentations @@ -306,7 +306,7 @@ trainer: # using YAML capture to reuse the image size train_image_size: [&height 384, &width 384] keep_aspect_ratio: true - color_format: "RGB" + train_rgb: true normalize: active: true augmentations: diff --git a/configs/complex_model.yaml b/configs/complex_model.yaml index fed25c23..ee8fa037 100644 --- a/configs/complex_model.yaml +++ b/configs/complex_model.yaml @@ -100,7 +100,6 @@ trainer: preprocessing: train_image_size: [&height 384, &width 384] keep_aspect_ratio: true - color_format: RGB normalize: active: true augmentations: diff --git a/luxonis_train/config/config.py b/luxonis_train/config/config.py index 014caaf2..3d3690e0 100644 --- a/luxonis_train/config/config.py +++ b/luxonis_train/config/config.py @@ -318,7 +318,7 @@ class PreprocessingConfig(BaseModelExtraForbid): ImageSize, Field(default=[256, 256], min_length=2, max_length=2) ] = ImageSize(256, 256) keep_aspect_ratio: bool = True - color_format: Literal["RGB", "BGR"] = "RGB" + train_rgb: bool = True normalize: NormalizeAugmentationConfig = NormalizeAugmentationConfig() augmentations: list[AugmentationConfig] = [] diff --git a/luxonis_train/core/core.py b/luxonis_train/core/core.py index 2fb59457..a0529030 100644 --- a/luxonis_train/core/core.py +++ b/luxonis_train/core/core.py @@ -132,7 +132,9 @@ def __init__( i.model_dump() for i in self.cfg.trainer.preprocessing.get_active_augmentations() ], - out_image_format=self.cfg.trainer.preprocessing.color_format, + out_image_format="RGB" + if self.cfg.trainer.preprocessing.train_rgb + else "BGR", keep_aspect_ratio=self.cfg.trainer.preprocessing.keep_aspect_ratio, **self.cfg.loader.params, ) @@ -721,7 +723,7 @@ def _mult(lst: list[float | int]) -> list[float]: self.cfg.trainer.preprocessing.normalize.params["std"] ), "dai_type": "RGB888p" - if self.cfg.trainer.preprocessing.color_format == "RGB" + if self.cfg.trainer.preprocessing.train_rgb else "BGR888p", } diff --git a/luxonis_train/core/utils/infer_utils.py b/luxonis_train/core/utils/infer_utils.py index 7f009d8f..26e994e5 100644 --- a/luxonis_train/core/utils/infer_utils.py +++ b/luxonis_train/core/utils/infer_utils.py @@ -92,7 +92,7 @@ def infer_from_video( ret, frame = cap.read() if not ret: # pragma: no cover break - if model.cfg.trainer.preprocessing.color_format == "RGB": + if model.cfg.trainer.preprocessing.train_rgb: frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # TODO: batched inference diff --git a/tests/configs/cli_commands.yaml b/tests/configs/cli_commands.yaml index 7123534c..b8e604ae 100644 --- a/tests/configs/cli_commands.yaml +++ b/tests/configs/cli_commands.yaml @@ -51,7 +51,6 @@ trainer: preprocessing: train_image_size: [256, 320] keep_aspect_ratio: true - color_format: RGB normalize: active: true diff --git a/tests/configs/parking_lot_config.yaml b/tests/configs/parking_lot_config.yaml index 2754197e..a5ee50cb 100644 --- a/tests/configs/parking_lot_config.yaml +++ b/tests/configs/parking_lot_config.yaml @@ -87,7 +87,6 @@ trainer: preprocessing: train_image_size: [256, 320] keep_aspect_ratio: false - color_format: RGB normalize: active: true augmentations: From e32f6ea7b93fd5c8f9660512ba1fe8f47a632c26 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Thu, 16 Jan 2025 17:55:32 -0500 Subject: [PATCH 24/31] fix type issues --- README.md | 2 +- configs/README.md | 18 +-- .../losses/adaptive_detection_loss.py | 6 +- luxonis_train/callbacks/__init__.py | 24 +-- .../callbacks/archive_on_train_end.py | 2 +- luxonis_train/callbacks/ema.py | 7 +- luxonis_train/callbacks/gpu_stats_monitor.py | 10 +- luxonis_train/callbacks/gradcam_visializer.py | 4 +- .../callbacks/luxonis_progress_bar.py | 4 +- luxonis_train/callbacks/metadata_logger.py | 2 +- luxonis_train/callbacks/test_on_train_end.py | 2 +- luxonis_train/callbacks/training_manager.py | 2 +- luxonis_train/config/config.py | 21 ++- luxonis_train/core/core.py | 33 ++-- luxonis_train/core/utils/infer_utils.py | 2 +- luxonis_train/loaders/base_loader.py | 148 ++++++++++++++++-- luxonis_train/loaders/luxonis_loader_torch.py | 4 +- luxonis_train/optimizers/optimizers.py | 2 +- luxonis_train/schedulers/schedulers.py | 2 +- luxonis_train/strategies/base_strategy.py | 2 +- luxonis_train/strategies/triple_lr_sgd.py | 2 +- tests/unittests/test_callbacks/test_ema.py | 2 +- 22 files changed, 216 insertions(+), 85 deletions(-) diff --git a/README.md b/README.md index 6cb85a00..10a69acd 100644 --- a/README.md +++ b/README.md @@ -595,7 +595,7 @@ from luxonis_train import LuxonisLightningModule from luxonis_train.utils.registry import CALLBACKS -@CALLBACKS.register_module() +@CALLBACKS.register() class CustomCallback(pl.Callback): def __init__(self, message: str, **kwargs): super().__init__(**kwargs) diff --git a/configs/README.md b/configs/README.md index e5ae8c81..b336ef98 100644 --- a/configs/README.md +++ b/configs/README.md @@ -280,14 +280,14 @@ We use [`Albumentations`](https://albumentations.ai/docs/) library for `augmenta Additionally, we support `Mosaic4` and `MixUp` batch augmentations and letterbox resizing if `keep_aspect_ratio: true`. -| Key | Type | Default value | Description | -| ------------------- | ------------ | ------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `train_image_size` | `list[int]` | `[256, 256]` | Image size used for training as `[height, width]` | -| `keep_aspect_ratio` | `bool` | `True` | Whether to keep the aspect ratio while resizing | -| `train_rgb` | `bool` | `"RGB"` | Whether to train on RGB or BGR images | -| `normalize.active` | `bool` | `True` | Whether to use normalization | -| `normalize.params` | `dict` | `{}` | Parameters for normalization, see [Normalize](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.Normalize) | -| `augmentations` | `list[dict]` | `[]` | List of `Albumentations` augmentations | +| Key | Type | Default value | Description | +| ------------------- | ----------------------- | ------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `train_image_size` | `list[int]` | `[256, 256]` | Image size used for training as `[height, width]` | +| `keep_aspect_ratio` | `bool` | `True` | Whether to keep the aspect ratio while resizing | +| `color_space` | `Literal["RGB", "BGR"]` | `"RGB"` | Whether to train on RGB or BGR images | +| `normalize.active` | `bool` | `True` | Whether to use normalization | +| `normalize.params` | `dict` | `{}` | Parameters for normalization, see [Normalize](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.Normalize) | +| `augmentations` | `list[dict]` | `[]` | List of `Albumentations` augmentations | #### Augmentations @@ -306,7 +306,7 @@ trainer: # using YAML capture to reuse the image size train_image_size: [&height 384, &width 384] keep_aspect_ratio: true - train_rgb: true + color_space: "RGB" normalize: active: true augmentations: diff --git a/luxonis_train/attached_modules/losses/adaptive_detection_loss.py b/luxonis_train/attached_modules/losses/adaptive_detection_loss.py index a81d5a45..952d11c3 100644 --- a/luxonis_train/attached_modules/losses/adaptive_detection_loss.py +++ b/luxonis_train/attached_modules/losses/adaptive_detection_loss.py @@ -3,7 +3,7 @@ import torch import torch.nn.functional as F -from torch import Tensor, nn +from torch import Tensor, amp, nn from torchvision.ops import box_convert from luxonis_train.assigners import ATSSAssigner, TaskAlignedAssigner @@ -270,9 +270,7 @@ def forward( self.alpha * pred_score.pow(self.gamma) * (1 - label) + target_score * label ) - with torch.amp.autocast( - device_type=pred_score.device.type, enabled=False - ): + with amp.autocast(device_type=pred_score.device.type, enabled=False): ce_loss = F.binary_cross_entropy( pred_score.float(), target_score.float(), reduction="none" ) diff --git a/luxonis_train/callbacks/__init__.py b/luxonis_train/callbacks/__init__.py index 7bea71a9..d374916c 100644 --- a/luxonis_train/callbacks/__init__.py +++ b/luxonis_train/callbacks/__init__.py @@ -28,18 +28,18 @@ from .training_manager import TrainingManager from .upload_checkpoint import UploadCheckpoint -CALLBACKS.register_module(module=EarlyStopping) -CALLBACKS.register_module(module=LearningRateMonitor) -CALLBACKS.register_module(module=ModelCheckpoint) -CALLBACKS.register_module(module=RichModelSummary) -CALLBACKS.register_module(module=DeviceStatsMonitor) -CALLBACKS.register_module(module=GradientAccumulationScheduler) -CALLBACKS.register_module(module=StochasticWeightAveraging) -CALLBACKS.register_module(module=Timer) -CALLBACKS.register_module(module=ModelPruning) -CALLBACKS.register_module(module=GradCamCallback) -CALLBACKS.register_module(module=EMACallback) -CALLBACKS.register_module(module=TrainingManager) +CALLBACKS.register(module=EarlyStopping) +CALLBACKS.register(module=LearningRateMonitor) +CALLBACKS.register(module=ModelCheckpoint) +CALLBACKS.register(module=RichModelSummary) +CALLBACKS.register(module=DeviceStatsMonitor) +CALLBACKS.register(module=GradientAccumulationScheduler) +CALLBACKS.register(module=StochasticWeightAveraging) +CALLBACKS.register(module=Timer) +CALLBACKS.register(module=ModelPruning) +CALLBACKS.register(module=GradCamCallback) +CALLBACKS.register(module=EMACallback) +CALLBACKS.register(module=TrainingManager) __all__ = [ diff --git a/luxonis_train/callbacks/archive_on_train_end.py b/luxonis_train/callbacks/archive_on_train_end.py index 0ed69bb5..67e27ab7 100644 --- a/luxonis_train/callbacks/archive_on_train_end.py +++ b/luxonis_train/callbacks/archive_on_train_end.py @@ -10,7 +10,7 @@ logger = logging.getLogger(__name__) -@CALLBACKS.register_module() +@CALLBACKS.register() class ArchiveOnTrainEnd(NeedsCheckpoint): def on_train_end( self, diff --git a/luxonis_train/callbacks/ema.py b/luxonis_train/callbacks/ema.py index 20c01c04..e76e8c1a 100644 --- a/luxonis_train/callbacks/ema.py +++ b/luxonis_train/callbacks/ema.py @@ -3,10 +3,9 @@ from copy import deepcopy from typing import Any -import pytorch_lightning as pl +import lightning.pytorch as pl import torch -from pytorch_lightning.callbacks import Callback -from pytorch_lightning.utilities.types import STEP_OUTPUT +from lightning.pytorch.utilities.types import STEP_OUTPUT from torch import nn logger = logging.getLogger(__name__) @@ -101,7 +100,7 @@ def update(self, model: pl.LightningModule) -> None: ) -class EMACallback(Callback): +class EMACallback(pl.Callback): """Callback that updates the stored parameters using a moving average.""" diff --git a/luxonis_train/callbacks/gpu_stats_monitor.py b/luxonis_train/callbacks/gpu_stats_monitor.py index a189ed3f..c8774aa6 100644 --- a/luxonis_train/callbacks/gpu_stats_monitor.py +++ b/luxonis_train/callbacks/gpu_stats_monitor.py @@ -25,20 +25,20 @@ import time from typing import Any, Dict, List, Optional, Tuple -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from lightning.pytorch.accelerators.cuda import CUDAAccelerator +from lightning.pytorch.utilities import rank_zero_only +from lightning.pytorch.utilities.parsing import AttributeDict +from lightning.pytorch.utilities.types import STEP_OUTPUT from lightning_fabric.utilities.exceptions import ( MisconfigurationException, # noqa: F401 ) -from pytorch_lightning.utilities import rank_zero_only -from pytorch_lightning.utilities.parsing import AttributeDict -from pytorch_lightning.utilities.types import STEP_OUTPUT from luxonis_train.utils.registry import CALLBACKS -@CALLBACKS.register_module() +@CALLBACKS.register() class GPUStatsMonitor(pl.Callback): def __init__( self, diff --git a/luxonis_train/callbacks/gradcam_visializer.py b/luxonis_train/callbacks/gradcam_visializer.py index 28863502..1d9616f2 100644 --- a/luxonis_train/callbacks/gradcam_visializer.py +++ b/luxonis_train/callbacks/gradcam_visializer.py @@ -1,16 +1,16 @@ import logging from typing import Any, Union +import lightning.pytorch as pl import numpy as np -import pytorch_lightning as pl import torch +from lightning.pytorch.utilities.types import STEP_OUTPUT from pytorch_grad_cam import HiResCAM from pytorch_grad_cam.utils.image import show_cam_on_image from pytorch_grad_cam.utils.model_targets import ( ClassifierOutputTarget, SemanticSegmentationTarget, ) -from pytorch_lightning.utilities.types import STEP_OUTPUT from luxonis_train.attached_modules.visualizers import ( get_denormalized_images, diff --git a/luxonis_train/callbacks/luxonis_progress_bar.py b/luxonis_train/callbacks/luxonis_progress_bar.py index b8bf6512..20665ced 100644 --- a/luxonis_train/callbacks/luxonis_progress_bar.py +++ b/luxonis_train/callbacks/luxonis_progress_bar.py @@ -46,7 +46,7 @@ def print_results( ... -@CALLBACKS.register_module() +@CALLBACKS.register() class LuxonisTQDMProgressBar(TQDMProgressBar, BaseLuxonisProgressBar): """Custom text progress bar based on TQDMProgressBar from Pytorch Lightning.""" @@ -104,7 +104,7 @@ def print_results( self._rule() -@CALLBACKS.register_module() +@CALLBACKS.register() class LuxonisRichProgressBar(RichProgressBar, BaseLuxonisProgressBar): """Custom rich text progress bar based on RichProgressBar from Pytorch Lightning.""" diff --git a/luxonis_train/callbacks/metadata_logger.py b/luxonis_train/callbacks/metadata_logger.py index 997ccbcd..0d7e3905 100644 --- a/luxonis_train/callbacks/metadata_logger.py +++ b/luxonis_train/callbacks/metadata_logger.py @@ -10,7 +10,7 @@ from luxonis_train.utils.registry import CALLBACKS -@CALLBACKS.register_module() +@CALLBACKS.register() class MetadataLogger(pl.Callback): def __init__(self, hyperparams: list[str]): """Callback that logs training metadata. diff --git a/luxonis_train/callbacks/test_on_train_end.py b/luxonis_train/callbacks/test_on_train_end.py index a60a16dd..9a437ff4 100644 --- a/luxonis_train/callbacks/test_on_train_end.py +++ b/luxonis_train/callbacks/test_on_train_end.py @@ -5,7 +5,7 @@ from luxonis_train.utils.registry import CALLBACKS -@CALLBACKS.register_module() +@CALLBACKS.register() class TestOnTrainEnd(pl.Callback): """Callback to perform a test run at the end of the training.""" diff --git a/luxonis_train/callbacks/training_manager.py b/luxonis_train/callbacks/training_manager.py index 9131fa84..d9cc7002 100644 --- a/luxonis_train/callbacks/training_manager.py +++ b/luxonis_train/callbacks/training_manager.py @@ -1,4 +1,4 @@ -import pytorch_lightning as pl +import lightning.pytorch as pl from luxonis_train.strategies.base_strategy import BaseTrainingStrategy diff --git a/luxonis_train/config/config.py b/luxonis_train/config/config.py index 3d3690e0..a73737cb 100644 --- a/luxonis_train/config/config.py +++ b/luxonis_train/config/config.py @@ -4,6 +4,7 @@ from typing import Annotated, Any, Literal, NamedTuple, TypeAlias from luxonis_ml.enums import DatasetType +from luxonis_ml.typing import ConfigItem from luxonis_ml.utils import ( BaseModelExtraForbid, Environ, @@ -318,10 +319,20 @@ class PreprocessingConfig(BaseModelExtraForbid): ImageSize, Field(default=[256, 256], min_length=2, max_length=2) ] = ImageSize(256, 256) keep_aspect_ratio: bool = True - train_rgb: bool = True + color_space: Literal["RGB", "BGR"] = "RGB" normalize: NormalizeAugmentationConfig = NormalizeAugmentationConfig() augmentations: list[AugmentationConfig] = [] + @model_validator(mode="before") + @classmethod + def validate_train_rgb(cls, data: dict[str, Any]) -> dict[str, Any]: + if "train_rgb" in data: + warnings.warn( + "Field `train_rgb` is deprecated. Use `color_space` instead." + ) + data["color_space"] = "RGB" if data.pop("train_rgb") else "BGR" + return data + @model_validator(mode="after") def check_normalize(self) -> Self: if self.normalize.active: @@ -332,13 +343,17 @@ def check_normalize(self) -> Self: ) return self - def get_active_augmentations(self) -> list[AugmentationConfig]: + def get_active_augmentations(self) -> list[ConfigItem]: """Returns list of augmentations that are active. @rtype: list[AugmentationConfig] @return: Filtered list of active augmentation configs """ - return [aug for aug in self.augmentations if aug.active] + return [ + ConfigItem(name=aug.name, params=aug.params) + for aug in self.augmentations + if aug.active + ] class CallbackConfig(BaseModelExtraForbid): diff --git a/luxonis_train/core/core.py b/luxonis_train/core/core.py index a0529030..2ae78162 100644 --- a/luxonis_train/core/core.py +++ b/luxonis_train/core/core.py @@ -75,6 +75,8 @@ def __init__( else: self.cfg = Config.get_config(cfg, opts) + self.cfg_preprocessing = self.cfg.trainer.preprocessing + rich.traceback.install(suppress=[pl, torch], show_locals=False) self.tracker = LuxonisTrackerPL( @@ -126,16 +128,11 @@ def __init__( "test": self.cfg.loader.test_view, }[view], image_source=self.cfg.loader.image_source, - height=self.cfg.trainer.preprocessing.train_image_size.height, - width=self.cfg.trainer.preprocessing.train_image_size.width, - augmentation_config=[ - i.model_dump() - for i in self.cfg.trainer.preprocessing.get_active_augmentations() - ], - out_image_format="RGB" - if self.cfg.trainer.preprocessing.train_rgb - else "BGR", - keep_aspect_ratio=self.cfg.trainer.preprocessing.keep_aspect_ratio, + height=self.cfg_preprocessing.train_image_size.height, + width=self.cfg_preprocessing.train_image_size.width, + augmentation_config=self.cfg_preprocessing.get_active_augmentations(), + color_space=self.cfg_preprocessing.color_space, + keep_aspect_ratio=self.cfg_preprocessing.keep_aspect_ratio, **self.cfg.loader.params, ) @@ -598,9 +595,7 @@ def _objective(trial: optuna.trial.Trial) -> float: "You have to specify the `tuner` section in config." ) - all_augs = [ - a.name for a in self.cfg.trainer.preprocessing.augmentations - ] + all_augs = [a.name for a in self.cfg_preprocessing.augmentations] rank = rank_zero_only.rank cfg_tracker = self.cfg.tracker tracker_params = cfg_tracker.model_dump() @@ -716,15 +711,9 @@ def _mult(lst: list[float | int]) -> list[float]: return [round(x * 255.0, 5) for x in lst] preprocessing = { # TODO: keep preprocessing same for each input? - "mean": _mult( - self.cfg.trainer.preprocessing.normalize.params["mean"] - ), - "scale": _mult( - self.cfg.trainer.preprocessing.normalize.params["std"] - ), - "dai_type": "RGB888p" - if self.cfg.trainer.preprocessing.train_rgb - else "BGR888p", + "mean": _mult(self.cfg_preprocessing.normalize.params["mean"]), + "scale": _mult(self.cfg_preprocessing.normalize.params["std"]), + "dai_type": f"{self.cfg_preprocessing.color_space}888p", } inputs_dict = get_inputs(path) diff --git a/luxonis_train/core/utils/infer_utils.py b/luxonis_train/core/utils/infer_utils.py index 26e994e5..22cd7521 100644 --- a/luxonis_train/core/utils/infer_utils.py +++ b/luxonis_train/core/utils/infer_utils.py @@ -92,7 +92,7 @@ def infer_from_video( ret, frame = cap.read() if not ret: # pragma: no cover break - if model.cfg.trainer.preprocessing.train_rgb: + if model.cfg.trainer.preprocessing.color_space == "RGB": frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # TODO: batched inference diff --git a/luxonis_train/loaders/base_loader.py b/luxonis_train/loaders/base_loader.py index a9998e4b..19fc4dc6 100644 --- a/luxonis_train/loaders/base_loader.py +++ b/luxonis_train/loaders/base_loader.py @@ -1,5 +1,7 @@ from abc import ABC, abstractmethod +from typing import Any, Literal +from luxonis_ml.typing import ConfigItem from luxonis_ml.utils.registry import AutoRegisterMeta from torch import Size, Tensor from torch.utils.data import Dataset @@ -16,28 +18,136 @@ class BaseLoaderTorch( register=False, registry=LOADERS, ): - """Base abstract loader class that enforces LuxonisLoaderTorchOutput - output label structure.""" - def __init__( self, - view: str | list[str], - image_source: str | None = None, + view: list[str], + height: int, + width: int, + augmentation_engine: str = "albumentations", + augmentation_config: list[ConfigItem] | None = None, + image_source: str = "default", + keep_aspect_ratio: bool = True, + color_space: Literal["RGB", "BGR"] = "RGB", ): - self.view = view if isinstance(view, list) else [view] + """Base abstract loader class that enforces + LuxonisLoaderTorchOutput output label structure. + + @type view: list[str] + @param view: List of view names. Usually contains only one element, + e.g. C{["train"]} or C{["test"]}. However, more complex datasets + can make use of multiple views, e.g. C{["train_synthetic", + "train_real"]} + + @type height: int + @param height: Height of the output image. + + @type width: int + @param width: Width of the output image. + + @type augmentation_engine: str + @param augmentation_engine: Name of the augmentation engine. Can + be used to enable swapping between different augmentation engines or making use of pre-defined engines, e.g. C{AlbumentationsEngine}. + + @type augmentation_config: list[ConfigItem] | None + @param augmentation_config: List of augmentation configurations. + Individual configurations are in the form of:: + + class ConfigItem: + name: str + params: dict[str, JsonValue] + + Where C{name} is the name of the augmentation and C{params} is a + dictionary of its parameters. + + Example:: + + ConfigItem( + name="HorizontalFlip", + params={"p": 0.5}, + ) + + @type image_source: str + @param image_source: Name of the input image group. This can be used for datasets with multiple image sources, e.g. left and right cameras or RGB and depth images. Irrelevant for datasets with only one image source. + + @type keep_aspect_ratio: bool + @param keep_aspect_ratio: Whether to keep the aspect ratio of the output image after resizing. + + @type color_space: Literal["RGB", "BGR"] + @param color_space: Color space of the output image. + """ + self._view = view self._image_source = image_source + self._augmentation_engine = augmentation_engine + self._augmentation_config = augmentation_config + self._height = height + self._width = width + self._keep_aspect_ratio = keep_aspect_ratio + self._color_space = color_space @property def image_source(self) -> str: """Name of the input image group. - Example: C{"image"} + @type: str + """ + return self._getter_check_none("image_source") + + @property + def view(self) -> list[str]: + """List of view names. + + @type: list[str] + """ + return self._view + + @property + def augmentation_engine(self) -> str: + """Name of the augmentation engine. @type: str """ - if self._image_source is None: - raise ValueError("image_source is not set") - return self._image_source + return self._getter_check_none("augmentation_engine") + + @property + def augmentation_config(self) -> list[ConfigItem]: + """List of augmentation configurations. + + @type: list[ConfigItem] + """ + return self._getter_check_none("augmentation_config") + + @property + def height(self) -> int: + """Height of the output image. + + @type: int + """ + return self._getter_check_none("height") + + @property + def width(self) -> int: + """Width of the output image. + + @type: int + """ + return self._getter_check_none("width") + + @property + def keep_aspect_ratio(self) -> bool: + """Whether to keep the aspect ratio of the output image after + resizing. + + @type: bool + """ + return self._getter_check_none("keep_aspect_ratio") + + @property + def color_space(self) -> str: + """Color space of the output image. + + @type: str + """ + return self._getter_check_none("color_space") @property @abstractmethod @@ -124,3 +234,21 @@ def get_n_keypoints(self) -> dict[str, int] | None: definitions. """ return None + + def _getter_check_none( + self, + attribute: Literal[ + "view", + "image_source", + "augmentation_engine", + "augmentation_config", + "height", + "width", + "keep_aspect_ratio", + "color_space", + ], + ) -> Any: + value = getattr(self, f"_{attribute}") + if value is None: + raise ValueError(f"{attribute} is not set") + return value diff --git a/luxonis_train/loaders/luxonis_loader_torch.py b/luxonis_train/loaders/luxonis_loader_torch.py index d44c92ea..7514ff91 100644 --- a/luxonis_train/loaders/luxonis_loader_torch.py +++ b/luxonis_train/loaders/luxonis_loader_torch.py @@ -110,7 +110,9 @@ def __init__( @param out_image_format: The format of the output images. Defaults to C{"RGB"}. """ - super().__init__(view=view, **kwargs) + super().__init__( + view=view if isinstance(view, list) else [view], **kwargs + ) if dataset_dir is not None: self.dataset = self._parse_dataset( dataset_dir, dataset_name, dataset_type, delete_existing diff --git a/luxonis_train/optimizers/optimizers.py b/luxonis_train/optimizers/optimizers.py index c2a4bf12..42f63f8a 100644 --- a/luxonis_train/optimizers/optimizers.py +++ b/luxonis_train/optimizers/optimizers.py @@ -16,4 +16,4 @@ optim.RMSprop, optim.SGD, ]: - OPTIMIZERS.register_module(module=optimizer) + OPTIMIZERS.register(module=optimizer) diff --git a/luxonis_train/schedulers/schedulers.py b/luxonis_train/schedulers/schedulers.py index 488a7498..0497a3c3 100644 --- a/luxonis_train/schedulers/schedulers.py +++ b/luxonis_train/schedulers/schedulers.py @@ -19,4 +19,4 @@ lr_scheduler.OneCycleLR, lr_scheduler.CosineAnnealingWarmRestarts, ]: - SCHEDULERS.register_module(module=scheduler) + SCHEDULERS.register(module=scheduler) diff --git a/luxonis_train/strategies/base_strategy.py b/luxonis_train/strategies/base_strategy.py index 8de6386d..5b812ebf 100644 --- a/luxonis_train/strategies/base_strategy.py +++ b/luxonis_train/strategies/base_strategy.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod -import pytorch_lightning as pl +import lightning.pytorch as pl from luxonis_ml.utils.registry import AutoRegisterMeta from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler diff --git a/luxonis_train/strategies/triple_lr_sgd.py b/luxonis_train/strategies/triple_lr_sgd.py index 33f7dfe3..b76b5ee8 100644 --- a/luxonis_train/strategies/triple_lr_sgd.py +++ b/luxonis_train/strategies/triple_lr_sgd.py @@ -1,8 +1,8 @@ # strategies/triple_lr_sgd.py import math +import lightning.pytorch as pl import numpy as np -import pytorch_lightning as pl import torch from torch.optim import SGD from torch.optim.lr_scheduler import LambdaLR diff --git a/tests/unittests/test_callbacks/test_ema.py b/tests/unittests/test_callbacks/test_ema.py index 0780e783..d117eb88 100644 --- a/tests/unittests/test_callbacks/test_ema.py +++ b/tests/unittests/test_callbacks/test_ema.py @@ -2,7 +2,7 @@ import pytest import torch -from pytorch_lightning import LightningModule, Trainer +from lightning.pytorch import LightningModule, Trainer from luxonis_train.callbacks.ema import EMACallback, ModelEma From eef219a7132f0d927c9af759d903f9bf3dfaa2ae Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Thu, 16 Jan 2025 18:26:21 -0500 Subject: [PATCH 25/31] replaced deprecated `register_module` --- luxonis_train/callbacks/export_on_train_end.py | 2 +- luxonis_train/callbacks/upload_checkpoint.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/luxonis_train/callbacks/export_on_train_end.py b/luxonis_train/callbacks/export_on_train_end.py index 80d2a648..195524c7 100644 --- a/luxonis_train/callbacks/export_on_train_end.py +++ b/luxonis_train/callbacks/export_on_train_end.py @@ -10,7 +10,7 @@ logger = logging.getLogger(__name__) -@CALLBACKS.register_module() +@CALLBACKS.register() class ExportOnTrainEnd(NeedsCheckpoint): def on_train_end( self, diff --git a/luxonis_train/callbacks/upload_checkpoint.py b/luxonis_train/callbacks/upload_checkpoint.py index b9753e94..0954737a 100644 --- a/luxonis_train/callbacks/upload_checkpoint.py +++ b/luxonis_train/callbacks/upload_checkpoint.py @@ -10,7 +10,7 @@ from luxonis_train.utils.registry import CALLBACKS -@CALLBACKS.register_module() +@CALLBACKS.register() class UploadCheckpoint(pl.Callback): """Callback that uploads best checkpoint based on the validation loss.""" From 0379b2a21842556eb6aaabb204dabf9a75b48f35 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Thu, 16 Jan 2025 18:26:42 -0500 Subject: [PATCH 26/31] removed init arguments --- luxonis_train/loaders/luxonis_loader_torch.py | 110 ++++++------------ 1 file changed, 34 insertions(+), 76 deletions(-) diff --git a/luxonis_train/loaders/luxonis_loader_torch.py b/luxonis_train/loaders/luxonis_loader_torch.py index 7514ff91..be31e28e 100644 --- a/luxonis_train/loaders/luxonis_loader_torch.py +++ b/luxonis_train/loaders/luxonis_loader_torch.py @@ -1,6 +1,5 @@ import logging -from pathlib import Path -from typing import Any, Literal +from typing import Literal import numpy as np import torch @@ -31,88 +30,45 @@ def __init__( bucket_type: Literal["internal", "external"] = "internal", bucket_storage: Literal["local", "s3", "gcs", "azure"] = "local", delete_existing: bool = True, - view: str | list[str] = "train", - augmentation_engine: str - | Literal["albumentations"] = "albumentations", - augmentation_config: Path | str | list[dict[str, Any]] | None = None, - height: int | None = None, - width: int | None = None, - keep_aspect_ratio: bool = True, - out_image_format: Literal["RGB", "BGR"] = "RGB", **kwargs, ): """Torch-compatible loader for Luxonis datasets. - Can either use an already existing dataset or parse a new one from a directory. + Can either use an already existing dataset or parse a new one + from a directory. @type dataset_name: str | None - @param dataset_name: Name of the dataset to load. If not provided, the - C{dataset_dir} argument must be provided instead. If both C{dataset_dir} and - C{dataset_name} are provided, the dataset will be parsed from the directory - and saved with the provided name. + @param dataset_name: Name of the dataset to load. If not + provided, the C{dataset_dir} argument must be provided + instead. If both C{dataset_dir} and C{dataset_name} are + provided, the dataset will be parsed from the directory and + saved with the provided name. @type dataset_dir: str | None - @param dataset_dir: Path to the dataset directory. It can be either a local path - or a URL. The data can be in a zip file. If not provided, C{dataset_name} of - an existing dataset must be provided. + @param dataset_dir: Path to the dataset directory. It can be + either a local path or a URL. The data can be in a zip file. + If not provided, C{dataset_name} of an existing dataset must + be provided. @type dataset_type: str | None - @param dataset_type: Type of the dataset. Only relevant when C{dataset_dir} is - provided. If not provided, the type will be inferred from the directory - structure. + @param dataset_type: Type of the dataset. Only relevant when + C{dataset_dir} is provided. If not provided, the type will + be inferred from the directory structure. @type team_id: str | None @param team_id: Optional unique team identifier for the cloud. @type bucket_type: Literal["internal", "external"] - @param bucket_type: Type of the bucket. Only relevant for remote datasets. - Defaults to 'internal'. + @param bucket_type: Type of the bucket. Only relevant for remote + datasets. Defaults to 'internal'. @type bucket_storage: Literal["local", "s3", "gcs", "azure"] - @param bucket_storage: Type of the bucket storage. Defaults to 'local'. + @param bucket_storage: Type of the bucket storage. Defaults to + 'local'. @type delete_existing: bool - @param delete_existing: Only relevant when C{dataset_dir} is provided. By - default, the dataset is parsed again every time the loader is created - because the underlying data might have changed. If C{delete_existing} is set - to C{False} and a dataset of the same name already exists, the existing + @param delete_existing: Only relevant when C{dataset_dir} is + provided. By default, the dataset is parsed again every time + the loader is created because the underlying data might have + changed. If C{delete_existing} is set to C{False} and a + dataset of the same name already exists, the existing dataset will be used instead of re-parsing the data. - @type view: str | list[str] - @param view: A single split or a list of splits that will be used to create a - view of the dataset. Each split is a string that represents a subset of the - dataset. The available splits depend on the dataset, but usually include - 'train', 'val', and 'test'. Defaults to 'train'. - @type augmentation_engine: Union[Literal["albumentations"], str] - @param augmentation_engine: The augmentation engine to use. - Defaults to C{"albumentations"}. - @type augmentation_config: Optional[Union[List[Dict[str, Any]], - PathType]] - @param augmentation_config: The configuration for the - augmentations. This can be either a list of C{Dict[str, Any]} or - a path to a configuration file. - The config member is a dictionary with two keys: C{name} and - C{params}. C{name} is the name of the augmentation to - instantiate and C{params} is an optional dictionary - of parameters to pass to the augmentation. - - Example:: - - [ - {"name": "HorizontalFlip", "params": {"p": 0.5}}, - {"name": "RandomBrightnessContrast", "params": {"p": 0.1}}, - {"name": "Defocus"} - ] - - @type height: Optional[int] - @param height: The height of the output images. Defaults to - C{None}. - @type width: Optional[int] - @param width: The width of the output images. Defaults to - C{None}. - @type keep_aspect_ratio: bool - @param keep_aspect_ratio: Whether to keep the aspect ratio of the - images. Defaults to C{True}. - @type out_image_format: Literal["RGB", "BGR"] - @param out_image_format: The format of the output images. Defaults - to C{"RGB"}. """ - super().__init__( - view=view if isinstance(view, list) else [view], **kwargs - ) + super().__init__(**kwargs) if dataset_dir is not None: self.dataset = self._parse_dataset( dataset_dir, dataset_name, dataset_type, delete_existing @@ -130,13 +86,15 @@ def __init__( ) self.loader = LuxonisLoader( dataset=self.dataset, - view=view, - augmentation_engine=augmentation_engine, - augmentation_config=augmentation_config, - height=height, - width=width, - keep_aspect_ratio=keep_aspect_ratio, - out_image_format=out_image_format, + view=self.view, + augmentation_engine=self.augmentation_engine, + augmentation_config=[ + aug.model_dump() for aug in self.augmentation_config + ], + height=self.height, + width=self.width, + keep_aspect_ratio=self.keep_aspect_ratio, + out_image_format=self.color_space, ) def __len__(self) -> int: From d6344efa3cad6cb5cd19a2c8c18b106692ceb52d Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Thu, 16 Jan 2025 19:18:36 -0500 Subject: [PATCH 27/31] added missing types --- luxonis_train/assigners/tal_assigner.py | 8 ++++++-- .../losses/adaptive_detection_loss.py | 6 +++--- .../losses/efficient_keypoint_bbox_loss.py | 4 ++-- .../losses/reconstruction_segmentation_loss.py | 6 +++--- .../metrics/mean_average_precision.py | 2 +- luxonis_train/callbacks/training_manager.py | 2 +- luxonis_train/config/config.py | 6 ++++-- luxonis_train/core/core.py | 4 ++-- luxonis_train/core/utils/export_utils.py | 3 ++- luxonis_train/core/utils/infer_utils.py | 7 ++++--- luxonis_train/models/luxonis_lightning.py | 17 ++++++++--------- luxonis_train/nodes/backbones/ddrnet/ddrnet.py | 2 +- .../backbones/efficientrep/efficientrep.py | 2 +- .../nodes/backbones/mobileone/blocks.py | 2 +- .../nodes/backbones/mobileone/mobileone.py | 6 ++++-- .../nodes/backbones/recsubnet/blocks.py | 2 +- luxonis_train/nodes/base_node.py | 2 +- .../nodes/heads/efficient_bbox_head.py | 4 ++-- .../nodes/necks/reppan_neck/reppan_neck.py | 2 +- luxonis_train/strategies/base_strategy.py | 6 ++---- luxonis_train/strategies/triple_lr_sgd.py | 6 +++--- 21 files changed, 53 insertions(+), 46 deletions(-) diff --git a/luxonis_train/assigners/tal_assigner.py b/luxonis_train/assigners/tal_assigner.py index c9435afa..b289fbd6 100644 --- a/luxonis_train/assigners/tal_assigner.py +++ b/luxonis_train/assigners/tal_assigner.py @@ -143,7 +143,7 @@ def _get_alignment_metric( pred_bboxes: Tensor, gt_labels: Tensor, gt_bboxes: Tensor, - ): + ) -> tuple[Tensor, Tensor]: """Calculates anchor alignment metric and IoU between GTs and predicted bboxes. @@ -155,7 +155,11 @@ def _get_alignment_metric( @param gt_labels: Initial GT labels [bs, n_max_boxes, 1] @type gt_bboxes: Tensor @param gt_bboxes: Initial GT bboxes [bs, n_max_boxes, 4] + @rtype: tuple[Tensor, Tensor] + @return: Anchor alignment metric and IoU between GTs and + predicted bboxes. """ + pred_scores = pred_scores.permute(0, 2, 1) gt_labels = gt_labels.to(torch.long) ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long) @@ -175,7 +179,7 @@ def _select_topk_candidates( metrics: Tensor, largest: bool = True, topk_mask: Tensor | None = None, - ): + ) -> Tensor: """Selects k anchors based on provided metrics tensor. @type metrics: Tensor diff --git a/luxonis_train/attached_modules/losses/adaptive_detection_loss.py b/luxonis_train/attached_modules/losses/adaptive_detection_loss.py index 952d11c3..6a7f57f2 100644 --- a/luxonis_train/attached_modules/losses/adaptive_detection_loss.py +++ b/luxonis_train/attached_modules/losses/adaptive_detection_loss.py @@ -132,7 +132,7 @@ def forward( assigned_labels: Tensor, assigned_scores: Tensor, mask_positive: Tensor, - ): + ) -> tuple[Tensor, dict[str, Tensor]]: one_hot_label = F.one_hot(assigned_labels.long(), self.n_classes + 1)[ ..., :-1 ] @@ -161,7 +161,7 @@ def forward( return loss, sub_losses - def _init_parameters(self, features: list[Tensor]): + def _init_parameters(self, features: list[Tensor]) -> None: if not hasattr(self, "gt_bboxes_scale"): self.gt_bboxes_scale = torch.tensor( [ @@ -235,7 +235,7 @@ def _preprocess_bbox_target( out_target[..., 1:] = box_convert(scaled_target, "xywh", "xyxy") return out_target - def _log_assigner_change(self): + def _log_assigner_change(self) -> None: if self._logged_assigner_change: return diff --git a/luxonis_train/attached_modules/losses/efficient_keypoint_bbox_loss.py b/luxonis_train/attached_modules/losses/efficient_keypoint_bbox_loss.py index d9a191e9..98630742 100644 --- a/luxonis_train/attached_modules/losses/efficient_keypoint_bbox_loss.py +++ b/luxonis_train/attached_modules/losses/efficient_keypoint_bbox_loss.py @@ -187,7 +187,7 @@ def forward( gt_kpts: Tensor, pred_kpts: Tensor, area: Tensor, - ): + ) -> tuple[Tensor, dict[str, Tensor]]: device = pred_bboxes.device sigmas = self.sigmas.to(device) d = (gt_kpts[..., 0] - pred_kpts[..., 0]).pow(2) + ( @@ -272,7 +272,7 @@ def dist2kpts_noscale(self, anchor_points: Tensor, kpts: Tensor) -> Tensor: adj_kpts[..., 1] += y_adj return adj_kpts - def _init_parameters(self, features: list[Tensor]): + def _init_parameters(self, features: list[Tensor]) -> None: device = features[0].device super()._init_parameters(features) self.gt_kpts_scale = torch.tensor( diff --git a/luxonis_train/attached_modules/losses/reconstruction_segmentation_loss.py b/luxonis_train/attached_modules/losses/reconstruction_segmentation_loss.py index a5b50f2b..6ba109d0 100644 --- a/luxonis_train/attached_modules/losses/reconstruction_segmentation_loss.py +++ b/luxonis_train/attached_modules/losses/reconstruction_segmentation_loss.py @@ -62,7 +62,7 @@ def prepare( def forward( self, orig: Tensor, recon: Tensor, seg_out: Tensor, an_mask: Tensor - ): + ) -> tuple[Tensor, dict[str, Tensor]]: l2 = self.loss_l2(recon, orig) ssim = self.loss_ssim(recon, orig) focal = self.loss_focal(seg_out, an_mask) @@ -142,8 +142,8 @@ def ssim( img2: Tensor, window_size: int = 11, window: Tensor | None = None, - size_average=True, - val_range=None, + size_average: bool = True, + val_range: float | None = None, ) -> Tensor: if val_range is None: if torch.max(img1) > 128: diff --git a/luxonis_train/attached_modules/metrics/mean_average_precision.py b/luxonis_train/attached_modules/metrics/mean_average_precision.py index d4731988..c082ee39 100644 --- a/luxonis_train/attached_modules/metrics/mean_average_precision.py +++ b/luxonis_train/attached_modules/metrics/mean_average_precision.py @@ -31,7 +31,7 @@ def update( self, outputs: list[dict[str, Tensor]], labels: list[dict[str, Tensor]], - ): + ) -> None: self.metric.update(outputs, labels) def prepare( diff --git a/luxonis_train/callbacks/training_manager.py b/luxonis_train/callbacks/training_manager.py index d9cc7002..390f49b6 100644 --- a/luxonis_train/callbacks/training_manager.py +++ b/luxonis_train/callbacks/training_manager.py @@ -15,7 +15,7 @@ def __init__(self, strategy: BaseTrainingStrategy | None = None): def on_after_backward( self, trainer: pl.Trainer, pl_module: pl.LightningModule - ): + ) -> None: """PyTorch Lightning hook that is called after the backward pass. diff --git a/luxonis_train/config/config.py b/luxonis_train/config/config.py index a73737cb..d44ac480 100644 --- a/luxonis_train/config/config.py +++ b/luxonis_train/config/config.py @@ -498,7 +498,9 @@ class ExportConfig(ArchiveConfig): @model_validator(mode="after") def check_values(self) -> Self: - def pad_values(values: float | list[float] | None): + def pad_values( + values: float | list[float] | None, + ) -> list[float] | None: if values is None: return None if isinstance(values, float): @@ -644,7 +646,7 @@ def is_acyclic(graph: dict[str, list[str]]) -> bool: """ graph = graph.copy() - def dfs(node: str, visited: set[str], recursion_stack: set[str]): + def dfs(node: str, visited: set[str], recursion_stack: set[str]) -> bool: visited.add(node) recursion_stack.add(node) diff --git a/luxonis_train/core/core.py b/luxonis_train/core/core.py index 2ae78162..5f469590 100644 --- a/luxonis_train/core/core.py +++ b/luxonis_train/core/core.py @@ -208,7 +208,7 @@ def __init__( self._exported_models: dict[str, Path] = {} - def _train(self, resume: str | None, *args, **kwargs): + def _train(self, resume: str | None, *args, **kwargs) -> None: status = "success" try: self.pl_trainer.fit(*args, ckpt_path=resume, **kwargs) @@ -245,7 +245,7 @@ def train( LuxonisFileSystem.download(resume_weights, self.run_save_dir) ) - def graceful_exit(signum: int, _): # pragma: no cover + def graceful_exit(signum: int, _: Any) -> None: # pragma: no cover logger.info( f"{signal.Signals(signum).name} received, stopping training..." ) diff --git a/luxonis_train/core/utils/export_utils.py b/luxonis_train/core/utils/export_utils.py index 25e1a3ff..7190c889 100644 --- a/luxonis_train/core/utils/export_utils.py +++ b/luxonis_train/core/utils/export_utils.py @@ -1,4 +1,5 @@ import logging +from collections.abc import Generator from contextlib import contextmanager from pathlib import Path @@ -12,7 +13,7 @@ def replace_weights( module: "luxonis_train.models.LuxonisLightningModule", weights: str | Path | None = None, -): +) -> Generator[None, None, None]: old_weights = None if weights is not None: old_weights = module.state_dict() diff --git a/luxonis_train/core/utils/infer_utils.py b/luxonis_train/core/utils/infer_utils.py index 22cd7521..c4f9085f 100644 --- a/luxonis_train/core/utils/infer_utils.py +++ b/luxonis_train/core/utils/infer_utils.py @@ -7,12 +7,13 @@ import numpy as np import torch import torch.utils.data as torch_data -from luxonis_ml.data import LuxonisDataset +from luxonis_ml.data import DatasetIterator, LuxonisDataset from torch import Tensor import luxonis_train from luxonis_train.attached_modules.visualizers import get_denormalized_images from luxonis_train.loaders import LuxonisLoaderTorch +from luxonis_train.models.luxonis_output import LuxonisOutput IMAGE_FORMATS = { ".bmp", @@ -48,7 +49,7 @@ def process_visualizations( def prepare_and_infer_image( model: "luxonis_train.core.LuxonisModel", img: Tensor -): +) -> LuxonisOutput: """Prepares the image for inference and runs the model.""" img = model.loaders["val"].augment_test_image(img) # type: ignore @@ -196,7 +197,7 @@ def infer_from_directory( """ img_paths = list(img_paths) - def generator(): + def generator() -> DatasetIterator: for img_path in img_paths: yield { "file": img_path, diff --git a/luxonis_train/models/luxonis_lightning.py b/luxonis_train/models/luxonis_lightning.py index c478c126..9643a6f2 100644 --- a/luxonis_train/models/luxonis_lightning.py +++ b/luxonis_train/models/luxonis_lightning.py @@ -9,6 +9,7 @@ from lightning.pytorch.callbacks import ModelCheckpoint, RichModelSummary from lightning.pytorch.utilities import rank_zero_only # type: ignore from luxonis_ml.data import LuxonisDataset +from luxonis_ml.typing import ConfigItem from torch import Size, Tensor, nn import luxonis_train @@ -601,7 +602,7 @@ def export_onnx(self, save_path: str, **kwargs) -> list[str]: old_forward = self.forward - def export_forward(inputs) -> tuple[Tensor, ...]: + def export_forward(inputs: dict[str, Tensor]) -> tuple[Tensor, ...]: outputs = old_forward( inputs, None, @@ -904,14 +905,12 @@ def configure_optimizers( } optimizer = OPTIMIZERS.get(cfg_optimizer.name)(**optim_params) - def get_scheduler(scheduler_cfg, optimizer): - scheduler_class = SCHEDULERS.get( - scheduler_cfg["name"] - ) # For dictionary access - scheduler_params = scheduler_cfg["params"] | { - "optimizer": optimizer - } # Dictionary access for params - return scheduler_class(**scheduler_params) + def get_scheduler( + scheduler_cfg: ConfigItem, optimizer: torch.optim.Optimizer + ) -> torch.optim.lr_scheduler._LRScheduler: + scheduler_class = SCHEDULERS.get(scheduler_cfg.name) + scheduler_params = scheduler_cfg.params | {"optimizer": optimizer} + return scheduler_class(**scheduler_params) # type: ignore if cfg_scheduler.name == "SequentialLR": schedulers_list = [ diff --git a/luxonis_train/nodes/backbones/ddrnet/ddrnet.py b/luxonis_train/nodes/backbones/ddrnet/ddrnet.py index b029dfff..2698c26d 100644 --- a/luxonis_train/nodes/backbones/ddrnet/ddrnet.py +++ b/luxonis_train/nodes/backbones/ddrnet/ddrnet.py @@ -297,7 +297,7 @@ def forward(self, inputs: Tensor) -> list[Tensor]: else: return [x] - def init_params(self): + def init_params(self) -> None: for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_( diff --git a/luxonis_train/nodes/backbones/efficientrep/efficientrep.py b/luxonis_train/nodes/backbones/efficientrep/efficientrep.py index 76b0cf18..529cd816 100644 --- a/luxonis_train/nodes/backbones/efficientrep/efficientrep.py +++ b/luxonis_train/nodes/backbones/efficientrep/efficientrep.py @@ -140,7 +140,7 @@ def __init__( f"No checkpoint available for {self.name}, skipping." ) - def initialize_weights(self): + def initialize_weights(self) -> None: for m in self.modules(): if isinstance(m, nn.Conv2d): pass diff --git a/luxonis_train/nodes/backbones/mobileone/blocks.py b/luxonis_train/nodes/backbones/mobileone/blocks.py index a9006c7e..4b926038 100644 --- a/luxonis_train/nodes/backbones/mobileone/blocks.py +++ b/luxonis_train/nodes/backbones/mobileone/blocks.py @@ -133,7 +133,7 @@ def forward(self, inputs: Tensor) -> Tensor: return self.activation(self.se(out)) - def reparameterize(self): + def reparameterize(self) -> None: """Following works like U{RepVGG: Making VGG-style ConvNets Great Again } architecture used at training time to obtain a plain CNN-like structure diff --git a/luxonis_train/nodes/backbones/mobileone/mobileone.py b/luxonis_train/nodes/backbones/mobileone/mobileone.py index 1ed476cf..c2fe93c0 100644 --- a/luxonis_train/nodes/backbones/mobileone/mobileone.py +++ b/luxonis_train/nodes/backbones/mobileone/mobileone.py @@ -142,7 +142,9 @@ def set_export_mode(self, mode: bool = True) -> None: if hasattr(module, "reparameterize"): module.reparameterize() - def _make_stage(self, planes: int, n_blocks: int, n_se_blocks: int): + def _make_stage( + self, planes: int, n_blocks: int, n_se_blocks: int + ) -> nn.Sequential: """Build a stage of MobileOne model. @type planes: int @@ -161,7 +163,7 @@ def _make_stage(self, planes: int, n_blocks: int, n_se_blocks: int): use_se = False if n_se_blocks > n_blocks: raise ValueError( - "Number of SE blocks cannot " "exceed number of layers." + "Number of SE blocks cannot exceed number of layers." ) if ix >= (n_blocks - n_se_blocks): use_se = True diff --git a/luxonis_train/nodes/backbones/recsubnet/blocks.py b/luxonis_train/nodes/backbones/recsubnet/blocks.py index 0090a3ca..1d557aff 100644 --- a/luxonis_train/nodes/backbones/recsubnet/blocks.py +++ b/luxonis_train/nodes/backbones/recsubnet/blocks.py @@ -118,7 +118,7 @@ def __init__(self, input_channels: int, width: int) -> None: self.encoder_block2 = ConvBlock(width, int(width * 1.1)) self.pool2 = nn.MaxPool2d(2) - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: enc1 = self.encoder_block1(x) enc1_pool = self.pool1(enc1) enc2 = self.encoder_block2(enc1_pool) diff --git a/luxonis_train/nodes/base_node.py b/luxonis_train/nodes/base_node.py index 7dcdbcf3..0a9a208a 100644 --- a/luxonis_train/nodes/base_node.py +++ b/luxonis_train/nodes/base_node.py @@ -378,7 +378,7 @@ def in_width(self) -> int | list[int]: """ return self._get_nth_size(-1) - def load_checkpoint(self, path: str, strict: bool = True): + def load_checkpoint(self, path: str, strict: bool = True) -> None: """Loads checkpoint for the module. If path is url then it downloads it locally and stores it in cache. diff --git a/luxonis_train/nodes/heads/efficient_bbox_head.py b/luxonis_train/nodes/heads/efficient_bbox_head.py index 8d1a55f4..95ebe1be 100644 --- a/luxonis_train/nodes/heads/efficient_bbox_head.py +++ b/luxonis_train/nodes/heads/efficient_bbox_head.py @@ -111,7 +111,7 @@ def __init__( f"No checkpoint available for {self.name}, skipping." ) - def initialize_weights(self): + def initialize_weights(self) -> None: for m in self.modules(): if isinstance(m, nn.Conv2d): pass @@ -201,7 +201,7 @@ def wrap( "distributions": [reg_tensor], } - def _fit_stride_to_n_heads(self): + def _fit_stride_to_n_heads(self) -> Tensor: """Returns correct stride for number of heads and attach index.""" stride = torch.tensor( diff --git a/luxonis_train/nodes/necks/reppan_neck/reppan_neck.py b/luxonis_train/nodes/necks/reppan_neck/reppan_neck.py index 383160e3..f7e0552e 100644 --- a/luxonis_train/nodes/necks/reppan_neck/reppan_neck.py +++ b/luxonis_train/nodes/necks/reppan_neck/reppan_neck.py @@ -180,7 +180,7 @@ def __init__( f"No checkpoint available for {self.name}, skipping." ) - def initialize_weights(self): + def initialize_weights(self) -> None: for m in self.modules(): if isinstance(m, nn.Conv2d): pass diff --git a/luxonis_train/strategies/base_strategy.py b/luxonis_train/strategies/base_strategy.py index 5b812ebf..09bc5392 100644 --- a/luxonis_train/strategies/base_strategy.py +++ b/luxonis_train/strategies/base_strategy.py @@ -20,9 +20,7 @@ def __init__(self, pl_module: pl.LightningModule): @abstractmethod def configure_optimizers( self, - ) -> tuple[list[Optimizer], list[LRScheduler]]: - pass + ) -> tuple[list[Optimizer], list[LRScheduler]]: ... @abstractmethod - def update_parameters(self, *args, **kwargs): - pass + def update_parameters(self, *args, **kwargs) -> None: ... diff --git a/luxonis_train/strategies/triple_lr_sgd.py b/luxonis_train/strategies/triple_lr_sgd.py index b76b5ee8..570c7bc7 100644 --- a/luxonis_train/strategies/triple_lr_sgd.py +++ b/luxonis_train/strategies/triple_lr_sgd.py @@ -51,7 +51,7 @@ def __init__( + 1 ) - def create_scheduler(self): + def create_scheduler(self) -> LambdaLR: scheduler = LambdaLR(self.optimizer, lr_lambda=self.lf) return scheduler @@ -103,7 +103,7 @@ def __init__(self, model: torch.nn.Module, params: dict) -> None: if params: self.params.update(params) - def create_optimizer(self): + def create_optimizer(self) -> torch.optim.Optimizer: batch_norm_weights, regular_weights, biases = [], [], [] for module in self.model.modules(): @@ -166,6 +166,6 @@ def __init__(self, pl_module: pl.LightningModule, params: dict): def configure_optimizers(self) -> tuple[list[Optimizer], list[LambdaLR]]: return [self.optimizer], [self.scheduler.create_scheduler()] - def update_parameters(self, *args, **kwargs): + def update_parameters(self, *args, **kwargs) -> None: current_epoch = self.model.current_epoch self.scheduler.update_learning_rate(current_epoch) From 44adfcb8aa6559f57aa1ce39879e752dc345a86e Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Thu, 16 Jan 2025 21:41:14 -0500 Subject: [PATCH 28/31] fixed anomaly detection --- .../reconstruction_segmentation_loss.py | 2 +- luxonis_train/core/utils/export_utils.py | 2 +- luxonis_train/loaders/base_loader.py | 50 +++++++++++-- luxonis_train/loaders/luxonis_loader_torch.py | 22 +++--- .../loaders/luxonis_perlin_loader_torch.py | 64 +++++++++++------ luxonis_train/loaders/perlin.py | 71 ++++++------------- luxonis_train/utils/general.py | 4 +- 7 files changed, 128 insertions(+), 87 deletions(-) diff --git a/luxonis_train/attached_modules/losses/reconstruction_segmentation_loss.py b/luxonis_train/attached_modules/losses/reconstruction_segmentation_loss.py index 6ba109d0..1e3ff449 100644 --- a/luxonis_train/attached_modules/losses/reconstruction_segmentation_loss.py +++ b/luxonis_train/attached_modules/losses/reconstruction_segmentation_loss.py @@ -55,7 +55,7 @@ def prepare( ) -> tuple[Tensor, Tensor, Tensor, Tensor]: recon = self.get_input_tensors(inputs, "reconstructed")[0] seg_out = self.get_input_tensors(inputs)[0] - an_mask = self.get_label(labels) + an_mask = labels[f"{self.node.task_name}/segmentation"] orig = labels[f"{self.node.task_name}/original/segmentation"] return orig, recon, seg_out, an_mask diff --git a/luxonis_train/core/utils/export_utils.py b/luxonis_train/core/utils/export_utils.py index 7190c889..fb1af27c 100644 --- a/luxonis_train/core/utils/export_utils.py +++ b/luxonis_train/core/utils/export_utils.py @@ -13,7 +13,7 @@ def replace_weights( module: "luxonis_train.models.LuxonisLightningModule", weights: str | Path | None = None, -) -> Generator[None, None, None]: +) -> Generator: old_weights = None if weights is not None: old_weights = module.state_dict() diff --git a/luxonis_train/loaders/base_loader.py b/luxonis_train/loaders/base_loader.py index 19fc4dc6..bade09b4 100644 --- a/luxonis_train/loaders/base_loader.py +++ b/luxonis_train/loaders/base_loader.py @@ -1,12 +1,17 @@ from abc import ABC, abstractmethod from typing import Any, Literal +import cv2 +import numpy as np +import numpy.typing as npt +import torch from luxonis_ml.typing import ConfigItem from luxonis_ml.utils.registry import AutoRegisterMeta from torch import Size, Tensor from torch.utils.data import Dataset from luxonis_train.utils.registry import LOADERS +from luxonis_train.utils.types import Labels from .utils import LuxonisLoaderTorchOutput @@ -67,7 +72,9 @@ class ConfigItem: ) @type image_source: str - @param image_source: Name of the input image group. This can be used for datasets with multiple image sources, e.g. left and right cameras or RGB and depth images. Irrelevant for datasets with only one image source. + @param image_source: Name of the image source. Only relevant for + datasets with multiple image sources, e.g. C{"left"} and C{"right"}. This parameter defines which of these sources is used for + visualizations. @type keep_aspect_ratio: bool @param keep_aspect_ratio: Whether to keep the aspect ratio of the output image after resizing. @@ -142,10 +149,10 @@ def keep_aspect_ratio(self) -> bool: return self._getter_check_none("keep_aspect_ratio") @property - def color_space(self) -> str: + def color_space(self) -> Literal["RGB", "BGR"]: """Color space of the output image. - @type: str + @type: Literal["RGB", "BGR"] """ return self._getter_check_none("color_space") @@ -200,13 +207,19 @@ def augment_test_image(self, img: Tensor) -> Tensor: "`augment_test_image` method to expose this functionality." ) + def __getitem__(self, idx: int) -> LuxonisLoaderTorchOutput: + img, labels = self.get(idx) + if isinstance(img, Tensor): + img = {self.image_source: img} + return img, labels + @abstractmethod def __len__(self) -> int: """Returns length of the dataset.""" ... @abstractmethod - def __getitem__(self, idx: int) -> LuxonisLoaderTorchOutput: + def get(self, idx: int) -> tuple[Tensor | dict[str, Tensor], Labels]: """Loads sample from dataset. @type idx: int @@ -235,6 +248,35 @@ def get_n_keypoints(self) -> dict[str, int] | None: """ return None + def dict_numpy_to_torch( + self, numpy_dictionary: dict[str, np.ndarray] + ) -> dict[str, Tensor]: + """Converts a dictionary of numpy arrays to a dictionary of + torch tensors. + + @type numpy_dictionary: dict[str, np.ndarray] + @param numpy_dictionary: Dictionary of numpy arrays. + @rtype: dict[str, torch.Tensor] + @return: Dictionary of torch tensors. + """ + return { + task: torch.tensor(array) + for task, array in numpy_dictionary.items() + } + + def read_image(self, path: str) -> npt.NDArray[np.float32]: + """Reads an image from a file. + + @type path: str + @param path: Path to the image file. + @rtype: np.ndarray[np.float32] + @return: Image as a numpy array. + """ + img = cv2.imread(path, cv2.IMREAD_COLOR) + if self.color_space == "RGB": + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + return img.astype(np.float32) / 255.0 + def _getter_check_none( self, attribute: Literal[ diff --git a/luxonis_train/loaders/luxonis_loader_torch.py b/luxonis_train/loaders/luxonis_loader_torch.py index be31e28e..c21d5230 100644 --- a/luxonis_train/loaders/luxonis_loader_torch.py +++ b/luxonis_train/loaders/luxonis_loader_torch.py @@ -12,15 +12,17 @@ from luxonis_ml.data.parsers import LuxonisParser from luxonis_ml.enums import DatasetType from torch import Size, Tensor -from typeguard import typechecked +from typing_extensions import override -from .base_loader import BaseLoaderTorch, LuxonisLoaderTorchOutput +from luxonis_train.utils.types import Labels + +from .base_loader import BaseLoaderTorch logger = logging.getLogger(__name__) class LuxonisLoaderTorch(BaseLoaderTorch): - @typechecked + @override def __init__( self, dataset_name: str | None = None, @@ -97,28 +99,30 @@ def __init__( out_image_format=self.color_space, ) + @override def __len__(self) -> int: return len(self.loader) @property + @override def input_shapes(self) -> dict[str, Size]: img = self[0][0][self.image_source] return {self.image_source: img.shape} - def __getitem__(self, idx: int) -> LuxonisLoaderTorchOutput: + @override + def get(self, idx: int) -> tuple[Tensor, Labels]: img, labels = self.loader[idx] img = np.transpose(img, (2, 0, 1)) # HWC to CHW - tensor_img = Tensor(img) - tensor_labels: dict[str, Tensor] = {} - for task, array in labels.items(): - tensor_labels[task] = Tensor(array) + tensor_img = torch.tensor(img) - return {self.image_source: tensor_img}, tensor_labels + return tensor_img, self.dict_numpy_to_torch(labels) + @override def get_classes(self) -> dict[str, list[str]]: return self.dataset.get_classes() + @override def get_n_keypoints(self) -> dict[str, int]: skeletons = self.dataset.get_skeletons() return {task: len(skeletons[task][0]) for task in skeletons} diff --git a/luxonis_train/loaders/luxonis_perlin_loader_torch.py b/luxonis_train/loaders/luxonis_perlin_loader_torch.py index 05e2c449..68b5512b 100644 --- a/luxonis_train/loaders/luxonis_perlin_loader_torch.py +++ b/luxonis_train/loaders/luxonis_perlin_loader_torch.py @@ -1,24 +1,28 @@ import random -from typing import cast +from collections.abc import Generator +from contextlib import contextmanager import numpy as np import torch import torch.nn.functional as F -from luxonis_ml.data import AlbumentationsEngine from luxonis_ml.utils import LuxonisFileSystem from torch import Tensor +from typing_extensions import override + +from luxonis_train.utils.types import Labels -from .base_loader import LuxonisLoaderTorchOutput from .luxonis_loader_torch import LuxonisLoaderTorch from .perlin import apply_anomaly_to_img class LuxonisLoaderPerlinNoise(LuxonisLoaderTorch): + @override def __init__( self, *args, anomaly_source_path: str, noise_prob: float = 0.5, + beta: float | None = None, **kwargs, ): """Custom loader for Luxonis datasets that adds Perlin noise @@ -58,37 +62,51 @@ def __init__( raise ValueError( "This loader only supports datasets with a single task." ) + self.beta = beta self.task_name = next(iter(self.loader.dataset.get_tasks())) + self.augmentations = self.loader.augmentations - augmentations = cast(AlbumentationsEngine, self.loader.augmentations) - if augmentations is None or augmentations.pixel_transform is None: - self.pixel_augs = None - else: - self.pixel_augs = augmentations.pixel_transform + @override + def get(self, idx: int) -> tuple[Tensor, Labels]: + with _freeze_seed(): + img, labels = self.loader[idx] - def __getitem__(self, idx: int) -> LuxonisLoaderTorchOutput: - img, labels = self.loader[idx] + an_mask = torch.tensor(labels.pop(f"{self.task_name}/segmentation"))[ + 0, ... + ] img = np.transpose(img, (2, 0, 1)) - tensor_img = Tensor(img) + tensor_img = torch.tensor(img) + tensor_labels = self.dict_numpy_to_torch(labels) if self.view[0] == "train" and random.random() < self.noise_prob: + anomaly_path = random.choice(self.anomaly_files) + anomaly_img = self.read_image(str(anomaly_path)) + + if self.augmentations is not None: + anomaly_img = self.augmentations.apply([(anomaly_img, {})])[0] + + anomaly_img = torch.tensor(anomaly_img).permute(2, 0, 1) aug_tensor_img, an_mask = apply_anomaly_to_img( - tensor_img, - anomaly_source_paths=self.anomaly_files, - pixel_augs=self.pixel_augs, + tensor_img, anomaly_img, self.beta ) else: aug_tensor_img = tensor_img - h, w = aug_tensor_img.shape[-2:] - an_mask = torch.zeros((h, w)) - tensor_labels = {f"{self.task_name}/original/segmentation": tensor_img} - for task, array in labels.items(): - tensor_labels[task] = Tensor(array) + an_mask = F.one_hot(an_mask.long(), 2).permute(2, 0, 1).float() + + tensor_labels = { + f"{self.task_name}/original/segmentation": tensor_img, + f"{self.task_name}/segmentation": an_mask, + } + + return aug_tensor_img, tensor_labels - tensor_labels[f"{self.task_name}/segmentation"] = ( - F.one_hot(an_mask.long(), 2).permute(2, 0, 1).float() - ) - return {self.image_source: aug_tensor_img}, tensor_labels +@contextmanager +def _freeze_seed() -> Generator: + python_seed = random.getstate() + numpy_seed = np.random.get_state() + yield + random.setstate(python_seed) + np.random.set_state(numpy_seed) diff --git a/luxonis_train/loaders/perlin.py b/luxonis_train/loaders/perlin.py index cd201d35..6a973d85 100644 --- a/luxonis_train/loaders/perlin.py +++ b/luxonis_train/loaders/perlin.py @@ -1,13 +1,10 @@ -import random -from pathlib import Path -from typing import Callable, List, Tuple +from typing import Callable, Tuple -import cv2 -import numpy as np import torch +from torch import Tensor -def compute_gradients(res: tuple[int, int]) -> torch.Tensor: +def compute_gradients(res: tuple[int, int]) -> Tensor: angles = 2 * torch.pi * torch.rand(res[0] + 1, res[1] + 1) gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1) return gradients @@ -15,21 +12,21 @@ def compute_gradients(res: tuple[int, int]) -> torch.Tensor: @torch.jit.script def lerp_torch( # pragma: no cover - x: torch.Tensor, y: torch.Tensor, w: torch.Tensor -) -> torch.Tensor: + x: Tensor, y: Tensor, w: Tensor +) -> Tensor: return (y - x) * w + x -def fade_function(t: torch.Tensor) -> torch.Tensor: +def fade_function(t: Tensor) -> Tensor: return 6 * t**5 - 15 * t**4 + 10 * t**3 def tile_grads( slice1: Tuple[int, int | None], slice2: Tuple[int, int | None], - gradients: torch.Tensor, + gradients: Tensor, d: Tuple[int, int], -) -> torch.Tensor: +) -> Tensor: return ( gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]] .repeat_interleave(d[0], 0) @@ -38,11 +35,11 @@ def tile_grads( def dot( - grad: torch.Tensor, + grad: Tensor, shift: Tuple[int, int], - grid: torch.Tensor, + grid: Tensor, shape: Tuple[int, int], -) -> torch.Tensor: +) -> Tensor: return ( torch.stack( ( @@ -58,8 +55,8 @@ def dot( def rand_perlin_2d( shape: Tuple[int, int], res: Tuple[int, int], - fade: Callable[[torch.Tensor], torch.Tensor] = fade_function, -) -> torch.Tensor: + fade: Callable[[Tensor], Tensor] = fade_function, +) -> Tensor: delta = (res[0] / shape[0], res[1] / shape[1]) d = (shape[0] // res[0], shape[1] // res[1]) grid_x, grid_y = torch.meshgrid( @@ -92,7 +89,7 @@ def rand_perlin_2d( @torch.jit.script -def rotate_noise(noise: torch.Tensor) -> torch.Tensor: # pragma: no cover +def rotate_noise(noise: Tensor) -> Tensor: # pragma: no cover angle = torch.rand(1) * 2 * torch.pi h, w = noise.shape center_y, center_x = h // 2, w // 2 @@ -117,7 +114,7 @@ def generate_perlin_noise( min_perlin_scale: int = 0, perlin_scale: int = 6, threshold: float = 0.5, -) -> torch.Tensor: +) -> Tensor: perlin_scalex = 2 ** int( torch.randint(min_perlin_scale, perlin_scale, (1,)).item() ) @@ -136,21 +133,14 @@ def generate_perlin_noise( return perlin_mask -def load_image_as_numpy(img_path: Path) -> np.ndarray: - image = cv2.imread(str(img_path), cv2.IMREAD_COLOR) - image = image.astype(np.float32) / 255.0 - return image - - def apply_anomaly_to_img( - img: torch.Tensor, - anomaly_source_paths: List[Path], + img: Tensor, + anomaly_img: Tensor, beta: float | None = None, - pixel_augs: Callable | None = None, # type: ignore -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[Tensor, Tensor]: """Applies Perlin noise-based anomalies to a single image (C, H, W). - @type img: torch.Tensor + @type img: Tensor @param img: The input image tensor of shape (C, H, W). @type anomaly_source_paths: List[str] @param anomaly_source_paths: List of file paths to the anomaly images. @@ -159,27 +149,12 @@ def apply_anomaly_to_img( @type beta: float | None @param beta: A blending factor for anomaly and noise. If None, a random value in the range [0, 0.8] is used. Defaults to C{None}. - @rtype: Tuple[torch.Tensor, torch.Tensor] + @rtype: Tuple[Tensor, Tensor] @return: A tuple containing: - - augmented_img (torch.Tensor): The augmented image with applied anomaly and Perlin noise. - - perlin_mask (torch.Tensor): The Perlin noise mask applied to the image. + - augmented_img (Tensor): The augmented image with applied anomaly and Perlin noise. + - perlin_mask (Tensor): The Perlin noise mask applied to the image. """ - sampled_anomaly_image_path = random.choice(anomaly_source_paths) - - anomaly_image = load_image_as_numpy(sampled_anomaly_image_path) - - anomaly_image = cv2.resize( - anomaly_image, - (img.shape[2], img.shape[1]), - interpolation=cv2.INTER_LINEAR, - ) - - if pixel_augs is not None: - anomaly_image = pixel_augs(image=anomaly_image)["image"] - - anomaly_image = torch.tensor(anomaly_image).permute(2, 0, 1) - perlin_mask = generate_perlin_noise( shape=(img.shape[1], img.shape[2]), ) @@ -189,7 +164,7 @@ def apply_anomaly_to_img( augmented_img = ( (1 - perlin_mask).unsqueeze(0) * img - + (1 - beta) * perlin_mask.unsqueeze(0) * anomaly_image + + (1 - beta) * perlin_mask.unsqueeze(0) * anomaly_img + beta * perlin_mask.unsqueeze(0) * img ) diff --git a/luxonis_train/utils/general.py b/luxonis_train/utils/general.py index 390ddfd2..746df792 100644 --- a/luxonis_train/utils/general.py +++ b/luxonis_train/utils/general.py @@ -187,7 +187,9 @@ def safe_download( if i == retry: logger.warning("Download failed, retry limit reached.") return None - logger.warning(f"Download failed, retrying {i+1}/{retry} ...") + logger.warning( + f"Download failed, retrying {i + 1}/{retry} ..." + ) def clean_url(url: str) -> str: From c76135c5cfacddcc77de9c9182c7ec52a7de368f Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Fri, 17 Jan 2025 00:24:02 -0500 Subject: [PATCH 29/31] converting to float --- luxonis_train/loaders/base_loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/luxonis_train/loaders/base_loader.py b/luxonis_train/loaders/base_loader.py index bade09b4..bb3b9c13 100644 --- a/luxonis_train/loaders/base_loader.py +++ b/luxonis_train/loaders/base_loader.py @@ -260,7 +260,7 @@ def dict_numpy_to_torch( @return: Dictionary of torch tensors. """ return { - task: torch.tensor(array) + task: torch.tensor(array).float() for task, array in numpy_dictionary.items() } From 058f449770cab742d1623ad375d3f4d3dff802df Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Fri, 17 Jan 2025 04:24:43 -0500 Subject: [PATCH 30/31] helper function --- luxonis_train/core/core.py | 2 +- luxonis_train/loaders/base_loader.py | 6 ++-- luxonis_train/utils/__init__.py | 2 ++ luxonis_train/utils/general.py | 41 +++++++++++++++++++++++++++- 4 files changed, 45 insertions(+), 6 deletions(-) diff --git a/luxonis_train/core/core.py b/luxonis_train/core/core.py index 5f469590..1ea49bab 100644 --- a/luxonis_train/core/core.py +++ b/luxonis_train/core/core.py @@ -133,7 +133,7 @@ def __init__( augmentation_config=self.cfg_preprocessing.get_active_augmentations(), color_space=self.cfg_preprocessing.color_space, keep_aspect_ratio=self.cfg_preprocessing.keep_aspect_ratio, - **self.cfg.loader.params, + **self.cfg.loader.params, # type: ignore ) for name, loader in self.loaders.items(): diff --git a/luxonis_train/loaders/base_loader.py b/luxonis_train/loaders/base_loader.py index bb3b9c13..92e54794 100644 --- a/luxonis_train/loaders/base_loader.py +++ b/luxonis_train/loaders/base_loader.py @@ -10,6 +10,7 @@ from torch import Size, Tensor from torch.utils.data import Dataset +from luxonis_train.utils.general import get_attribute_check_none from luxonis_train.utils.registry import LOADERS from luxonis_train.utils.types import Labels @@ -290,7 +291,4 @@ def _getter_check_none( "color_space", ], ) -> Any: - value = getattr(self, f"_{attribute}") - if value is None: - raise ValueError(f"{attribute} is not set") - return value + return get_attribute_check_none(self, attribute) diff --git a/luxonis_train/utils/__init__.py b/luxonis_train/utils/__init__.py index d7c0be9f..2f2b550a 100644 --- a/luxonis_train/utils/__init__.py +++ b/luxonis_train/utils/__init__.py @@ -9,6 +9,7 @@ from .dataset_metadata import DatasetMetadata from .exceptions import IncompatibleException from .general import ( + get_attribute_check_none, get_with_default, infer_upscale_factor, make_divisible, @@ -42,4 +43,5 @@ "get_sigmas", "traverse_graph", "insert_class", + "get_attribute_check_none", ] diff --git a/luxonis_train/utils/general.py b/luxonis_train/utils/general.py index 746df792..94cd1613 100644 --- a/luxonis_train/utils/general.py +++ b/luxonis_train/utils/general.py @@ -3,7 +3,7 @@ import os import urllib.parse from pathlib import Path, PurePosixPath -from typing import TypeVar +from typing import Any, TypeVar import torch from torch import Size, Tensor @@ -205,3 +205,42 @@ def clean_url(url: str) -> str: def url2file(url: str) -> str: """Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt.""" return Path(clean_url(url)).name + + +def get_attribute_check_none(obj: object, attribute: str) -> Any: + """Get private attribute from object and check if it is not None. + + Example: + + >>> class Person: + ... def __init__(self, age: int | None = None): + ... self._age = age + ... + ... @property + ... def age(self): + ... return get_attribute_check_none(self, "age") + + >>> mike = Person(20) + >>> print(mike.age) + 20 + + >>> amanda = Person() + >>> print(amanda.age) + Traceback (most recent call last): + ValueError: attribute 'age' was not set + + @type obj: object + @param obj: Object to get attribute from. + + @type attribute: str + @param attribute: Name of the attribute to get. + + @rtype: Any + @return: Value of the attribute. + + @raise ValueError: If the attribute is None. + """ + value = getattr(obj, f"_{attribute}") + if value is None: + raise ValueError(f"attribute '{attribute}' was not set") + return value From b8f8d7d907b3ee0d85762925ee145fe3d57eb41b Mon Sep 17 00:00:00 2001 From: Jernej Sabadin Date: Fri, 17 Jan 2025 13:21:25 +0100 Subject: [PATCH 31/31] fix predefined models --- .../instance_segmentation_heavy_model.yaml | 36 +++++++++------- .../instance_segmentation_light_model.yaml | 36 +++++++++------- .../instance_segmentation_model.py | 43 ++++++++++++------- 3 files changed, 69 insertions(+), 46 deletions(-) diff --git a/configs/instance_segmentation_heavy_model.yaml b/configs/instance_segmentation_heavy_model.yaml index 42cedd87..58331434 100644 --- a/configs/instance_segmentation_heavy_model.yaml +++ b/configs/instance_segmentation_heavy_model.yaml @@ -6,6 +6,10 @@ model: name: InstanceSegmentationModel params: variant: heavy + loss_params: + bbox_loss_weight: 60 # Should be 7.5 * accumulate_grad_batches for best results + class_loss_weight: 4 # Should be 0.5 * accumulate_grad_batches for best results + dfl_loss_weight: 12 # Should be 1.5 * accumulate_grad_batches for best results loader: params: @@ -19,27 +23,29 @@ trainer: active: true batch_size: 8 - epochs: &epochs 200 - n_workers: 4 + epochs: &epochs 300 + accumulate_grad_batches: 8 # For best results, always accumulate gradients to effectively use 64 batch size + n_workers: 8 validation_interval: 10 n_log_images: 8 callbacks: + - name: EMACallback + params: + decay: 0.9999 + use_dynamic_decay: True + decay_tau: 2000 - name: ExportOnTrainEnd - name: TestOnTrainEnd - optimizer: - name: SGD - params: + training_strategy: + name: "TripleLRSGDStrategy" + params: + warmup_epochs: 3 + warmup_bias_lr: 0.1 + warmup_momentum: 0.8 lr: 0.01 - momentum: 0.937 + lre: 0.0001 + momentum: 0.937 weight_decay: 0.0005 - dampening: 0.0 - nesterov: true - - scheduler: - name: CosineAnnealingLR - params: - T_max: *epochs - eta_min: 0.0001 - last_epoch: -1 + nesterov: True \ No newline at end of file diff --git a/configs/instance_segmentation_light_model.yaml b/configs/instance_segmentation_light_model.yaml index 24d764ed..1517998c 100644 --- a/configs/instance_segmentation_light_model.yaml +++ b/configs/instance_segmentation_light_model.yaml @@ -6,6 +6,10 @@ model: name: InstanceSegmentationModel params: variant: light + loss_params: + bbox_loss_weight: 60 # Should be 7.5 * accumulate_grad_batches for best results + class_loss_weight: 4 # Should be 0.5 * accumulate_grad_batches for best results + dfl_loss_weight: 12 # Should be 1.5 * accumulate_grad_batches for best results loader: params: @@ -19,27 +23,29 @@ trainer: active: true batch_size: 8 - epochs: &epochs 200 - n_workers: 4 + epochs: &epochs 300 + accumulate_grad_batches: 8 # For best results, always accumulate gradients to effectively use 64 batch size + n_workers: 8 validation_interval: 10 n_log_images: 8 callbacks: + - name: EMACallback + params: + decay: 0.9999 + use_dynamic_decay: True + decay_tau: 2000 - name: ExportOnTrainEnd - name: TestOnTrainEnd - optimizer: - name: SGD - params: + training_strategy: + name: "TripleLRSGDStrategy" + params: + warmup_epochs: 3 + warmup_bias_lr: 0.1 + warmup_momentum: 0.8 lr: 0.01 - momentum: 0.937 + lre: 0.0001 + momentum: 0.937 weight_decay: 0.0005 - dampening: 0.0 - nesterov: true - - scheduler: - name: CosineAnnealingLR - params: - T_max: *epochs - eta_min: 0.0001 - last_epoch: -1 + nesterov: True diff --git a/luxonis_train/config/predefined_models/instance_segmentation_model.py b/luxonis_train/config/predefined_models/instance_segmentation_model.py index 28477572..b30ca6f4 100644 --- a/luxonis_train/config/predefined_models/instance_segmentation_model.py +++ b/luxonis_train/config/predefined_models/instance_segmentation_model.py @@ -61,7 +61,9 @@ def __init__( head_params: Params | None = None, loss_params: Params | None = None, visualizer_params: Params | None = None, - task_name: str | None = None, + task_name: str = "", + enable_confusion_matrix: bool = True, + confusion_matrix_params: Params | None = None, ): var_config = get_variant(variant) @@ -74,9 +76,11 @@ def __init__( self.backbone = backbone or var_config.backbone self.neck_params = neck_params or var_config.neck_params self.head_params = head_params or {} - self.loss_params = loss_params or {"n_warmup_epochs": 0} + self.loss_params = loss_params or {} self.visualizer_params = visualizer_params or {} - self.task_name = task_name or "instance_segmentation" + self.task_name = task_name + self.enable_confusion_matrix = enable_confusion_matrix + self.confusion_matrix_params = confusion_matrix_params or {} @property def nodes(self) -> list[ModelNodeConfig]: @@ -85,7 +89,7 @@ def nodes(self) -> list[ModelNodeConfig]: nodes = [ ModelNodeConfig( name=self.backbone, - alias=f"{self.backbone}-{self.task_name}", + alias=f"{self.task_name}/{self.backbone}", freezing=self.backbone_params.pop("freezing", {}), params=self.backbone_params, ), @@ -94,8 +98,8 @@ def nodes(self) -> list[ModelNodeConfig]: nodes.append( ModelNodeConfig( name="RepPANNeck", - alias=f"RepPANNeck-{self.task_name}", - inputs=[f"{self.backbone}-{self.task_name}"], + alias=f"{self.task_name}/RepPANNeck", + inputs=[f"{self.task_name}/{self.backbone}"], freezing=self.neck_params.pop("freezing", {}), params=self.neck_params, ) @@ -104,13 +108,13 @@ def nodes(self) -> list[ModelNodeConfig]: nodes.append( ModelNodeConfig( name="PrecisionSegmentBBoxHead", - alias=f"PrecisionSegmentBBoxHead-{self.task_name}", + alias=f"{self.task_name}/PrecisionSegmentBBoxHead", freezing=self.head_params.pop("freezing", {}), - inputs=[f"RepPANNeck-{self.task_name}"] + inputs=[f"{self.task_name}/RepPANNeck"] if self.use_neck else [f"{self.backbone}-{self.task_name}"], params=self.head_params, - task=self.task_name, + # task=self.task_name, ) ) return nodes @@ -121,8 +125,7 @@ def losses(self) -> list[LossModuleConfig]: return [ LossModuleConfig( name="PrecisionDFLSegmentationLoss", - alias=f"PrecisionDFLSegmentationLoss-{self.task_name}", - attached_to=f"PrecisionSegmentBBoxHead-{self.task_name}", + attached_to=f"{self.task_name}/PrecisionSegmentBBoxHead", params=self.loss_params, weight=1.0, ) @@ -131,14 +134,23 @@ def losses(self) -> list[LossModuleConfig]: @property def metrics(self) -> list[MetricModuleConfig]: """Defines the metrics used for evaluation.""" - return [ + metrics = [ MetricModuleConfig( name="MeanAveragePrecision", - alias=f"MeanAveragePrecision-{self.task_name}", - attached_to=f"PrecisionSegmentBBoxHead-{self.task_name}", + attached_to=f"{self.task_name}/PrecisionSegmentBBoxHead", is_main_metric=True, ), ] + if self.enable_confusion_matrix: + metrics.append( + MetricModuleConfig( + name="ConfusionMatrix", + alias=f"{self.task_name}/ConfusionMatrix", + attached_to=f"{self.task_name}/PrecisionSegmentBBoxHead", + params={**self.confusion_matrix_params}, + ) + ) + return metrics @property def visualizers(self) -> list[AttachedModuleConfig]: @@ -146,8 +158,7 @@ def visualizers(self) -> list[AttachedModuleConfig]: return [ AttachedModuleConfig( name="InstanceSegmentationVisualizer", - alias=f"InstanceSegmentationVisualizer-{self.task_name}", - attached_to=f"PrecisionSegmentBBoxHead-{self.task_name}", + attached_to=f"{self.task_name}/PrecisionSegmentBBoxHead", params=self.visualizer_params, ) ]