Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: countgd sam2 video support #318

Merged
merged 15 commits into from
Dec 13, 2024
34 changes: 33 additions & 1 deletion tests/integ/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
blip_image_caption,
clip,
closest_mask_distance,
countgd_example_based_counting,
countgd_object_detection,
countgd_sam2_object_detection,
countgd_example_based_counting,
countgd_sam2_video_tracking,
depth_anything_v2,
detr_segmentation,
dpt_hybrid_midas,
Expand All @@ -32,6 +33,7 @@
ocr,
owl_v2_image,
owl_v2_video,
owlv2_sam2_video_tracking,
qwen2_vl_images_vqa,
qwen2_vl_video_vqa,
siglip_classification,
Expand Down Expand Up @@ -624,3 +626,33 @@ def test_flux_image_inpainting_resizing_big_image():

assert result.shape[0] == 512
assert result.shape[1] == 208


def test_video_tracking_with_countgd():

frames = [
np.array(Image.fromarray(ski.data.coins()).convert("RGB")) for _ in range(10)
]
result = countgd_sam2_video_tracking(
prompt="coin",
frames=frames,
)

assert len(result) == 10
assert len([res["label"] for res in result[0]]) == 24
assert len([res["mask"] for res in result[0]]) == 24


def test_video_tracking_with_owlv2():

frames = [
np.array(Image.fromarray(ski.data.coins()).convert("RGB")) for _ in range(10)
]
result = owlv2_sam2_video_tracking(
prompt="coin",
frames=frames,
)

