From 81d12a3ecd868d0b44011072c9481cc396ec7039 Mon Sep 17 00:00:00 2001 From: rolson24 Date: Thu, 25 Jul 2024 19:22:57 +0000 Subject: [PATCH 1/3] add support for mmdet tracking --- supervision/detection/core.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/supervision/detection/core.py b/supervision/detection/core.py index 89a526d62..e3e35aa04 100644 --- a/supervision/detection/core.py +++ b/supervision/detection/core.py @@ -423,6 +423,16 @@ def from_mmdetection(cls, mmdet_results) -> Detections: detections = sv.Detections.from_mmdetection(result) ``` """ # noqa: E501 // docs + if hasattr(mmdet_results, "pred_track_instances") and mmdet_results.pred_track_instances is not None: + return cls( + xyxy=mmdet_results.pred_track_instances.bboxes.cpu().numpy(), + confidence=mmdet_results.pred_track_instances.scores.cpu().numpy(), + class_id=mmdet_results.pred_track_instances.labels.cpu().numpy(), + mask=mmdet_results.pred_track_instances.masks.cpu().numpy() + if "masks" in mmdet_results.pred_track_instances + else None, + tracker_id=mmdet_results.pred_track_instances.instances_id.cpu().numpy() + ) return cls( xyxy=mmdet_results.pred_instances.bboxes.cpu().numpy(), From 40b44991d03c96381966e7082d8c30bbede83cf9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 26 Jul 2024 00:15:35 +0000 Subject: [PATCH 2/3] =?UTF-8?q?fix(pre=5Fcommit):=20=F0=9F=8E=A8=20auto=20?= =?UTF-8?q?format=20pre-commit=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- supervision/detection/core.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/supervision/detection/core.py b/supervision/detection/core.py index e3e35aa04..6566db008 100644 --- a/supervision/detection/core.py +++ b/supervision/detection/core.py @@ -423,7 +423,10 @@ def from_mmdetection(cls, mmdet_results) -> Detections: detections = sv.Detections.from_mmdetection(result) ``` """ # noqa: E501 // docs - if hasattr(mmdet_results, "pred_track_instances") and mmdet_results.pred_track_instances is not None: + if ( + hasattr(mmdet_results, "pred_track_instances") + and mmdet_results.pred_track_instances is not None + ): return cls( xyxy=mmdet_results.pred_track_instances.bboxes.cpu().numpy(), confidence=mmdet_results.pred_track_instances.scores.cpu().numpy(), @@ -431,7 +434,7 @@ def from_mmdetection(cls, mmdet_results) -> Detections: mask=mmdet_results.pred_track_instances.masks.cpu().numpy() if "masks" in mmdet_results.pred_track_instances else None, - tracker_id=mmdet_results.pred_track_instances.instances_id.cpu().numpy() + tracker_id=mmdet_results.pred_track_instances.instances_id.cpu().numpy(), ) return cls( From 0cd5aa1445e3221baa7494c88a4e8532a71f8dc1 Mon Sep 17 00:00:00 2001 From: Raif Olson <99894460+rolson24@users.noreply.github.com> Date: Mon, 29 Jul 2024 10:26:45 -0400 Subject: [PATCH 3/3] handle case where tracker_id does not exist --- supervision/detection/core.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/supervision/detection/core.py b/supervision/detection/core.py index 6566db008..30e44d848 100644 --- a/supervision/detection/core.py +++ b/supervision/detection/core.py @@ -434,7 +434,9 @@ def from_mmdetection(cls, mmdet_results) -> Detections: mask=mmdet_results.pred_track_instances.masks.cpu().numpy() if "masks" in mmdet_results.pred_track_instances else None, - tracker_id=mmdet_results.pred_track_instances.instances_id.cpu().numpy(), + tracker_id=mmdet_results.pred_track_instances.instances_id.cpu().numpy() + if "instances_id" in mmdet_results.pred_track_instances + else None, ) return cls(