From de0c71970a5b2eba2136a0463bccf0de4b8969e4 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Mon, 13 May 2024 21:31:59 +0200 Subject: [PATCH 1/6] fixed reset not being called --- .../attached_modules/metrics/mean_average_precision.py | 3 +++ .../metrics/mean_average_precision_keypoints.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/luxonis_train/attached_modules/metrics/mean_average_precision.py b/luxonis_train/attached_modules/metrics/mean_average_precision.py index 34adbcd9..edb7430f 100644 --- a/luxonis_train/attached_modules/metrics/mean_average_precision.py +++ b/luxonis_train/attached_modules/metrics/mean_average_precision.py @@ -62,6 +62,9 @@ def prepare( return output_list, label_list + def reset(self) -> None: + self.metric.reset() + def compute(self) -> tuple[Tensor, dict[str, Tensor]]: metric_dict = self.metric.compute() diff --git a/luxonis_train/attached_modules/metrics/mean_average_precision_keypoints.py b/luxonis_train/attached_modules/metrics/mean_average_precision_keypoints.py index 3740f58e..622ae1e5 100644 --- a/luxonis_train/attached_modules/metrics/mean_average_precision_keypoints.py +++ b/luxonis_train/attached_modules/metrics/mean_average_precision_keypoints.py @@ -201,6 +201,9 @@ def update( item.get("iscrowd", torch.zeros_like(item["labels"])) ) + def reset(self) -> None: + self.metric.reset() + def compute(self) -> tuple[Tensor, dict[str, Tensor]]: """Torchmetric compute function.""" coco_target, coco_preds = COCO(), COCO() From 194c253ca7a009dcb79646de937f9d72a68226e0 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Mon, 13 May 2024 21:36:08 +0200 Subject: [PATCH 2/6] added metric resets --- luxonis_train/attached_modules/metrics/common.py | 3 +++ .../attached_modules/metrics/object_keypoint_similarity.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/luxonis_train/attached_modules/metrics/common.py b/luxonis_train/attached_modules/metrics/common.py index 27d1069a..fe1891df 100644 --- a/luxonis_train/attached_modules/metrics/common.py +++ b/luxonis_train/attached_modules/metrics/common.py @@ -55,6 +55,9 @@ def update(self, preds, target, *args, **kwargs): def compute(self): return self.metric.compute() + def reset(self) -> None: + self.metric.reset() + class Accuracy(TorchMetricWrapper): Metric = torchmetrics.Accuracy diff --git a/luxonis_train/attached_modules/metrics/object_keypoint_similarity.py b/luxonis_train/attached_modules/metrics/object_keypoint_similarity.py index c5e4a19b..d9cffcbc 100644 --- a/luxonis_train/attached_modules/metrics/object_keypoint_similarity.py +++ b/luxonis_train/attached_modules/metrics/object_keypoint_similarity.py @@ -133,6 +133,9 @@ def update( self.groundtruth_keypoints.append(keypoints) self.groundtruth_scales.append(item["scales"]) + def reset(self) -> None: + self.metric.reset() + def compute(self) -> Tensor: """Computes the OKS metric based on the inner state.""" From 9bc90960274e97f379d6c67a4eeb931a87c886f0 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Tue, 14 May 2024 01:12:44 +0200 Subject: [PATCH 3/6] removed inheritance --- .../attached_modules/metrics/mean_average_precision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/luxonis_train/attached_modules/metrics/mean_average_precision.py b/luxonis_train/attached_modules/metrics/mean_average_precision.py index edb7430f..0a58d061 100644 --- a/luxonis_train/attached_modules/metrics/mean_average_precision.py +++ b/luxonis_train/attached_modules/metrics/mean_average_precision.py @@ -12,7 +12,7 @@ from .base_metric import BaseMetric -class MeanAveragePrecision(BaseMetric, detection.MeanAveragePrecision): +class MeanAveragePrecision(BaseMetric): """Compute the Mean-Average-Precision (mAP) and Mean-Average-Recall (mAR) for object detection predictions. From 65b0c4e37c4a6ad80895be7e26eda6f751db8e99 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Tue, 14 May 2024 01:21:56 +0200 Subject: [PATCH 4/6] proper oks reset --- .../attached_modules/metrics/object_keypoint_similarity.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/luxonis_train/attached_modules/metrics/object_keypoint_similarity.py b/luxonis_train/attached_modules/metrics/object_keypoint_similarity.py index d9cffcbc..fdef971f 100644 --- a/luxonis_train/attached_modules/metrics/object_keypoint_similarity.py +++ b/luxonis_train/attached_modules/metrics/object_keypoint_similarity.py @@ -134,7 +134,9 @@ def update( self.groundtruth_scales.append(item["scales"]) def reset(self) -> None: - self.metric.reset() + self.pred_keypoints = [] + self.groundtruth_keypoints = [] + self.groundtruth_scales = [] def compute(self) -> Tensor: """Computes the OKS metric based on the inner state.""" From 617dee966576156f423fb8b9c8ed7596e66c6441 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Tue, 14 May 2024 01:26:24 +0200 Subject: [PATCH 5/6] removed unnecessary resets --- .../metrics/mean_average_precision_keypoints.py | 3 --- .../attached_modules/metrics/object_keypoint_similarity.py | 5 ----- 2 files changed, 8 deletions(-) diff --git a/luxonis_train/attached_modules/metrics/mean_average_precision_keypoints.py b/luxonis_train/attached_modules/metrics/mean_average_precision_keypoints.py index 622ae1e5..3740f58e 100644 --- a/luxonis_train/attached_modules/metrics/mean_average_precision_keypoints.py +++ b/luxonis_train/attached_modules/metrics/mean_average_precision_keypoints.py @@ -201,9 +201,6 @@ def update( item.get("iscrowd", torch.zeros_like(item["labels"])) ) - def reset(self) -> None: - self.metric.reset() - def compute(self) -> tuple[Tensor, dict[str, Tensor]]: """Torchmetric compute function.""" coco_target, coco_preds = COCO(), COCO() diff --git a/luxonis_train/attached_modules/metrics/object_keypoint_similarity.py b/luxonis_train/attached_modules/metrics/object_keypoint_similarity.py index fdef971f..c5e4a19b 100644 --- a/luxonis_train/attached_modules/metrics/object_keypoint_similarity.py +++ b/luxonis_train/attached_modules/metrics/object_keypoint_similarity.py @@ -133,11 +133,6 @@ def update( self.groundtruth_keypoints.append(keypoints) self.groundtruth_scales.append(item["scales"]) - def reset(self) -> None: - self.pred_keypoints = [] - self.groundtruth_keypoints = [] - self.groundtruth_scales = [] - def compute(self) -> Tensor: """Computes the OKS metric based on the inner state.""" From 95d7b4434cccd8c5d44f8677a65c0e34898123c6 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Tue, 14 May 2024 01:26:34 +0200 Subject: [PATCH 6/6] added annotations --- luxonis_train/attached_modules/metrics/common.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/luxonis_train/attached_modules/metrics/common.py b/luxonis_train/attached_modules/metrics/common.py index fe1891df..6d16a4b4 100644 --- a/luxonis_train/attached_modules/metrics/common.py +++ b/luxonis_train/attached_modules/metrics/common.py @@ -1,6 +1,7 @@ import logging import torchmetrics +from torch import Tensor from .base_metric import BaseMetric @@ -47,12 +48,12 @@ def __init__(self, **kwargs): self.metric = self.Metric(**kwargs) - def update(self, preds, target, *args, **kwargs): + def update(self, preds, target, *args, **kwargs) -> None: if self.task in ["multiclass"]: target = target.argmax(dim=1) self.metric.update(preds, target, *args, **kwargs) - def compute(self): + def compute(self) -> Tensor: return self.metric.compute() def reset(self) -> None: