From 18a580cd1bf786166e07cf43e5ed97137f56d51f Mon Sep 17 00:00:00 2001 From: PierreMarieCurie <154653205+PierreMarieCurie@users.noreply.github.com> Date: Thu, 21 Mar 2024 15:36:40 +0100 Subject: [PATCH 1/2] Update inference.py Add default values in predict_with_classes method to be coherent with predict to predict_with_caption method. --- groundingdino/util/inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/groundingdino/util/inference.py b/groundingdino/util/inference.py index 58528ed..8c5f4b0 100644 --- a/groundingdino/util/inference.py +++ b/groundingdino/util/inference.py @@ -192,8 +192,8 @@ def predict_with_classes( self, image: np.ndarray, classes: List[str], - box_threshold: float, - text_threshold: float + box_threshold: float = 0.35, + text_threshold: float = 0.25 ) -> sv.Detections: """ import cv2 From b8b98e3ff86c7e1d80f9bea134ce0fc6f9b2e22e Mon Sep 17 00:00:00 2001 From: PierreMarieCurie <154653205+PierreMarieCurie@users.noreply.github.com> Date: Thu, 21 Mar 2024 16:10:12 +0100 Subject: [PATCH 2/2] Update inference.py Add access to remove_combined parameter through new API --- groundingdino/util/inference.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/groundingdino/util/inference.py b/groundingdino/util/inference.py index 8c5f4b0..857f048 100644 --- a/groundingdino/util/inference.py +++ b/groundingdino/util/inference.py @@ -152,7 +152,8 @@ def predict_with_caption( image: np.ndarray, caption: str, box_threshold: float = 0.35, - text_threshold: float = 0.25 + text_threshold: float = 0.25, + remove_combined: bool = False ) -> Tuple[sv.Detections, List[str]]: """ import cv2 @@ -179,7 +180,8 @@ def predict_with_caption( caption=caption, box_threshold=box_threshold, text_threshold=text_threshold, - device=self.device) + device=self.device, + remove_combined=remove_combined) source_h, source_w, _ = image.shape detections = Model.post_process_result( source_h=source_h, @@ -193,7 +195,8 @@ def predict_with_classes( image: np.ndarray, classes: List[str], box_threshold: float = 0.35, - text_threshold: float = 0.25 + text_threshold: float = 0.25, + remove_combined: bool = False ) -> sv.Detections: """ import cv2 @@ -222,7 +225,8 @@ def predict_with_classes( caption=caption, box_threshold=box_threshold, text_threshold=text_threshold, - device=self.device) + device=self.device, + remove_combined=remove_combined) source_h, source_w, _ = image.shape detections = Model.post_process_result( source_h=source_h,