Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch annotation #35

Merged
merged 13 commits into from
Feb 23, 2024
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ datadreamer --save_dir <directory> --class_names <objects> --prompts_number <num
- `--image_tester_patience`: Patience level for image tester. Default is 1.
- `--lm_quantization`: Quantization to use for Mistral language model. Choose between `none` and `4bit`. Default is `none`.
- `--batch_size_prompt`: Batch size for prompt generation. Default is 64.
- `--batch_size_annotation`: Batch size for annotation. Default is 8.
- `--batch_size_image`: Batch size for image generation. Default is 1.
- `--device`: Choose between `cuda` and `cpu`. Default is cuda.
- `--seed`: Set a random seed for image and prompt generation. Default is 42.
Expand Down
2 changes: 1 addition & 1 deletion datadreamer/dataset_annotation/image_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,5 @@ def __init__(
self.task_definition = task_definition

@abstractmethod
def annotate(self):
def annotate_batch(self):
pass
207 changes: 150 additions & 57 deletions datadreamer/dataset_annotation/owlv2_annotator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from typing import List, Tuple

import numpy as np
import PIL
import torch
from transformers import Owlv2ForObjectDetection, Owlv2Processor

Expand All @@ -18,7 +22,7 @@ class OWLv2Annotator(BaseAnnotator):
Methods:
_init_model(): Initializes the OWLv2 model.
_init_processor(): Initializes the processor for the OWLv2 model.
annotate(image, prompts, conf_threshold, use_tta, synonym_dict): Annotates the given image with bounding boxes and labels.
annotate_batch(image, prompts, conf_threshold, use_tta, synonym_dict): Annotates the given image with bounding boxes and labels.
release(empty_cuda_cache): Releases resources and optionally empties the CUDA cache.
"""

Expand Down Expand Up @@ -59,13 +63,85 @@ def _init_processor(self):
"google/owlv2-base-patch16-ensemble", do_pad=False
)

def annotate(
self, image, prompts, conf_threshold=0.1, use_tta=False, synonym_dict=None
):
"""Annotates an image using the OWLv2 model.
def _generate_annotations(
self,
images: List[PIL.Image.Image],
prompts: List[str],
conf_threshold: float = 0.1,
) -> List[dict[str, torch.Tensor]]:
"""Generates annotations for the given images and prompts.

Args:
image: The image to be annotated.
images: The images to be annotated.
prompts: Prompts to guide the annotation.
conf_threshold (float, optional): Confidence threshold for the annotations. Defaults to 0.1.

Returns:
dict: A dictionary containing the annotations for the images.
"""
n = len(images)
batched_prompts = [prompts] * n
target_sizes = torch.Tensor(images[0].size[::-1]).repeat((n, 1)).to(self.device)

inputs = self.processor(
text=batched_prompts, images=images, return_tensors="pt"
).to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
# print(outputs)
preds = self.processor.post_process_object_detection(
outputs=outputs, target_sizes=target_sizes, threshold=conf_threshold
)

return preds

def _get_annotations(
self,
pred: dict[str, torch.Tensor],
use_tta: bool,
img_dim: int,
synonym_dict: dict[str, List[str]] | None,
synonym_dict_rev: dict[int, int] | None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Extracts the annotations from the predictions.

Args:
pred: The predictions from the model.
use_tta (bool): Flag to whether the test-time augmentation was applied.
img_dim (int): The dimension of the image.
synonym_dict (dict): Dictionary for handling synonyms in labels.
synonym_dict_rev (dict): Dictionary for handling synonyms in labels.

Returns:
tuple: A tuple containing the final bounding boxes, scores, and labels for the annotations.
"""

boxes, scores, labels = (
pred["boxes"],
pred["scores"],
pred["labels"],
)
# Flip boxes back if using TTA
if use_tta:
boxes[:, [0, 2]] = img_dim - boxes[:, [2, 0]]

if synonym_dict is not None:
labels = torch.tensor([synonym_dict_rev[label.item()] for label in labels])

return boxes, scores, labels

def annotate_batch(
self,
images: List[PIL.Image.Image],
prompts: List[str],
conf_threshold: float = 0.1,
use_tta: bool = False,
synonym_dict: dict[str, List[str]] | None = None,
) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
"""Annotates images using the OWLv2 model.

Args:
images: The images to be annotated.
prompts: Prompts to guide the annotation.
conf_threshold (float, optional): Confidence threshold for the annotations. Defaults to 0.1.
use_tta (bool, optional): Flag to apply test-time augmentation. Defaults to False.
Expand All @@ -75,9 +151,7 @@ def annotate(
tuple: A tuple containing the final bounding boxes, scores, and labels for the annotations.
"""
if use_tta:
augmented_images = apply_tta(image)
else:
augmented_images = [image]
augmented_images = [apply_tta(image)[0] for image in images]

if synonym_dict is not None:
prompts_syn = []
Expand All @@ -93,69 +167,88 @@ def annotate(
synonym_dict_rev[prompts_syn.index(v)] = prompts.index(key)
prompts = prompts_syn

all_boxes = []
all_scores = []
all_labels = []
preds = self._generate_annotations(images, prompts, conf_threshold)
if use_tta:
augmented_preds = self._generate_annotations(
augmented_images, prompts, conf_threshold
)
else:
augmented_preds = [None] * len(images)

target_sizes = torch.Tensor([augmented_images[0].size[::-1]]).to(self.device)
final_boxes = []
final_scores = []
final_labels = []

for aug_image in augmented_images:
inputs = self.processor(
text=prompts, images=aug_image, return_tensors="pt"
).to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
# print(outputs)
preds = self.processor.post_process_object_detection(
outputs=outputs, target_sizes=target_sizes, threshold=conf_threshold
for i, (pred, aug_pred) in enumerate(zip(preds, augmented_preds)):
boxes, scores, labels = self._get_annotations(
pred,
False,
images[i].size[0],
synonym_dict,
synonym_dict_rev if synonym_dict is not None else None,
)

boxes, scores, labels = (
preds[0]["boxes"],
preds[0]["scores"],
preds[0]["labels"],
)
# Flip boxes back if using TTA
if use_tta and len(all_boxes) == 1:
boxes[:, [0, 2]] = image.size[0] - boxes[:, [2, 0]]
all_boxes = [boxes.to("cpu")]
all_scores = [scores.to("cpu")]
all_labels = [labels.to("cpu")]

if synonym_dict is not None:
labels = torch.tensor(
[synonym_dict_rev[label.item()] for label in labels]
# Flip boxes back if using TTA
if use_tta:
aug_boxes, aug_scores, aug_labels = self._get_annotations(
aug_pred,
True,
images[i].size[0],
synonym_dict,
synonym_dict_rev if synonym_dict is not None else None,
)

all_boxes.append(boxes.to("cpu"))
all_scores.append(scores.to("cpu"))
all_labels.append(labels.to("cpu"))
all_boxes.append(aug_boxes.to("cpu"))
all_scores.append(aug_scores.to("cpu"))
all_labels.append(aug_labels.to("cpu"))

# Convert list of tensors to a single tensor for NMS
all_boxes_cat = torch.cat(all_boxes)
all_scores_cat = torch.cat(all_scores)
all_labels_cat = torch.cat(all_labels)
one_hot_labels = torch.nn.functional.one_hot(
torch.cat(all_labels), num_classes=len(prompts)
)

one_hot_labels = torch.nn.functional.one_hot(
all_labels_cat, num_classes=len(prompts)
)
# Apply NMS
# transform predictions to shape [N, 5 + num_classes], N is the number of bboxes for nms function
all_boxes_cat = torch.cat(
(
torch.cat(all_boxes),
torch.cat(all_scores).unsqueeze(-1),
one_hot_labels,
),
dim=1,
)

# Apply NMS
# transform predictions to shape [N, 5 + num_classes], N is the number of bboxes for nms function
all_boxes_cat = torch.cat(
(all_boxes_cat, all_scores_cat.unsqueeze(-1), one_hot_labels),
dim=1,
)
# output is a list of detections, each item is one tensor with shape (num_boxes, 6), 6 is for [xyxy, conf, cls].
output = non_max_suppression(
all_boxes_cat.unsqueeze(0), conf_thres=conf_threshold, iou_thres=0.2
)

# output is a list of detections, each item is one tensor with shape (num_boxes, 6), 6 is for [xyxy, conf, cls].
output = non_max_suppression(
all_boxes_cat.unsqueeze(0), conf_thres=conf_threshold, iou_thres=0.2
)
output_boxes = output[0][:, :4]
output_scores = output[0][:, 4]
output_local_labels = output[0][:, 5].long()

final_boxes = output[0][:, :4]
final_scores = output[0][:, 4]
final_labels = output[0][:, 5].long()
final_boxes.append(
output_boxes.detach().cpu().numpy()
if not isinstance(output_boxes, np.ndarray)
else output_boxes
)
final_scores.append(
output_scores.detach().cpu().numpy()
if not isinstance(output_scores, np.ndarray)
else output_scores
)
final_labels.append(
output_local_labels.detach().cpu().numpy()
if not isinstance(output_local_labels, np.ndarray)
else output_local_labels
)

return final_boxes, final_scores, final_labels

def release(self, empty_cuda_cache=False) -> None:
def release(self, empty_cuda_cache: bool = False) -> None:
"""Releases the model and optionally empties the CUDA cache.

Args:
Expand Down
4 changes: 2 additions & 2 deletions datadreamer/dataset_annotation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ def apply_tta(image):
image: The image to be augmented.
Returns:
list: A list of augmented images, including the original and transformed versions.
list: A list of augmented images.
Note:
Currently, only horizontal flip is enabled. Additional transformations like
vertical flip and color jitter are commented out but can be enabled as needed.
"""
tta_transforms = [
# Original image
transforms.Compose([]),
# transforms.Compose([]),
# Horizontal Flip
transforms.Compose([transforms.RandomHorizontalFlip(p=1)]),
# Vertical Flip
Expand Down
Loading