From f91fc4d5d712ab4610a901b233f449fbad5c888c Mon Sep 17 00:00:00 2001 From: Jernej Sabadin Date: Thu, 21 Nov 2024 10:54:56 +0100 Subject: [PATCH 01/21] Confusion matrix metric for det, seg and cls tasks --- .../attached_modules/metrics/__init__.py | 2 + .../metrics/confusion_matrix.py | 247 ++++++++++++++++++ 2 files changed, 249 insertions(+) create mode 100644 luxonis_train/attached_modules/metrics/confusion_matrix.py diff --git a/luxonis_train/attached_modules/metrics/__init__.py b/luxonis_train/attached_modules/metrics/__init__.py index b1dc40ea..cdd0b3ac 100644 --- a/luxonis_train/attached_modules/metrics/__init__.py +++ b/luxonis_train/attached_modules/metrics/__init__.py @@ -1,4 +1,5 @@ from .base_metric import BaseMetric +from .confusion_matrix import ConfusionMatrix from .mean_average_precision import MeanAveragePrecision from .mean_average_precision_keypoints import MeanAveragePrecisionKeypoints from .object_keypoint_similarity import ObjectKeypointSimilarity @@ -14,4 +15,5 @@ "ObjectKeypointSimilarity", "Precision", "Recall", + "ConfusionMatrix", ] diff --git a/luxonis_train/attached_modules/metrics/confusion_matrix.py b/luxonis_train/attached_modules/metrics/confusion_matrix.py new file mode 100644 index 00000000..a3323cab --- /dev/null +++ b/luxonis_train/attached_modules/metrics/confusion_matrix.py @@ -0,0 +1,247 @@ +from typing import Literal + +import torch +from torch import Tensor +from torchvision.ops import box_convert, box_iou + +from luxonis_train.enums import TaskType +from luxonis_train.utils import Labels, Packet + +from .base_metric import BaseMetric + + +class ConfusionMatrix(BaseMetric[Tensor, Tensor]): + def __init__( + self, + box_format: Literal["xyxy", "xywh", "cxcywh"] = "xyxy", + iou_threshold: float = 0.45, + confidence_threshold: float = 0.25, + **kwargs, + ): + """Compute the confusion matrix for classification, + segmentation, and object detection tasks. + + @type box_format: Literal["xyxy", "xywh", "cxcywh"] + @param box_format: The format of the bounding boxes. Can be one + of "xyxy", "xywh", or "cxcywh". + @type iou_threshold: float + @param iou_threshold: The IoU threshold for matching predictions + to ground truth. + @type confidence_threshold: float + @param confidence_threshold: The confidence threshold for + filtering predictions. + """ + super().__init__(**kwargs) + allowed_box_formats = ("xyxy", "xywh", "cxcywh") + if box_format not in allowed_box_formats: + raise ValueError( + f"Expected argument `box_format` to be one of {allowed_box_formats} but got {box_format}" + ) + self.box_format = box_format + self.iou_threshold = iou_threshold + self.confidence_threshold = confidence_threshold + + self.is_classification = TaskType.CLASSIFICATION in self.node.tasks + self.is_detection = TaskType.BOUNDINGBOX in self.node.tasks + self.is_segmentation = TaskType.SEGMENTATION in self.node.tasks + + if self.is_classification: + self.add_state( + "classification_cm", + default=torch.zeros( + self.n_classes, self.n_classes, dtype=torch.int64 + ), + dist_reduce_fx="sum", + ) + if self.is_segmentation: + self.add_state( + "segmentation_cm", + default=torch.zeros( + self.n_classes, self.n_classes, dtype=torch.int64 + ), + dist_reduce_fx="sum", + ) + if self.is_detection: + self.add_state( + "detection_cm", + default=torch.zeros( + self.n_classes + 1, self.n_classes + 1, dtype=torch.int64 + ), # +1 for background + dist_reduce_fx="sum", + ) + + def prepare(self, inputs: Packet[Tensor], labels: Labels): + if self.is_detection: + out_bbox = self.get_input_tensors( + inputs, TaskType.BOUNDINGBOX + ) # Predictions list of bs elements with shape [N,6] where 6 is [x1, y1, x2, y2, score, class] + bbox = self.get_label( + labels, TaskType.BOUNDINGBOX + ) # Ground truth shape of [M, 6] where 6 is [image_idx, class, x1, y1, w, h] + bbox = bbox.to(out_bbox[0].device) + bbox[..., 2:6] = box_convert( + bbox[..., 2:6], "xywh", "xyxy" + ) # Convert ground truth to "xyxy" + scale_factors = torch.tensor( + [ + self.original_in_shape[2], + self.original_in_shape[1], + self.original_in_shape[2], + self.original_in_shape[1], + ], + device=bbox.device, + ) + bbox[..., 2:6] *= scale_factors + return out_bbox, bbox + + if self.is_classification: + prediction = ( + self.get_input_tensors(inputs, TaskType.CLASSIFICATION), + ) + target = self.get_label(labels, TaskType.CLASSIFICATION).to( + prediction[0].device + ) + return prediction, target + + if self.is_segmentation: + prediction = ( + self.get_input_tensors(inputs, TaskType.SEGMENTATION), + ) + target = self.get_label(labels, TaskType.SEGMENTATION).to( + prediction[0].device + ) + return prediction, target + + def update(self, preds: list[Tensor], target: Tensor) -> None: + if self.is_classification: + pred_classes = preds[0].argmax(dim=1) # [B] + target_classes = target.argmax(dim=1) # [B] + self.classification_cm += self._compute_confusion_matrix( + pred_classes, target_classes + ) + + if self.is_segmentation: + pred_masks = preds[0].argmax(dim=1) # [B, H, W] + target_masks = target.argmax(dim=1) # [B, H, W] + self.segmentation_cm += self._compute_confusion_matrix( + pred_masks.view(-1), target_masks.view(-1) + ) + + if self.is_detection: + self.detection_cm += self._compute_detection_confusion_matrix( + preds, target + ) + + def compute(self) -> dict[str, Tensor]: + results = {} + if self.is_classification: + results["classification_confusion_matrix"] = self.classification_cm + print("classification_confusion_matrix:\n", self.classification_cm) + if self.is_segmentation: + results["segmentation_confusion_matrix"] = self.segmentation_cm + print("segmentation_confusion_matrix:\n", self.segmentation_cm) + if self.is_detection: + results["detection_confusion_matrix"] = self.detection_cm + print("detection_confusion_matrix:\n", self.detection_cm) + + return torch.tensor( + [-1.0], dtype=torch.float32 + ) # Change this once luxonis-ml supports returning tensor as a metric + + def _compute_confusion_matrix( + self, preds: Tensor, targets: Tensor + ) -> Tensor: + """Compute a confusion matrix using efficient vectorized + operations.""" + mask = (targets >= 0) & (targets < self.n_classes) + preds = preds[mask] + targets = targets[mask] + + indices = targets * self.n_classes + preds + cm = torch.bincount( + indices, + minlength=self.n_classes * self.n_classes, + ).reshape(self.n_classes, self.n_classes) + return cm + + def _compute_detection_confusion_matrix( + self, preds: list[Tensor], targets: Tensor + ) -> Tensor: + """Compute a confusion matrix for object detection tasks. + + @type preds: list[Tensor] + @param preds: List of predictions for each image. Each tensor + has shape [N, 6] where 6 is for [x1, y1, x2, y2, score, + class] + """ + cm = torch.zeros( + self.n_classes + 1, + self.n_classes + 1, + dtype=torch.int64, + device=preds[0].device, + ) + + for img_idx, pred in enumerate(preds): + img_targets = targets[targets[:, 0] == img_idx] + + if img_targets.shape[0] == 0: + for pred_class in pred[:, 5].int(): + cm[pred_class, self.n_classes] += 1 # False positive + continue + + if pred.shape[0] == 0: + for gt_class in img_targets[:, 1].int(): + cm[self.n_classes, gt_class] += 1 # False negative + continue + + pred = pred[pred[:, 4] > self.confidence_threshold] + pred_boxes = pred[:, :4] + pred_classes = pred[:, 5].int() + + gt_boxes = img_targets[:, 2:] + gt_classes = img_targets[:, 1].int() + + iou = box_iou(gt_boxes, pred_boxes) + iou_thresholded = iou > self.iou_threshold + + if iou_thresholded.any(): + iou_max, pred_max_idx = torch.max( + iou, dim=1 + ) # Maximum IoU for each GT + iou_gt_mask = iou_max > self.iou_threshold + gt_match_idx = torch.arange( + len(gt_boxes), device=gt_boxes.device + )[iou_gt_mask] + pred_match_idx = pred_max_idx[iou_gt_mask] + + for gt_idx, pred_idx in zip(gt_match_idx, pred_match_idx): + gt_class = gt_classes[gt_idx] + pred_class = pred_classes[pred_idx] + cm[pred_class, gt_class] += 1 + + unmatched_gt_mask = ~torch.isin( + torch.arange(len(gt_boxes), device=gt_boxes.device), + gt_match_idx, + ) + for gt_idx in torch.arange( + len(gt_boxes), device=gt_boxes.device + )[unmatched_gt_mask]: + gt_class = gt_classes[gt_idx] + cm[self.n_classes, gt_class] += 1 + + unmatched_pred_mask = ~torch.isin( + torch.arange(len(pred_boxes), device=gt_boxes.device), + pred_match_idx, + ) + for pred_idx in torch.arange( + len(pred_boxes), device=gt_boxes.device + )[unmatched_pred_mask]: + pred_class = pred_classes[pred_idx] + cm[pred_class, self.n_classes] += 1 + else: + for gt_class in gt_classes: + cm[self.n_classes, gt_class] += 1 + for pred_class in pred_classes: + cm[pred_class, self.n_classes] += 1 + + return cm From 4d62b089cee7e1194fa2618d0c722c182e969f86 Mon Sep 17 00:00:00 2001 From: Jernej Sabadin Date: Thu, 21 Nov 2024 11:03:48 +0100 Subject: [PATCH 02/21] removed comments --- .../metrics/confusion_matrix.py | 20 ++++++------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/luxonis_train/attached_modules/metrics/confusion_matrix.py b/luxonis_train/attached_modules/metrics/confusion_matrix.py index a3323cab..643e2081 100644 --- a/luxonis_train/attached_modules/metrics/confusion_matrix.py +++ b/luxonis_train/attached_modules/metrics/confusion_matrix.py @@ -72,16 +72,10 @@ def __init__( def prepare(self, inputs: Packet[Tensor], labels: Labels): if self.is_detection: - out_bbox = self.get_input_tensors( - inputs, TaskType.BOUNDINGBOX - ) # Predictions list of bs elements with shape [N,6] where 6 is [x1, y1, x2, y2, score, class] - bbox = self.get_label( - labels, TaskType.BOUNDINGBOX - ) # Ground truth shape of [M, 6] where 6 is [image_idx, class, x1, y1, w, h] + out_bbox = self.get_input_tensors(inputs, TaskType.BOUNDINGBOX) + bbox = self.get_label(labels, TaskType.BOUNDINGBOX) bbox = bbox.to(out_bbox[0].device) - bbox[..., 2:6] = box_convert( - bbox[..., 2:6], "xywh", "xyxy" - ) # Convert ground truth to "xyxy" + bbox[..., 2:6] = box_convert(bbox[..., 2:6], "xywh", "xyxy") scale_factors = torch.tensor( [ self.original_in_shape[2], @@ -186,12 +180,12 @@ def _compute_detection_confusion_matrix( if img_targets.shape[0] == 0: for pred_class in pred[:, 5].int(): - cm[pred_class, self.n_classes] += 1 # False positive + cm[pred_class, self.n_classes] += 1 continue if pred.shape[0] == 0: for gt_class in img_targets[:, 1].int(): - cm[self.n_classes, gt_class] += 1 # False negative + cm[self.n_classes, gt_class] += 1 continue pred = pred[pred[:, 4] > self.confidence_threshold] @@ -205,9 +199,7 @@ def _compute_detection_confusion_matrix( iou_thresholded = iou > self.iou_threshold if iou_thresholded.any(): - iou_max, pred_max_idx = torch.max( - iou, dim=1 - ) # Maximum IoU for each GT + iou_max, pred_max_idx = torch.max(iou, dim=1) iou_gt_mask = iou_max > self.iou_threshold gt_match_idx = torch.arange( len(gt_boxes), device=gt_boxes.device From dd9a7225684c85c56fb3778c16efa5d1caddd082 Mon Sep 17 00:00:00 2001 From: Jernej Sabadin Date: Thu, 21 Nov 2024 11:55:57 +0100 Subject: [PATCH 03/21] minor prepare and update refactor --- .../metrics/confusion_matrix.py | 61 +++++++++++++++---- 1 file changed, 48 insertions(+), 13 deletions(-) diff --git a/luxonis_train/attached_modules/metrics/confusion_matrix.py b/luxonis_train/attached_modules/metrics/confusion_matrix.py index 643e2081..ea279ad0 100644 --- a/luxonis_train/attached_modules/metrics/confusion_matrix.py +++ b/luxonis_train/attached_modules/metrics/confusion_matrix.py @@ -70,7 +70,22 @@ def __init__( dist_reduce_fx="sum", ) - def prepare(self, inputs: Packet[Tensor], labels: Labels): + def prepare( + self, inputs: Packet[Tensor], labels: Labels + ) -> tuple[dict[str, Tensor], dict[str, Tensor]]: + """Prepare data for classification, segmentation, and detection + tasks. + + @type inputs: Packet[Tensor] + @param inputs: The inputs to the model. + @type labels: Labels + @param labels: The ground-truth labels. + @return: A tuple of two dictionaries: one for predictions and + one for targets. + """ + predictions = {} + targets = {} + if self.is_detection: out_bbox = self.get_input_tensors(inputs, TaskType.BOUNDINGBOX) bbox = self.get_label(labels, TaskType.BOUNDINGBOX) @@ -86,42 +101,62 @@ def prepare(self, inputs: Packet[Tensor], labels: Labels): device=bbox.device, ) bbox[..., 2:6] *= scale_factors - return out_bbox, bbox + predictions["detection"] = out_bbox + targets["detection"] = bbox if self.is_classification: - prediction = ( - self.get_input_tensors(inputs, TaskType.CLASSIFICATION), + prediction = self.get_input_tensors( + inputs, TaskType.CLASSIFICATION ) target = self.get_label(labels, TaskType.CLASSIFICATION).to( prediction[0].device ) - return prediction, target + predictions["classification"] = prediction + targets["classification"] = target if self.is_segmentation: - prediction = ( - self.get_input_tensors(inputs, TaskType.SEGMENTATION), - ) + prediction = self.get_input_tensors(inputs, TaskType.SEGMENTATION) target = self.get_label(labels, TaskType.SEGMENTATION).to( prediction[0].device ) - return prediction, target + predictions["segmentation"] = prediction + targets["segmentation"] = target - def update(self, preds: list[Tensor], target: Tensor) -> None: - if self.is_classification: + return predictions, targets + + def update( + self, predictions: dict[str, Tensor], targets: dict[str, Tensor] + ) -> None: + """Update the confusion matrices for all tasks using prepared + data. + + @type predictions: dict[str, Tensor] + @param predictions: A dictionary containing predictions for all + tasks. + @type targets: dict[str, Tensor] + @param targets: A dictionary containing targets for all tasks. + """ + if "classification" in predictions and "classification" in targets: + preds = predictions["classification"] + target = targets["classification"] pred_classes = preds[0].argmax(dim=1) # [B] target_classes = target.argmax(dim=1) # [B] self.classification_cm += self._compute_confusion_matrix( pred_classes, target_classes ) - if self.is_segmentation: + if "segmentation" in predictions and "segmentation" in targets: + preds = predictions["segmentation"] + target = targets["segmentation"] pred_masks = preds[0].argmax(dim=1) # [B, H, W] target_masks = target.argmax(dim=1) # [B, H, W] self.segmentation_cm += self._compute_confusion_matrix( pred_masks.view(-1), target_masks.view(-1) ) - if self.is_detection: + if "detection" in predictions and "detection" in targets: + preds = predictions["detection"] + target = targets["detection"] self.detection_cm += self._compute_detection_confusion_matrix( preds, target ) From b7a30797cec51f403d745e07047bb3a2f75a1267 Mon Sep 17 00:00:00 2001 From: Jernej Sabadin Date: Thu, 21 Nov 2024 14:57:35 +0100 Subject: [PATCH 04/21] fix double scaling bug --- luxonis_train/attached_modules/metrics/confusion_matrix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/luxonis_train/attached_modules/metrics/confusion_matrix.py b/luxonis_train/attached_modules/metrics/confusion_matrix.py index ea279ad0..2c5ee4c6 100644 --- a/luxonis_train/attached_modules/metrics/confusion_matrix.py +++ b/luxonis_train/attached_modules/metrics/confusion_matrix.py @@ -89,7 +89,7 @@ def prepare( if self.is_detection: out_bbox = self.get_input_tensors(inputs, TaskType.BOUNDINGBOX) bbox = self.get_label(labels, TaskType.BOUNDINGBOX) - bbox = bbox.to(out_bbox[0].device) + bbox = bbox.to(out_bbox[0].device).clone() bbox[..., 2:6] = box_convert(bbox[..., 2:6], "xywh", "xyxy") scale_factors = torch.tensor( [ From 97274830dcc4c2c4d8f74564a0306b002efc54fc Mon Sep 17 00:00:00 2001 From: Jernej Sabadin Date: Thu, 21 Nov 2024 20:25:24 +0100 Subject: [PATCH 05/21] fixed logging --- .../metrics/confusion_matrix.py | 4 +--- luxonis_train/models/luxonis_lightning.py | 21 ++++++++++++------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/luxonis_train/attached_modules/metrics/confusion_matrix.py b/luxonis_train/attached_modules/metrics/confusion_matrix.py index 2c5ee4c6..bf8b73e6 100644 --- a/luxonis_train/attached_modules/metrics/confusion_matrix.py +++ b/luxonis_train/attached_modules/metrics/confusion_matrix.py @@ -173,9 +173,7 @@ def compute(self) -> dict[str, Tensor]: results["detection_confusion_matrix"] = self.detection_cm print("detection_confusion_matrix:\n", self.detection_cm) - return torch.tensor( - [-1.0], dtype=torch.float32 - ) # Change this once luxonis-ml supports returning tensor as a metric + return results def _compute_confusion_matrix( self, preds: Tensor, targets: Tensor diff --git a/luxonis_train/models/luxonis_lightning.py b/luxonis_train/models/luxonis_lightning.py index 08d0066f..6249a221 100644 --- a/luxonis_train/models/luxonis_lightning.py +++ b/luxonis_train/models/luxonis_lightning.py @@ -810,14 +810,19 @@ def _evaluation_epoch_end(self, mode: Literal["test", "val"]) -> None: logger.info("Metrics computed.") for node_name, metrics in computed_metrics.items(): for metric_name, metric_value in metrics.items(): - metric_results[node_name][metric_name] = ( - metric_value.cpu().item() - ) - self.log( - f"{mode}/metric/{node_name}/{metric_name}", - metric_value, - sync_dist=True, - ) + if "matrix" in metric_name: + self.logger.log_matrix( + matrix=metric_value.cpu().numpy(), name=metric_name + ) + else: + metric_results[node_name][metric_name] = ( + metric_value.cpu().item() + ) + self.log( + f"{mode}/metric/{node_name}/{metric_name}", + metric_value, + sync_dist=True, + ) if self.cfg.trainer.verbose: self._print_results( From c4c58cefb74f76de3ad601eab894c1b5c766f2ac Mon Sep 17 00:00:00 2001 From: Jernej Sabadin Date: Fri, 22 Nov 2024 07:58:13 +0100 Subject: [PATCH 06/21] hotfix: no logging of CM --- .../metrics/confusion_matrix.py | 22 ++++++++++++++++--- luxonis_train/models/luxonis_lightning.py | 4 +++- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/luxonis_train/attached_modules/metrics/confusion_matrix.py b/luxonis_train/attached_modules/metrics/confusion_matrix.py index bf8b73e6..05e75571 100644 --- a/luxonis_train/attached_modules/metrics/confusion_matrix.py +++ b/luxonis_train/attached_modules/metrics/confusion_matrix.py @@ -45,6 +45,20 @@ def __init__( self.is_detection = TaskType.BOUNDINGBOX in self.node.tasks self.is_segmentation = TaskType.SEGMENTATION in self.node.tasks + if ( + sum( + [ + self.is_classification, + self.is_detection, + self.is_segmentation, + ] + ) + > 1 + ): + raise ValueError( + "Multiple tasks detected in self.node.tasks. Only one task is allowed." + ) + if self.is_classification: self.add_state( "classification_cm", @@ -162,16 +176,15 @@ def update( ) def compute(self) -> dict[str, Tensor]: + """Compute confusion matrices for classification, segmentation, + and detection tasks.""" results = {} if self.is_classification: results["classification_confusion_matrix"] = self.classification_cm - print("classification_confusion_matrix:\n", self.classification_cm) if self.is_segmentation: results["segmentation_confusion_matrix"] = self.segmentation_cm - print("segmentation_confusion_matrix:\n", self.segmentation_cm) if self.is_detection: results["detection_confusion_matrix"] = self.detection_cm - print("detection_confusion_matrix:\n", self.detection_cm) return results @@ -200,6 +213,9 @@ def _compute_detection_confusion_matrix( @param preds: List of predictions for each image. Each tensor has shape [N, 6] where 6 is for [x1, y1, x2, y2, score, class] + @type targets: Tensor + @param targets: Ground truth boxes and classes. Shape [M, 6] + where first column is image index. """ cm = torch.zeros( self.n_classes + 1, diff --git a/luxonis_train/models/luxonis_lightning.py b/luxonis_train/models/luxonis_lightning.py index 6249a221..7319a028 100644 --- a/luxonis_train/models/luxonis_lightning.py +++ b/luxonis_train/models/luxonis_lightning.py @@ -810,7 +810,7 @@ def _evaluation_epoch_end(self, mode: Literal["test", "val"]) -> None: logger.info("Metrics computed.") for node_name, metrics in computed_metrics.items(): for metric_name, metric_value in metrics.items(): - if "matrix" in metric_name: + if "matrix" in metric_name.lower(): self.logger.log_matrix( matrix=metric_value.cpu().numpy(), name=metric_name ) @@ -1036,6 +1036,8 @@ def _print_results( if self.main_metric is not None: main_metric_node, main_metric_name = self.main_metric.split("/") + if "matrix" in main_metric_name.lower(): + return main_metric = metrics[main_metric_node][main_metric_name] logger.info( f"{stage} main metric ({self.main_metric}): {main_metric:.4f}" From 2549f276a9d8f6ddb3ac9bb699c1cc190735c48b Mon Sep 17 00:00:00 2001 From: Jernej Sabadin Date: Fri, 22 Nov 2024 13:52:27 +0100 Subject: [PATCH 07/21] minor refactor --- luxonis_train/models/luxonis_lightning.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/luxonis_train/models/luxonis_lightning.py b/luxonis_train/models/luxonis_lightning.py index 7319a028..38933617 100644 --- a/luxonis_train/models/luxonis_lightning.py +++ b/luxonis_train/models/luxonis_lightning.py @@ -812,7 +812,8 @@ def _evaluation_epoch_end(self, mode: Literal["test", "val"]) -> None: for metric_name, metric_value in metrics.items(): if "matrix" in metric_name.lower(): self.logger.log_matrix( - matrix=metric_value.cpu().numpy(), name=metric_name + matrix=metric_value.cpu().numpy(), + path=f"{mode}/metrics/{self.current_epoch}/{metric_name}", ) else: metric_results[node_name][metric_name] = ( From 8f19dc2f9c5230fba2d56a2face6a4f9e6dfe391 Mon Sep 17 00:00:00 2001 From: Jernej Sabadin Date: Fri, 22 Nov 2024 14:11:55 +0100 Subject: [PATCH 08/21] rename path to name --- luxonis_train/models/luxonis_lightning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/luxonis_train/models/luxonis_lightning.py b/luxonis_train/models/luxonis_lightning.py index 38933617..7ab306b5 100644 --- a/luxonis_train/models/luxonis_lightning.py +++ b/luxonis_train/models/luxonis_lightning.py @@ -813,7 +813,7 @@ def _evaluation_epoch_end(self, mode: Literal["test", "val"]) -> None: if "matrix" in metric_name.lower(): self.logger.log_matrix( matrix=metric_value.cpu().numpy(), - path=f"{mode}/metrics/{self.current_epoch}/{metric_name}", + name=f"{mode}/metrics/{self.current_epoch}/{metric_name}", ) else: metric_results[node_name][metric_name] = ( From b98158bef8b0b31d2a66850253cb41952fa1b4b5 Mon Sep 17 00:00:00 2001 From: Jernej Sabadin Date: Sat, 23 Nov 2024 09:01:16 +0100 Subject: [PATCH 09/21] remove is_main_metric hotfix --- luxonis_train/models/luxonis_lightning.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/luxonis_train/models/luxonis_lightning.py b/luxonis_train/models/luxonis_lightning.py index 7ab306b5..ee03b450 100644 --- a/luxonis_train/models/luxonis_lightning.py +++ b/luxonis_train/models/luxonis_lightning.py @@ -1037,8 +1037,6 @@ def _print_results( if self.main_metric is not None: main_metric_node, main_metric_name = self.main_metric.split("/") - if "matrix" in main_metric_name.lower(): - return main_metric = metrics[main_metric_node][main_metric_name] logger.info( f"{stage} main metric ({self.main_metric}): {main_metric:.4f}" From 6c705888f6344014cb94a31b7b6becce7cd534dc Mon Sep 17 00:00:00 2001 From: Jernej Sabadin Date: Sat, 23 Nov 2024 09:25:58 +0100 Subject: [PATCH 10/21] forbid CM as main metric --- luxonis_train/config/config.py | 17 +++++++++++++---- luxonis_train/models/luxonis_lightning.py | 1 + 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/luxonis_train/config/config.py b/luxonis_train/config/config.py index 941cd649..baf654c3 100644 --- a/luxonis_train/config/config.py +++ b/luxonis_train/config/config.py @@ -151,15 +151,24 @@ def check_predefined_model(self) -> Self: def check_main_metric(self) -> Self: for metric in self.metrics: if metric.is_main_metric: + if "matrix" in metric.name.lower(): + raise ValueError( + f"Main metric cannot contain 'matrix' in its name: `{metric.name}`" + ) logger.info(f"Main metric: `{metric.name}`") return self logger.warning("No main metric specified.") if self.metrics: - metric = self.metrics[0] - metric.is_main_metric = True - name = metric.alias or metric.name - logger.info(f"Setting '{name}' as main metric.") + for metric in self.metrics: + if "matrix" not in metric.name.lower(): + metric.is_main_metric = True + name = metric.alias or metric.name + logger.info(f"Setting '{name}' as main metric.") + return self + raise ValueError( + "[Configuration Error] No valid main metric can be set as all metrics contain 'matrix' in their names." + ) else: logger.warning( "[Ignore if using predefined model] " diff --git a/luxonis_train/models/luxonis_lightning.py b/luxonis_train/models/luxonis_lightning.py index ee03b450..2cb00e7e 100644 --- a/luxonis_train/models/luxonis_lightning.py +++ b/luxonis_train/models/luxonis_lightning.py @@ -814,6 +814,7 @@ def _evaluation_epoch_end(self, mode: Literal["test", "val"]) -> None: self.logger.log_matrix( matrix=metric_value.cpu().numpy(), name=f"{mode}/metrics/{self.current_epoch}/{metric_name}", + step=self.current_epoch, ) else: metric_results[node_name][metric_name] = ( From 76dcc889d9697875b512a0f98872b717faf8e1d1 Mon Sep 17 00:00:00 2001 From: Jernej Sabadin Date: Mon, 25 Nov 2024 18:03:01 +0100 Subject: [PATCH 11/21] fix: failing type-check --- .../metrics/confusion_matrix.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/luxonis_train/attached_modules/metrics/confusion_matrix.py b/luxonis_train/attached_modules/metrics/confusion_matrix.py index 05e75571..b57d7069 100644 --- a/luxonis_train/attached_modules/metrics/confusion_matrix.py +++ b/luxonis_train/attached_modules/metrics/confusion_matrix.py @@ -41,9 +41,18 @@ def __init__( self.iou_threshold = iou_threshold self.confidence_threshold = confidence_threshold - self.is_classification = TaskType.CLASSIFICATION in self.node.tasks - self.is_detection = TaskType.BOUNDINGBOX in self.node.tasks - self.is_segmentation = TaskType.SEGMENTATION in self.node.tasks + self.is_classification = ( + self.node.tasks is not None + and TaskType.CLASSIFICATION in self.node.tasks + ) + self.is_detection = ( + self.node.tasks is not None + and TaskType.BOUNDINGBOX in self.node.tasks + ) + self.is_segmentation = ( + self.node.tasks is not None + and TaskType.SEGMENTATION in self.node.tasks + ) if ( sum( @@ -169,7 +178,7 @@ def update( ) if "detection" in predictions and "detection" in targets: - preds = predictions["detection"] + preds = predictions["detection"] # type: ignore target = targets["detection"] self.detection_cm += self._compute_detection_confusion_matrix( preds, target From 60c9aa21bcd00677335df72590c6b0674e6b5cc3 Mon Sep 17 00:00:00 2001 From: Jernej Sabadin Date: Mon, 25 Nov 2024 18:08:30 +0100 Subject: [PATCH 12/21] fix: failing type-check --- luxonis_train/attached_modules/metrics/confusion_matrix.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/luxonis_train/attached_modules/metrics/confusion_matrix.py b/luxonis_train/attached_modules/metrics/confusion_matrix.py index b57d7069..d2319f6c 100644 --- a/luxonis_train/attached_modules/metrics/confusion_matrix.py +++ b/luxonis_train/attached_modules/metrics/confusion_matrix.py @@ -181,7 +181,8 @@ def update( preds = predictions["detection"] # type: ignore target = targets["detection"] self.detection_cm += self._compute_detection_confusion_matrix( - preds, target + preds, + target, # type: ignore ) def compute(self) -> dict[str, Tensor]: From 72f93b37e6a207dc6236961571db868420d88718 Mon Sep 17 00:00:00 2001 From: Jernej Sabadin Date: Tue, 26 Nov 2024 13:47:41 +0100 Subject: [PATCH 13/21] torchmetrics for seg and cls --- .../metrics/confusion_matrix.py | 25 +++++-------------- 1 file changed, 6 insertions(+), 19 deletions(-) diff --git a/luxonis_train/attached_modules/metrics/confusion_matrix.py b/luxonis_train/attached_modules/metrics/confusion_matrix.py index d2319f6c..fa50fdb4 100644 --- a/luxonis_train/attached_modules/metrics/confusion_matrix.py +++ b/luxonis_train/attached_modules/metrics/confusion_matrix.py @@ -2,6 +2,7 @@ import torch from torch import Tensor +from torchmetrics.classification import MulticlassConfusionMatrix from torchvision.ops import box_convert, box_iou from luxonis_train.enums import TaskType @@ -92,6 +93,9 @@ def __init__( ), # +1 for background dist_reduce_fx="sum", ) + self.compute_confusion_matrix = MulticlassConfusionMatrix( + num_classes=self.n_classes + ) def prepare( self, inputs: Packet[Tensor], labels: Labels @@ -164,16 +168,15 @@ def update( target = targets["classification"] pred_classes = preds[0].argmax(dim=1) # [B] target_classes = target.argmax(dim=1) # [B] - self.classification_cm += self._compute_confusion_matrix( + self.classification_cm += self.compute_confusion_matrix( pred_classes, target_classes ) - if "segmentation" in predictions and "segmentation" in targets: preds = predictions["segmentation"] target = targets["segmentation"] pred_masks = preds[0].argmax(dim=1) # [B, H, W] target_masks = target.argmax(dim=1) # [B, H, W] - self.segmentation_cm += self._compute_confusion_matrix( + self.segmentation_cm += self.compute_confusion_matrix( pred_masks.view(-1), target_masks.view(-1) ) @@ -198,22 +201,6 @@ def compute(self) -> dict[str, Tensor]: return results - def _compute_confusion_matrix( - self, preds: Tensor, targets: Tensor - ) -> Tensor: - """Compute a confusion matrix using efficient vectorized - operations.""" - mask = (targets >= 0) & (targets < self.n_classes) - preds = preds[mask] - targets = targets[mask] - - indices = targets * self.n_classes + preds - cm = torch.bincount( - indices, - minlength=self.n_classes * self.n_classes, - ).reshape(self.n_classes, self.n_classes) - return cm - def _compute_detection_confusion_matrix( self, preds: list[Tensor], targets: Tensor ) -> Tensor: From 0c1e422cf29ebb4c1047e54db07159bb47d98347 Mon Sep 17 00:00:00 2001 From: Jernej Sabadin Date: Tue, 26 Nov 2024 17:32:31 +0100 Subject: [PATCH 14/21] test for det CM --- luxonis_train/models/luxonis_lightning.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/luxonis_train/models/luxonis_lightning.py b/luxonis_train/models/luxonis_lightning.py index 2cb00e7e..f6352094 100644 --- a/luxonis_train/models/luxonis_lightning.py +++ b/luxonis_train/models/luxonis_lightning.py @@ -460,6 +460,8 @@ def forward( else: del computed[computed_name] + torch.cuda.empty_cache() + outputs_dict = { node_name: outputs for node_name, outputs in computed.items() From d79aa1ca223ae48c469959bfce1266d56251a9b9 Mon Sep 17 00:00:00 2001 From: Jernej Sabadin Date: Tue, 26 Nov 2024 17:36:41 +0100 Subject: [PATCH 15/21] test for det CM --- luxonis_train/models/luxonis_lightning.py | 2 - .../test_metrics/test_confusion_matrix.py | 56 +++++++++++++++++++ 2 files changed, 56 insertions(+), 2 deletions(-) create mode 100644 tests/unittests/test_metrics/test_confusion_matrix.py diff --git a/luxonis_train/models/luxonis_lightning.py b/luxonis_train/models/luxonis_lightning.py index f6352094..2cb00e7e 100644 --- a/luxonis_train/models/luxonis_lightning.py +++ b/luxonis_train/models/luxonis_lightning.py @@ -460,8 +460,6 @@ def forward( else: del computed[computed_name] - torch.cuda.empty_cache() - outputs_dict = { node_name: outputs for node_name, outputs in computed.items() diff --git a/tests/unittests/test_metrics/test_confusion_matrix.py b/tests/unittests/test_metrics/test_confusion_matrix.py new file mode 100644 index 00000000..21d04d17 --- /dev/null +++ b/tests/unittests/test_metrics/test_confusion_matrix.py @@ -0,0 +1,56 @@ +import torch + +from luxonis_train.attached_modules.metrics.confusion_matrix import ( + ConfusionMatrix, +) +from luxonis_train.enums import TaskType +from luxonis_train.nodes import BaseNode + + +def test_compute_detection_confusion_matrix_specific_case(): + class DummyNode(BaseNode): + tasks = [TaskType.BOUNDINGBOX] + + def forward(self, _): + pass + + metric = ConfusionMatrix(node=DummyNode(n_classes=3), iou_threshold=0.5) + + preds = [torch.empty((0, 6)) for _ in range(3)] + preds.append( + torch.tensor( + [ + [10, 20, 30, 50, 0.8, 2], + [10, 21, 30, 50, 0.8, 1], + [10, 20, 30, 50, 0.8, 1], + [51, 61, 71, 78, 0.9, 2], + ] + ) + ) + + # Targets: ground truth for 4 images + targets = torch.tensor( + [ + [3, 1, 10, 20, 30, 50], + [0, 1, 10, 20, 30, 40], + [1, 2, 50, 60, 70, 80], + [2, 2, 10, 60, 70, 80], + [3, 2, 50, 60, 70, 80], + ] + ) + + expected_cm = torch.tensor( + [ + [0, 0, 0, 0], + [0, 0, 0, 2], + [0, 1, 1, 0], + [0, 1, 2, 0], + ], + dtype=torch.int64, + ) + + computed_cm = metric._compute_detection_confusion_matrix(preds, targets) + + assert torch.equal( + computed_cm, expected_cm + ), f"Expected {expected_cm}, but got {computed_cm}" From 02ccbfa7950c7ab32d4bf1d1330e6953ae8b7cde Mon Sep 17 00:00:00 2001 From: Jernej Sabadin Date: Tue, 26 Nov 2024 17:40:33 +0100 Subject: [PATCH 16/21] fix type-check errors --- luxonis_train/attached_modules/metrics/confusion_matrix.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/luxonis_train/attached_modules/metrics/confusion_matrix.py b/luxonis_train/attached_modules/metrics/confusion_matrix.py index fa50fdb4..bdcc2864 100644 --- a/luxonis_train/attached_modules/metrics/confusion_matrix.py +++ b/luxonis_train/attached_modules/metrics/confusion_matrix.py @@ -183,8 +183,8 @@ def update( if "detection" in predictions and "detection" in targets: preds = predictions["detection"] # type: ignore target = targets["detection"] - self.detection_cm += self._compute_detection_confusion_matrix( - preds, + self.detection_cm += self._compute_detection_confusion_matrix( # type: ignore + preds, # type: ignore target, # type: ignore ) From 985233a6d384fa9afe9b1a6819a9fb444d0148b7 Mon Sep 17 00:00:00 2001 From: Jernej Sabadin Date: Wed, 27 Nov 2024 14:23:57 +0100 Subject: [PATCH 17/21] actual torchmetric instance --- .../metrics/confusion_matrix.py | 39 +++++++------------ 1 file changed, 14 insertions(+), 25 deletions(-) diff --git a/luxonis_train/attached_modules/metrics/confusion_matrix.py b/luxonis_train/attached_modules/metrics/confusion_matrix.py index bdcc2864..3d9ddd18 100644 --- a/luxonis_train/attached_modules/metrics/confusion_matrix.py +++ b/luxonis_train/attached_modules/metrics/confusion_matrix.py @@ -69,22 +69,12 @@ def __init__( "Multiple tasks detected in self.node.tasks. Only one task is allowed." ) - if self.is_classification: - self.add_state( - "classification_cm", - default=torch.zeros( - self.n_classes, self.n_classes, dtype=torch.int64 - ), - dist_reduce_fx="sum", - ) - if self.is_segmentation: - self.add_state( - "segmentation_cm", - default=torch.zeros( - self.n_classes, self.n_classes, dtype=torch.int64 - ), - dist_reduce_fx="sum", + self.metric_cm = None + if self.is_classification or self.is_segmentation: + self.metric_cm = MulticlassConfusionMatrix( + num_classes=self.n_classes ) + if self.is_detection: self.add_state( "detection_cm", @@ -168,17 +158,14 @@ def update( target = targets["classification"] pred_classes = preds[0].argmax(dim=1) # [B] target_classes = target.argmax(dim=1) # [B] - self.classification_cm += self.compute_confusion_matrix( - pred_classes, target_classes - ) + self.metric_cm.update(pred_classes, target_classes) + if "segmentation" in predictions and "segmentation" in targets: preds = predictions["segmentation"] target = targets["segmentation"] pred_masks = preds[0].argmax(dim=1) # [B, H, W] target_masks = target.argmax(dim=1) # [B, H, W] - self.segmentation_cm += self.compute_confusion_matrix( - pred_masks.view(-1), target_masks.view(-1) - ) + self.metric_cm.update(pred_masks.view(-1), target_masks.view(-1)) if "detection" in predictions and "detection" in targets: preds = predictions["detection"] # type: ignore @@ -192,10 +179,12 @@ def compute(self) -> dict[str, Tensor]: """Compute confusion matrices for classification, segmentation, and detection tasks.""" results = {} - if self.is_classification: - results["classification_confusion_matrix"] = self.classification_cm - if self.is_segmentation: - results["segmentation_confusion_matrix"] = self.segmentation_cm + if self.metric_cm: + task_type = ( + "classification" if self.is_classification else "segmentation" + ) + results[f"{task_type}_confusion_matrix"] = self.metric_cm.compute() + self.metric_cm.reset() if self.is_detection: results["detection_confusion_matrix"] = self.detection_cm From 7d4e469633274014e430e5b27c3489f131fe2156 Mon Sep 17 00:00:00 2001 From: Jernej Sabadin Date: Wed, 27 Nov 2024 14:30:18 +0100 Subject: [PATCH 18/21] remove compute_confusion_matrix --- luxonis_train/attached_modules/metrics/confusion_matrix.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/luxonis_train/attached_modules/metrics/confusion_matrix.py b/luxonis_train/attached_modules/metrics/confusion_matrix.py index 3d9ddd18..e88e27d8 100644 --- a/luxonis_train/attached_modules/metrics/confusion_matrix.py +++ b/luxonis_train/attached_modules/metrics/confusion_matrix.py @@ -83,9 +83,6 @@ def __init__( ), # +1 for background dist_reduce_fx="sum", ) - self.compute_confusion_matrix = MulticlassConfusionMatrix( - num_classes=self.n_classes - ) def prepare( self, inputs: Packet[Tensor], labels: Labels From a1b003aafe8d6661eff5e7f27f92c17839ff88ea Mon Sep 17 00:00:00 2001 From: Jernej Sabadin Date: Wed, 27 Nov 2024 14:32:18 +0100 Subject: [PATCH 19/21] fix type-check --- .../attached_modules/metrics/confusion_matrix.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/luxonis_train/attached_modules/metrics/confusion_matrix.py b/luxonis_train/attached_modules/metrics/confusion_matrix.py index e88e27d8..079a9315 100644 --- a/luxonis_train/attached_modules/metrics/confusion_matrix.py +++ b/luxonis_train/attached_modules/metrics/confusion_matrix.py @@ -155,14 +155,18 @@ def update( target = targets["classification"] pred_classes = preds[0].argmax(dim=1) # [B] target_classes = target.argmax(dim=1) # [B] - self.metric_cm.update(pred_classes, target_classes) + if self.metric_cm is not None: + self.metric_cm.update(pred_classes, target_classes) if "segmentation" in predictions and "segmentation" in targets: preds = predictions["segmentation"] target = targets["segmentation"] pred_masks = preds[0].argmax(dim=1) # [B, H, W] target_masks = target.argmax(dim=1) # [B, H, W] - self.metric_cm.update(pred_masks.view(-1), target_masks.view(-1)) + if self.metric_cm is not None: + self.metric_cm.update( + pred_masks.view(-1), target_masks.view(-1) + ) if "detection" in predictions and "detection" in targets: preds = predictions["detection"] # type: ignore From 1b565ded76bb0ae328d6ddeaeae74cc94b538e0b Mon Sep 17 00:00:00 2001 From: Jernej Sabadin Date: Wed, 27 Nov 2024 15:47:08 +0100 Subject: [PATCH 20/21] fix failing tests --- tests/unittests/test_metrics/test_confusion_matrix.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/unittests/test_metrics/test_confusion_matrix.py b/tests/unittests/test_metrics/test_confusion_matrix.py index 21d04d17..07429d8b 100644 --- a/tests/unittests/test_metrics/test_confusion_matrix.py +++ b/tests/unittests/test_metrics/test_confusion_matrix.py @@ -8,13 +8,15 @@ def test_compute_detection_confusion_matrix_specific_case(): - class DummyNode(BaseNode): + class DummyNodeDetection(BaseNode): tasks = [TaskType.BOUNDINGBOX] def forward(self, _): pass - metric = ConfusionMatrix(node=DummyNode(n_classes=3), iou_threshold=0.5) + metric = ConfusionMatrix( + node=DummyNodeDetection(n_classes=3), iou_threshold=0.5 + ) preds = [torch.empty((0, 6)) for _ in range(3)] preds.append( From d4170a927f35390820fcf88beabca228da17987f Mon Sep 17 00:00:00 2001 From: klemen1999 Date: Mon, 9 Dec 2024 12:30:07 +0100 Subject: [PATCH 21/21] update to latest luxonis-ml --- requirements-config.txt | 2 +- requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements-config.txt b/requirements-config.txt index 0a7b2625..f8498e7d 100644 --- a/requirements-config.txt +++ b/requirements-config.txt @@ -1 +1 @@ -luxonis-ml[data,utils]@git+https://github.com/luxonis/luxonis-ml.git@dev +luxonis-ml[data,utils]@git+https://github.com/luxonis/luxonis-ml.git@main diff --git a/requirements.txt b/requirements.txt index 5d0fcb28..5ef87b3a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ blobconverter>=1.4.2 lightning>=2.4.0 -luxonis-ml[data,tracker]>=0.5.0 +luxonis-ml[data,tracker]>=0.5.1 onnx>=1.12.0 onnxruntime>=1.13.1 onnxsim>=0.4.10