assert len(result) == 10
assert len([res["label"] for res in result[0]]) == 24
assert len([res["mask"] for res in result[0]]) == 24
2 changes: 2 additions & 0 deletions vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
countgd_example_based_counting,
countgd_object_detection,
countgd_sam2_object_detection,
countgd_sam2_video_tracking,
depth_anything_v2,
detr_segmentation,
dpt_hybrid_midas,
Expand Down Expand Up @@ -64,6 +65,7 @@
overlay_segmentation_masks,
owl_v2_image,
owl_v2_video,
owlv2_sam2_video_tracking,
qwen2_vl_images_vqa,
qwen2_vl_video_vqa,
sam2,
Expand Down
182 changes: 182 additions & 0 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import tempfile
import urllib.request
from concurrent.futures import ThreadPoolExecutor, as_completed
from enum import Enum
from functools import lru_cache
from importlib import resources
from pathlib import Path
Expand Down Expand Up @@ -2625,6 +2626,187 @@ def _plot_counting(
return image


class ODModels(str, Enum):
dillonalaird marked this conversation as resolved.
Show resolved Hide resolved
COUNTGD = "countgd"
OWLV2 = "owlv2"


def od_sam2_video_tracking(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

od_model: ODModels,
prompt: str,
frames: List[np.ndarray],
chunk_length: Optional[int] = 10,
fine_tune_id: Optional[str] = None,
) -> List[List[Dict[str, Any]]]:

if od_model == ODModels.COUNTGD:
detection_function = countgd_object_detection
function_name = "countgd_object_detection"
elif od_model == ODModels.OWLV2:
detection_function = owl_v2_image
function_name = "owl_v2_image"
else:
raise NotImplementedError(
f"Object detection model '{od_model.value}' is not implemented."
)

results: List[Optional[List[Dict[str, Any]]]] = [None] * len(frames)

if chunk_length is None:
step = 1 # Process every frame
elif chunk_length <= 0:
raise ValueError("chunk_length must be a positive integer or None.")
else:
step = chunk_length # Process frames with the specified step size

for idx in range(0, len(frames), step):
results[idx] = detection_function(prompt=prompt, image=frames[idx])

image_size = frames[0].shape[:2]

def _transform_detections(
input_list: List[Optional[List[Dict[str, Any]]]]
) -> List[Optional[Dict[str, Any]]]:
output_list: List[Optional[Dict[str, Any]]] = []

for idx, frame in enumerate(input_list):
if frame is not None:
labels = [detection["label"] for detection in frame]
bboxes = [
denormalize_bbox(detection["bbox"], image_size)
for detection in frame
]

output_list.append(
{
"labels": labels,
"bboxes": bboxes,
}
)
else:
output_list.append(None)

return output_list

output = _transform_detections(results)

buffer_bytes = frames_to_bytes(frames)
files = [("video", buffer_bytes)]
payload = {"bboxes": json.dumps(output), "chunk_length": chunk_length}
metadata = {"function_name": function_name}

detections = send_task_inference_request(
payload,
"sam2",
files=files,
metadata=metadata,
)

return_data = []
for frame in detections:
return_frame_data = []
for detection in frame:
mask = rle_decode_array(detection["mask"])
label = str(detection["id"]) + ": " + detection["label"]
return_frame_data.append({"label": label, "mask": mask, "score": 1.0})
return_data.append(return_frame_data)
return_data = add_bboxes_from_masks(return_data)
return nms(return_data, iou_threshold=0.95)


def countgd_sam2_video_tracking(
prompt: str,
frames: List[np.ndarray],
chunk_length: Optional[int] = 10,
) -> List[List[Dict[str, Any]]]:
"""'countgd_sam2_video_tracking' is a tool that can segment multiple objects given a text
prompt such as category names or referring expressions. The categories in the text
prompt are separated by commas. It returns a list of bounding boxes, label names,
mask file names and associated probability scores of 1.0.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minro comment, only florence2 returns probability scores of 1.0, countgd and owlv2 will can return regular probability scores. So you can just say "and associated probability scores."

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed


Parameters:
prompt (str): The prompt to ground to the image.
image (np.ndarray): The image to ground the prompt to.

Returns:
List[Dict[str, Any]]: A list of dictionaries containing the score, label,
bounding box, and mask of the detected objects with normalized coordinates
(xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the top-left
and xmax and ymax are the coordinates of the bottom-right of the bounding box.
The mask is binary 2D numpy array where 1 indicates the object and 0 indicates
the background.

Example
-------
>>> countgd_sam2_video_tracking("car, dinosaur", image)
[
{
'score': 1.0,
'label': 'dinosaur',
'bbox': [0.1, 0.11, 0.35, 0.4],
'mask': array([[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
...,
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]], dtype=uint8),
},
]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use the return values and examples from florence2_sam2_video_tracking. It's actually a list of list of dictionaries where the inner list is a frame

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

"""

return od_sam2_video_tracking(
ODModels.COUNTGD, prompt=prompt, frames=frames, chunk_length=chunk_length
)


def owlv2_sam2_video_tracking(
prompt: str,
frames: List[np.ndarray],
chunk_length: Optional[int] = 10,
fine_tune_id: Optional[str] = None,
) -> List[List[Dict[str, Any]]]:
"""'owlv2_sam2_video_tracking' is a tool that can segment multiple objects given a text
prompt such as category names or referring expressions. The categories in the text
prompt are separated by commas. It returns a list of bounding boxes, label names,
mask file names and associated probability scores of 1.0.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see comment above on prob scores

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed


Parameters:
prompt (str): The prompt to ground to the image.
image (np.ndarray): The image to ground the prompt to.

Returns:
List[Dict[str, Any]]: A list of dictionaries containing the score, label,
bounding box, and mask of the detected objects with normalized coordinates
(xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the top-left
and xmax and ymax are the coordinates of the bottom-right of the bounding box.
The mask is binary 2D numpy array where 1 indicates the object and 0 indicates
the background.

Example
-------
>>> countgd_sam2_video_tracking("car, dinosaur", image)
[
{
'score': 1.0,
'label': 'dinosaur',
'bbox': [0.1, 0.11, 0.35, 0.4],
'mask': array([[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
...,
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]], dtype=uint8),
},
]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see comment above on return comments

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

"""

return od_sam2_video_tracking(
ODModels.OWLV2,
prompt=prompt,
frames=frames,
chunk_length=chunk_length,
fine_tune_id=fine_tune_id,
)


FUNCTION_TOOLS = [
owl_v2_image,
owl_v2_video,
Expand Down