From fdfc38c7dbfd2fe7fcfb69cca34730555433870a Mon Sep 17 00:00:00 2001 From: hrnn Date: Wed, 4 Dec 2024 23:29:03 -0400 Subject: [PATCH 01/13] feat: countgd sam2 video --- vision_agent/tools/__init__.py | 1 + vision_agent/tools/tools.py | 64 ++++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+) diff --git a/vision_agent/tools/__init__.py b/vision_agent/tools/__init__.py index 32eab0fc..02382a79 100644 --- a/vision_agent/tools/__init__.py +++ b/vision_agent/tools/__init__.py @@ -71,6 +71,7 @@ video_temporal_localization, vit_image_classification, vit_nsfw_classification, + countgd_sam2_video_tracking, ) __new_tools__ = [ diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 6df6ed31..bed158c6 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -2439,6 +2439,69 @@ def _plot_counting( return image +def countgd_sam2_video_tracking( + prompt: str, + frames: List[np.ndarray], + chunk_length: Optional[int] = 10, + fine_tune_id: Optional[str] = None, +) -> List[List[Dict[str, Any]]]: + """`countgd_sam2_video_tracking` it is only a test method + """ + + results = [None] * len(frames) + + for idx in range(0, len(frames), chunk_length): + results[idx] = countgd_counting(prompt=prompt, image=frames[idx]) + + image_size = frames[0].shape[:2] + + def _transform_detections(input_list): + output_list = [] + + for idx, frame in enumerate(input_list): + if frame is not None: + labels = [detection['label'] for detection in frame] + bboxes = [denormalize_bbox(detection['bbox'], image_size) for detection in frame] + + output_list.append({ + "labels": labels, + "bboxes": bboxes, + }) + else: + output_list.append(None) + + return output_list + + + output = _transform_detections(results) + + buffer_bytes = frames_to_bytes(frames) + files = [("video", buffer_bytes)] + payload = { + "bboxes": json.dumps(output), + "chunk_length": chunk_length + } + metadata = {"function_name": "countgd_sam2_video_tracking"} + + detections = send_task_inference_request( + payload, + "sam2", + files=files, + metadata=metadata, + ) + + return_data = [] + for frame in detections: + return_frame_data = [] + for detection in frame: + mask = rle_decode_array(detection["mask"]) + label = str(detection["id"]) + ": " + detection["label"] + return_frame_data.append({"label": label, "mask": mask, "score": 1.0}) + return_data.append(return_frame_data) + return_data = add_bboxes_from_masks(return_data) + return nms(return_data, iou_threshold=0.95) + + FUNCTION_TOOLS = [ owl_v2_image, owl_v2_video, @@ -2461,6 +2524,7 @@ def _plot_counting( video_temporal_localization, flux_image_inpainting, siglip_classification, + countgd_sam2_video_tracking ] UTIL_TOOLS = [ From 6bcd2178df0f21a7f3082a19572bf82207252471 Mon Sep 17 00:00:00 2001 From: hrnn Date: Thu, 5 Dec 2024 09:38:35 -0400 Subject: [PATCH 02/13] added test --- vision_agent/tools/tools.py | 33 ++++++++++++++++----------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index b06dad85..686c71c3 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -2629,15 +2629,13 @@ def countgd_sam2_video_tracking( prompt: str, frames: List[np.ndarray], chunk_length: Optional[int] = 10, - fine_tune_id: Optional[str] = None, ) -> List[List[Dict[str, Any]]]: - """`countgd_sam2_video_tracking` it is only a test method - """ + """`countgd_sam2_video_tracking` it is only a test method""" results = [None] * len(frames) for idx in range(0, len(frames), chunk_length): - results[idx] = countgd_counting(prompt=prompt, image=frames[idx]) + results[idx] = countgd_object_detection(prompt=prompt, image=frames[idx]) image_size = frames[0].shape[:2] @@ -2646,27 +2644,28 @@ def _transform_detections(input_list): for idx, frame in enumerate(input_list): if frame is not None: - labels = [detection['label'] for detection in frame] - bboxes = [denormalize_bbox(detection['bbox'], image_size) for detection in frame] - - output_list.append({ - "labels": labels, - "bboxes": bboxes, - }) + labels = [detection["label"] for detection in frame] + bboxes = [ + denormalize_bbox(detection["bbox"], image_size) + for detection in frame + ] + + output_list.append( + { + "labels": labels, + "bboxes": bboxes, + } + ) else: output_list.append(None) return output_list - output = _transform_detections(results) buffer_bytes = frames_to_bytes(frames) files = [("video", buffer_bytes)] - payload = { - "bboxes": json.dumps(output), - "chunk_length": chunk_length - } + payload = {"bboxes": json.dumps(output), "chunk_length": chunk_length} metadata = {"function_name": "countgd_sam2_video_tracking"} detections = send_task_inference_request( @@ -2710,7 +2709,7 @@ def _transform_detections(input_list): video_temporal_localization, flux_image_inpainting, siglip_classification, - countgd_sam2_video_tracking + countgd_sam2_video_tracking, ] UTIL_TOOLS = [ From 50faa2e9ea4e66281a485260036d581084708b8d Mon Sep 17 00:00:00 2001 From: hrnn Date: Thu, 5 Dec 2024 09:40:44 -0400 Subject: [PATCH 03/13] added test --- tests/integ/test_tools.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/tests/integ/test_tools.py b/tests/integ/test_tools.py index 97dc15b7..3696b9a5 100644 --- a/tests/integ/test_tools.py +++ b/tests/integ/test_tools.py @@ -6,9 +6,10 @@ blip_image_caption, clip, closest_mask_distance, + countgd_example_based_counting, countgd_object_detection, countgd_sam2_object_detection, - countgd_example_based_counting, + countgd_sam2_video_tracking, depth_anything_v2, detr_segmentation, dpt_hybrid_midas, @@ -624,3 +625,18 @@ def test_flux_image_inpainting_resizing_big_image(): assert result.shape[0] == 512 assert result.shape[1] == 208 + + +def test_video_tracking_with_countgd(): + + frames = [ + np.array(Image.fromarray(ski.data.coins()).convert("RGB")) for _ in range(10) + ] + result = countgd_sam2_video_tracking( + prompt="coin", + frames=frames, + ) + + assert len(result) == 10 + assert len([res["label"] for res in result[0]]) == 24 + assert len([res["mask"] for res in result[0]]) == 24 From db40ef4b260733fd80e87033d4a7b7b54550be5a Mon Sep 17 00:00:00 2001 From: hrnn Date: Thu, 5 Dec 2024 11:20:40 -0400 Subject: [PATCH 04/13] handle empty chunk length --- vision_agent/tools/tools.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 686c71c3..692a3429 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -2634,7 +2634,14 @@ def countgd_sam2_video_tracking( results = [None] * len(frames) - for idx in range(0, len(frames), chunk_length): + if chunk_length is None: + step = 1 # Process every frame + elif chunk_length <= 0: + raise ValueError("chunk_length must be a positive integer or None.") + else: + step = chunk_length # Process frames with the specified step size + + for idx in range(0, len(frames), step): results[idx] = countgd_object_detection(prompt=prompt, image=frames[idx]) image_size = frames[0].shape[:2] From e0a7cb281c2f61d10444f5e9a9c74b840ec6b315 Mon Sep 17 00:00:00 2001 From: hrnn Date: Thu, 5 Dec 2024 11:48:41 -0400 Subject: [PATCH 05/13] fixed mypy issues --- vision_agent/tools/tools.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 692a3429..97b7f733 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -2632,7 +2632,7 @@ def countgd_sam2_video_tracking( ) -> List[List[Dict[str, Any]]]: """`countgd_sam2_video_tracking` it is only a test method""" - results = [None] * len(frames) + results: List[Optional[List[Dict[str, Any]]]] = [None] * len(frames) if chunk_length is None: step = 1 # Process every frame @@ -2646,8 +2646,10 @@ def countgd_sam2_video_tracking( image_size = frames[0].shape[:2] - def _transform_detections(input_list): - output_list = [] + def _transform_detections( + input_list: List[Optional[List[Dict[str, Any]]]] + ) -> List[Optional[Dict[str, Any]]]: + output_list: List[Optional[Dict[str, Any]]] = [] for idx, frame in enumerate(input_list): if frame is not None: From 984eaa379264302b2a37500214f63b9dc1b14c86 Mon Sep 17 00:00:00 2001 From: hrnn Date: Thu, 5 Dec 2024 17:04:56 -0400 Subject: [PATCH 06/13] updated docstring --- vision_agent/tools/tools.py | 33 ++++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 97b7f733..49e970f9 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -2630,8 +2630,39 @@ def countgd_sam2_video_tracking( frames: List[np.ndarray], chunk_length: Optional[int] = 10, ) -> List[List[Dict[str, Any]]]: - """`countgd_sam2_video_tracking` it is only a test method""" + """'countgd_sam2_video_tracking' is a tool that can segment multiple objects given a text + prompt such as category names or referring expressions. The categories in the text + prompt are separated by commas. It returns a list of bounding boxes, label names, + mask file names and associated probability scores of 1.0. + + Parameters: + prompt (str): The prompt to ground to the image. + image (np.ndarray): The image to ground the prompt to. + + Returns: + List[Dict[str, Any]]: A list of dictionaries containing the score, label, + bounding box, and mask of the detected objects with normalized coordinates + (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the top-left + and xmax and ymax are the coordinates of the bottom-right of the bounding box. + The mask is binary 2D numpy array where 1 indicates the object and 0 indicates + the background. + Example + ------- + >>> countgd_sam2_video_tracking("car, dinosaur", image) + [ + { + 'score': 1.0, + 'label': 'dinosaur', + 'bbox': [0.1, 0.11, 0.35, 0.4], + 'mask': array([[0, 0, 0, ..., 0, 0, 0], + [0, 0, 0, ..., 0, 0, 0], + ..., + [0, 0, 0, ..., 0, 0, 0], + [0, 0, 0, ..., 0, 0, 0]], dtype=uint8), + }, + ] + """ results: List[Optional[List[Dict[str, Any]]]] = [None] * len(frames) if chunk_length is None: From 87a787c2fd465eadeebf39f367eb6e134791f7bf Mon Sep 17 00:00:00 2001 From: hrnn Date: Thu, 5 Dec 2024 18:24:53 -0400 Subject: [PATCH 07/13] allow multiple od tools --- vision_agent/tools/tools.py | 92 ++++++++++++++++++++++++------------- 1 file changed, 59 insertions(+), 33 deletions(-) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 49e970f9..96220b90 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -5,6 +5,7 @@ import tempfile import urllib.request from concurrent.futures import ThreadPoolExecutor, as_completed +from enum import Enum from functools import lru_cache from importlib import resources from pathlib import Path @@ -2625,44 +2626,25 @@ def _plot_counting( return image -def countgd_sam2_video_tracking( +class ODModels(str, Enum): + COUNTGD = "countgd" + + +def od_sam2_video_tracking( + od_model: ODModels, prompt: str, frames: List[np.ndarray], chunk_length: Optional[int] = 10, + fine_tune_id: Optional[str] = None, ) -> List[List[Dict[str, Any]]]: - """'countgd_sam2_video_tracking' is a tool that can segment multiple objects given a text - prompt such as category names or referring expressions. The categories in the text - prompt are separated by commas. It returns a list of bounding boxes, label names, - mask file names and associated probability scores of 1.0. - - Parameters: - prompt (str): The prompt to ground to the image. - image (np.ndarray): The image to ground the prompt to. - Returns: - List[Dict[str, Any]]: A list of dictionaries containing the score, label, - bounding box, and mask of the detected objects with normalized coordinates - (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the top-left - and xmax and ymax are the coordinates of the bottom-right of the bounding box. - The mask is binary 2D numpy array where 1 indicates the object and 0 indicates - the background. + if od_model == ODModels.COUNTGD: + detection_function = countgd_object_detection + else: + raise NotImplementedError( + f"Object detection model '{od_model.value}' is not implemented." + ) - Example - ------- - >>> countgd_sam2_video_tracking("car, dinosaur", image) - [ - { - 'score': 1.0, - 'label': 'dinosaur', - 'bbox': [0.1, 0.11, 0.35, 0.4], - 'mask': array([[0, 0, 0, ..., 0, 0, 0], - [0, 0, 0, ..., 0, 0, 0], - ..., - [0, 0, 0, ..., 0, 0, 0], - [0, 0, 0, ..., 0, 0, 0]], dtype=uint8), - }, - ] - """ results: List[Optional[List[Dict[str, Any]]]] = [None] * len(frames) if chunk_length is None: @@ -2673,7 +2655,7 @@ def countgd_sam2_video_tracking( step = chunk_length # Process frames with the specified step size for idx in range(0, len(frames), step): - results[idx] = countgd_object_detection(prompt=prompt, image=frames[idx]) + results[idx] = detection_function(prompt=prompt, image=frames[idx]) image_size = frames[0].shape[:2] @@ -2727,6 +2709,50 @@ def _transform_detections( return nms(return_data, iou_threshold=0.95) +def countgd_sam2_video_tracking( + prompt: str, + frames: List[np.ndarray], + chunk_length: Optional[int] = 10, +) -> List[List[Dict[str, Any]]]: + """'countgd_sam2_video_tracking' is a tool that can segment multiple objects given a text + prompt such as category names or referring expressions. The categories in the text + prompt are separated by commas. It returns a list of bounding boxes, label names, + mask file names and associated probability scores of 1.0. + + Parameters: + prompt (str): The prompt to ground to the image. + image (np.ndarray): The image to ground the prompt to. + + Returns: + List[Dict[str, Any]]: A list of dictionaries containing the score, label, + bounding box, and mask of the detected objects with normalized coordinates + (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the top-left + and xmax and ymax are the coordinates of the bottom-right of the bounding box. + The mask is binary 2D numpy array where 1 indicates the object and 0 indicates + the background. + + Example + ------- + >>> countgd_sam2_video_tracking("car, dinosaur", image) + [ + { + 'score': 1.0, + 'label': 'dinosaur', + 'bbox': [0.1, 0.11, 0.35, 0.4], + 'mask': array([[0, 0, 0, ..., 0, 0, 0], + [0, 0, 0, ..., 0, 0, 0], + ..., + [0, 0, 0, ..., 0, 0, 0], + [0, 0, 0, ..., 0, 0, 0]], dtype=uint8), + }, + ] + """ + + return od_sam2_video_tracking( + ODModels.COUNTGD, prompt=prompt, frames=frames, chunk_length=chunk_length + ) + + FUNCTION_TOOLS = [ owl_v2_image, owl_v2_video, From 07b4e19b1931aea7c1e079a56a613c5c384e4fb1 Mon Sep 17 00:00:00 2001 From: hrnn Date: Thu, 5 Dec 2024 23:36:37 -0400 Subject: [PATCH 08/13] added owlv2 --- tests/integ/test_tools.py | 20 ++++++++++++- vision_agent/tools/tools.py | 57 +++++++++++++++++++++++++++++++++++-- 2 files changed, 74 insertions(+), 3 deletions(-) diff --git a/tests/integ/test_tools.py b/tests/integ/test_tools.py index 3696b9a5..3246af9d 100644 --- a/tests/integ/test_tools.py +++ b/tests/integ/test_tools.py @@ -9,7 +9,6 @@ countgd_example_based_counting, countgd_object_detection, countgd_sam2_object_detection, - countgd_sam2_video_tracking, depth_anything_v2, detr_segmentation, dpt_hybrid_midas, @@ -41,6 +40,10 @@ vit_image_classification, vit_nsfw_classification, ) +from vision_agent.tools.tools import ( + countgd_sam2_video_tracking, + owlv2_sam2_video_tracking, +) FINE_TUNE_ID = "65ebba4a-88b7-419f-9046-0750e30250da" @@ -640,3 +643,18 @@ def test_video_tracking_with_countgd(): assert len(result) == 10 assert len([res["label"] for res in result[0]]) == 24 assert len([res["mask"] for res in result[0]]) == 24 + + +def test_video_tracking_with_owlv2(): + + frames = [ + np.array(Image.fromarray(ski.data.coins()).convert("RGB")) for _ in range(10) + ] + result = owlv2_sam2_video_tracking( + prompt="coin", + frames=frames, + ) + + assert len(result) == 10 + assert len([res["label"] for res in result[0]]) == 24 + assert len([res["mask"] for res in result[0]]) == 24 diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 96220b90..d02589dd 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -2628,6 +2628,7 @@ def _plot_counting( class ODModels(str, Enum): COUNTGD = "countgd" + OWLV2 = "owlv2" def od_sam2_video_tracking( @@ -2640,6 +2641,10 @@ def od_sam2_video_tracking( if od_model == ODModels.COUNTGD: detection_function = countgd_object_detection + function_name = "countgd_object_detection" + elif od_model == ODModels.OWLV2: + detection_function = owl_v2_image + function_name = "owl_v2_image" else: raise NotImplementedError( f"Object detection model '{od_model.value}' is not implemented." @@ -2688,7 +2693,7 @@ def _transform_detections( buffer_bytes = frames_to_bytes(frames) files = [("video", buffer_bytes)] payload = {"bboxes": json.dumps(output), "chunk_length": chunk_length} - metadata = {"function_name": "countgd_sam2_video_tracking"} + metadata = {"function_name": function_name} detections = send_task_inference_request( payload, @@ -2753,6 +2758,55 @@ def countgd_sam2_video_tracking( ) +def owlv2_sam2_video_tracking( + prompt: str, + frames: List[np.ndarray], + chunk_length: Optional[int] = 10, + fine_tune_id: Optional[str] = None, +) -> List[List[Dict[str, Any]]]: + """'owlv2_sam2_video_tracking' is a tool that can segment multiple objects given a text + prompt such as category names or referring expressions. The categories in the text + prompt are separated by commas. It returns a list of bounding boxes, label names, + mask file names and associated probability scores of 1.0. + + Parameters: + prompt (str): The prompt to ground to the image. + image (np.ndarray): The image to ground the prompt to. + + Returns: + List[Dict[str, Any]]: A list of dictionaries containing the score, label, + bounding box, and mask of the detected objects with normalized coordinates + (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the top-left + and xmax and ymax are the coordinates of the bottom-right of the bounding box. + The mask is binary 2D numpy array where 1 indicates the object and 0 indicates + the background. + + Example + ------- + >>> countgd_sam2_video_tracking("car, dinosaur", image) + [ + { + 'score': 1.0, + 'label': 'dinosaur', + 'bbox': [0.1, 0.11, 0.35, 0.4], + 'mask': array([[0, 0, 0, ..., 0, 0, 0], + [0, 0, 0, ..., 0, 0, 0], + ..., + [0, 0, 0, ..., 0, 0, 0], + [0, 0, 0, ..., 0, 0, 0]], dtype=uint8), + }, + ] + """ + + return od_sam2_video_tracking( + ODModels.OWLV2, + prompt=prompt, + frames=frames, + chunk_length=chunk_length, + fine_tune_id=fine_tune_id, + ) + + FUNCTION_TOOLS = [ owl_v2_image, owl_v2_video, @@ -2775,7 +2829,6 @@ def countgd_sam2_video_tracking( video_temporal_localization, flux_image_inpainting, siglip_classification, - countgd_sam2_video_tracking, ] UTIL_TOOLS = [ From f19343813f0627ff5209e4ad85f2d47e201cddd4 Mon Sep 17 00:00:00 2001 From: hrnn Date: Fri, 6 Dec 2024 09:22:31 -0400 Subject: [PATCH 09/13] fixed import --- tests/integ/test_tools.py | 6 ++---- vision_agent/tools/__init__.py | 3 ++- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/integ/test_tools.py b/tests/integ/test_tools.py index 3246af9d..9d7e48bd 100644 --- a/tests/integ/test_tools.py +++ b/tests/integ/test_tools.py @@ -9,6 +9,7 @@ countgd_example_based_counting, countgd_object_detection, countgd_sam2_object_detection, + countgd_sam2_video_tracking, depth_anything_v2, detr_segmentation, dpt_hybrid_midas, @@ -32,6 +33,7 @@ ocr, owl_v2_image, owl_v2_video, + owlv2_sam2_video_tracking, qwen2_vl_images_vqa, qwen2_vl_video_vqa, siglip_classification, @@ -40,10 +42,6 @@ vit_image_classification, vit_nsfw_classification, ) -from vision_agent.tools.tools import ( - countgd_sam2_video_tracking, - owlv2_sam2_video_tracking, -) FINE_TUNE_ID = "65ebba4a-88b7-419f-9046-0750e30250da" diff --git a/vision_agent/tools/__init__.py b/vision_agent/tools/__init__.py index 2413b1ce..48a36c5d 100644 --- a/vision_agent/tools/__init__.py +++ b/vision_agent/tools/__init__.py @@ -31,6 +31,7 @@ countgd_example_based_counting, countgd_object_detection, countgd_sam2_object_detection, + countgd_sam2_video_tracking, depth_anything_v2, detr_segmentation, dpt_hybrid_midas, @@ -64,6 +65,7 @@ overlay_segmentation_masks, owl_v2_image, owl_v2_video, + owlv2_sam2_video_tracking, qwen2_vl_images_vqa, qwen2_vl_video_vqa, sam2, @@ -75,7 +77,6 @@ video_temporal_localization, vit_image_classification, vit_nsfw_classification, - countgd_sam2_video_tracking, ) __new_tools__ = [ From 12062e000245296b952e4c76cd876480acb39639 Mon Sep 17 00:00:00 2001 From: hrnn Date: Mon, 9 Dec 2024 12:18:59 -0400 Subject: [PATCH 10/13] fixed probability --- vision_agent/tools/tools.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index d02589dd..36fafdb5 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -2722,7 +2722,7 @@ def countgd_sam2_video_tracking( """'countgd_sam2_video_tracking' is a tool that can segment multiple objects given a text prompt such as category names or referring expressions. The categories in the text prompt are separated by commas. It returns a list of bounding boxes, label names, - mask file names and associated probability scores of 1.0. + mask file names and associated probability scores. Parameters: prompt (str): The prompt to ground to the image. @@ -2767,7 +2767,7 @@ def owlv2_sam2_video_tracking( """'owlv2_sam2_video_tracking' is a tool that can segment multiple objects given a text prompt such as category names or referring expressions. The categories in the text prompt are separated by commas. It returns a list of bounding boxes, label names, - mask file names and associated probability scores of 1.0. + mask file names and associated probability scores. Parameters: prompt (str): The prompt to ground to the image. From 1219fefd69e7b79766a511e8b64957cfae3ccd68 Mon Sep 17 00:00:00 2001 From: hrnn Date: Mon, 9 Dec 2024 16:27:25 -0400 Subject: [PATCH 11/13] fixed return example --- vision_agent/tools/tools.py | 48 ++++++++++++++++++++----------------- 1 file changed, 26 insertions(+), 22 deletions(-) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 36fafdb5..ef03c8a3 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -2738,18 +2738,20 @@ def countgd_sam2_video_tracking( Example ------- - >>> countgd_sam2_video_tracking("car, dinosaur", image) + >>> countgd_sam2_video_tracking("car, dinosaur", frames) [ - { - 'score': 1.0, - 'label': 'dinosaur', - 'bbox': [0.1, 0.11, 0.35, 0.4], - 'mask': array([[0, 0, 0, ..., 0, 0, 0], - [0, 0, 0, ..., 0, 0, 0], - ..., - [0, 0, 0, ..., 0, 0, 0], - [0, 0, 0, ..., 0, 0, 0]], dtype=uint8), - }, + [ + { + 'label': '0: dinosaur', + 'bbox': [0.1, 0.11, 0.35, 0.4], + 'mask': array([[0, 0, 0, ..., 0, 0, 0], + [0, 0, 0, ..., 0, 0, 0], + ..., + [0, 0, 0, ..., 0, 0, 0], + [0, 0, 0, ..., 0, 0, 0]], dtype=uint8), + }, + ], + ... ] """ @@ -2783,18 +2785,20 @@ def owlv2_sam2_video_tracking( Example ------- - >>> countgd_sam2_video_tracking("car, dinosaur", image) + >>> countgd_sam2_video_tracking("car, dinosaur", frames) [ - { - 'score': 1.0, - 'label': 'dinosaur', - 'bbox': [0.1, 0.11, 0.35, 0.4], - 'mask': array([[0, 0, 0, ..., 0, 0, 0], - [0, 0, 0, ..., 0, 0, 0], - ..., - [0, 0, 0, ..., 0, 0, 0], - [0, 0, 0, ..., 0, 0, 0]], dtype=uint8), - }, + [ + { + 'label': '0: dinosaur', + 'bbox': [0.1, 0.11, 0.35, 0.4], + 'mask': array([[0, 0, 0, ..., 0, 0, 0], + [0, 0, 0, ..., 0, 0, 0], + ..., + [0, 0, 0, ..., 0, 0, 0], + [0, 0, 0, ..., 0, 0, 0]], dtype=uint8), + }, + ], + ... ] """ From 36607d2bd57f24a35a7b0f9e8d6c89e854e5cd47 Mon Sep 17 00:00:00 2001 From: hrnn Date: Mon, 9 Dec 2024 18:14:10 -0400 Subject: [PATCH 12/13] added florence2 support --- tests/integ/test_tools.py | 17 +++++++++++++++++ vision_agent/tools/__init__.py | 1 + vision_agent/tools/tools.py | 13 ++++++++++++- 3 files changed, 30 insertions(+), 1 deletion(-) diff --git a/tests/integ/test_tools.py b/tests/integ/test_tools.py index 9d7e48bd..61b18453 100644 --- a/tests/integ/test_tools.py +++ b/tests/integ/test_tools.py @@ -31,6 +31,7 @@ loca_visual_prompt_counting, loca_zero_shot_counting, ocr, + od_sam2_video_tracking, owl_v2_image, owl_v2_video, owlv2_sam2_video_tracking, @@ -656,3 +657,19 @@ def test_video_tracking_with_owlv2(): assert len(result) == 10 assert len([res["label"] for res in result[0]]) == 24 assert len([res["mask"] for res in result[0]]) == 24 + + +def test_video_tracking_by_given_model(): + + frames = [ + np.array(Image.fromarray(ski.data.coins()).convert("RGB")) for _ in range(10) + ] + result = od_sam2_video_tracking( + od_model="florence2", + prompt="coin", + frames=frames, + ) + + assert len(result) == 10 + assert len([res["label"] for res in result[0]]) == 24 + assert len([res["mask"] for res in result[0]]) == 24 diff --git a/vision_agent/tools/__init__.py b/vision_agent/tools/__init__.py index 48a36c5d..99145416 100644 --- a/vision_agent/tools/__init__.py +++ b/vision_agent/tools/__init__.py @@ -60,6 +60,7 @@ loca_zero_shot_counting, minimum_distance, ocr, + od_sam2_video_tracking, overlay_bounding_boxes, overlay_heat_map, overlay_segmentation_masks, diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index ef03c8a3..bf7efb64 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -2628,6 +2628,7 @@ def _plot_counting( class ODModels(str, Enum): COUNTGD = "countgd" + FLORENCE2 = "florence2" OWLV2 = "owlv2" @@ -2639,12 +2640,21 @@ def od_sam2_video_tracking( fine_tune_id: Optional[str] = None, ) -> List[List[Dict[str, Any]]]: + params = { + "prompt": prompt, + } + if od_model == ODModels.COUNTGD: detection_function = countgd_object_detection function_name = "countgd_object_detection" elif od_model == ODModels.OWLV2: detection_function = owl_v2_image function_name = "owl_v2_image" + params["fine_tune_id"] = fine_tune_id + elif od_model == ODModels.FLORENCE2: + detection_function = florence2_sam2_image + function_name = "florence2_sam2_image" + params["fine_tune_id"] = fine_tune_id else: raise NotImplementedError( f"Object detection model '{od_model.value}' is not implemented." @@ -2660,7 +2670,8 @@ def od_sam2_video_tracking( step = chunk_length # Process frames with the specified step size for idx in range(0, len(frames), step): - results[idx] = detection_function(prompt=prompt, image=frames[idx]) + params["image"] = frames[idx] + results[idx] = detection_function(**params) image_size = frames[0].shape[:2] From c1320b8853aad4ae59628116f1b09779dbbf40b1 Mon Sep 17 00:00:00 2001 From: hrnn Date: Mon, 9 Dec 2024 18:47:16 -0400 Subject: [PATCH 13/13] fixed mypy --- vision_agent/tools/tools.py | 39 ++++++++++++++++--------------------- 1 file changed, 17 insertions(+), 22 deletions(-) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index bf7efb64..7e2f9fc3 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -2640,26 +2640,6 @@ def od_sam2_video_tracking( fine_tune_id: Optional[str] = None, ) -> List[List[Dict[str, Any]]]: - params = { - "prompt": prompt, - } - - if od_model == ODModels.COUNTGD: - detection_function = countgd_object_detection - function_name = "countgd_object_detection" - elif od_model == ODModels.OWLV2: - detection_function = owl_v2_image - function_name = "owl_v2_image" - params["fine_tune_id"] = fine_tune_id - elif od_model == ODModels.FLORENCE2: - detection_function = florence2_sam2_image - function_name = "florence2_sam2_image" - params["fine_tune_id"] = fine_tune_id - else: - raise NotImplementedError( - f"Object detection model '{od_model.value}' is not implemented." - ) - results: List[Optional[List[Dict[str, Any]]]] = [None] * len(frames) if chunk_length is None: @@ -2670,8 +2650,23 @@ def od_sam2_video_tracking( step = chunk_length # Process frames with the specified step size for idx in range(0, len(frames), step): - params["image"] = frames[idx] - results[idx] = detection_function(**params) + if od_model == ODModels.COUNTGD: + results[idx] = countgd_object_detection(prompt=prompt, image=frames[idx]) + function_name = "countgd_object_detection" + elif od_model == ODModels.OWLV2: + results[idx] = owl_v2_image( + prompt=prompt, image=frames[idx], fine_tune_id=fine_tune_id + ) + function_name = "owl_v2_image" + elif od_model == ODModels.FLORENCE2: + results[idx] = florence2_sam2_image( + prompt=prompt, image=frames[idx], fine_tune_id=fine_tune_id + ) + function_name = "florence2_sam2_image" + else: + raise NotImplementedError( + f"Object detection model '{od_model}' is not implemented." + ) image_size = frames[0].shape[:2]