Skip to content

Commit

Permalink
Batch annotation (#35)
Browse files Browse the repository at this point in the history
* Add batch annotation

* Update docs & add test & fix batched annotation

* Change default batch annotation

* Fix annotation tests

* Fix tests

* [Automated] Updated coverage badge

* Update annotation example & docstrings

* Fix formatting

* Fix docstring

* [Automated] Updated coverage badge

* refactor: replace annotate() with annotate_batch()

* feature: replace owlv2 resize

---------

Co-authored-by: Jan Cuhel <[email protected]>
Co-authored-by: GitHub Actions <[email protected]>
Co-authored-by: Nikita Sokovnin <[email protected]>
  • Loading branch information
4 people authored Feb 23, 2024
1 parent 6d313af commit e8a9e11
Show file tree
Hide file tree
Showing 10 changed files with 306 additions and 193 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ datadreamer --save_dir <directory> --class_names <objects> --prompts_number <num
- `--image_tester_patience`: Patience level for image tester. Default is 1.
- `--lm_quantization`: Quantization to use for Mistral language model. Choose between `none` and `4bit`. Default is `none`.
- `--batch_size_prompt`: Batch size for prompt generation. Default is 64.
- `--batch_size_annotation`: Batch size for annotation. Default is 8.
- `--batch_size_image`: Batch size for image generation. Default is 1.
- `--device`: Choose between `cuda` and `cpu`. Default is cuda.
- `--seed`: Set a random seed for image and prompt generation. Default is 42.
Expand Down
4 changes: 2 additions & 2 deletions datadreamer/dataset_annotation/image_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class BaseAnnotator(ABC):
which can be overridden by subclasses for specific tasks.
Methods:
annotate(): Abstract method to be implemented by subclasses. It should contain
annotate_batch(): Abstract method to be implemented by subclasses. It should contain
the logic for performing annotation based on the task definition.
"""

Expand All @@ -31,5 +31,5 @@ def __init__(
self.task_definition = task_definition

@abstractmethod
def annotate(self):
def annotate_batch(self):
pass
211 changes: 153 additions & 58 deletions datadreamer/dataset_annotation/owlv2_annotator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from typing import List, Tuple

import numpy as np
import PIL
import torch
from transformers import Owlv2ForObjectDetection, Owlv2Processor

Expand All @@ -18,7 +22,7 @@ class OWLv2Annotator(BaseAnnotator):
Methods:
_init_model(): Initializes the OWLv2 model.
_init_processor(): Initializes the processor for the OWLv2 model.
annotate(image, prompts, conf_threshold, use_tta, synonym_dict): Annotates the given image with bounding boxes and labels.
annotate_batch(image, prompts, conf_threshold, use_tta, synonym_dict): Annotates the given image with bounding boxes and labels.
release(empty_cuda_cache): Releases resources and optionally empties the CUDA cache.
"""

Expand Down Expand Up @@ -56,16 +60,90 @@ def _init_processor(self):
Owlv2Processor: The initialized processor.
"""
return Owlv2Processor.from_pretrained(
"google/owlv2-base-patch16-ensemble", do_pad=False
"google/owlv2-base-patch16-ensemble", do_pad=False, do_resize=False
)

def annotate(
self, image, prompts, conf_threshold=0.1, use_tta=False, synonym_dict=None
):
"""Annotates an image using the OWLv2 model.
def _generate_annotations(
self,
images: List[PIL.Image.Image],
prompts: List[str],
conf_threshold: float = 0.1,
) -> List[dict[str, torch.Tensor]]:
"""Generates annotations for the given images and prompts.
Args:
image: The image to be annotated.
images: The images to be annotated.
prompts: Prompts to guide the annotation.
conf_threshold (float, optional): Confidence threshold for the annotations. Defaults to 0.1.
Returns:
dict: A dictionary containing the annotations for the images.
"""
n = len(images)
batched_prompts = [prompts] * n
target_sizes = torch.Tensor(images[0].size[::-1]).repeat((n, 1)).to(self.device)

# resize the images to the model's input size
images = [images[i].resize((960, 960)) for i in range(n)]
inputs = self.processor(
text=batched_prompts, images=images, return_tensors="pt"
).to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
# print(outputs)
preds = self.processor.post_process_object_detection(
outputs=outputs, target_sizes=target_sizes, threshold=conf_threshold
)

return preds

def _get_annotations(
self,
pred: dict[str, torch.Tensor],
use_tta: bool,
img_dim: int,
synonym_dict: dict[str, List[str]] | None,
synonym_dict_rev: dict[int, int] | None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Extracts the annotations from the predictions.
Args:
pred: The predictions from the model.
use_tta (bool): Flag to whether the test-time augmentation was applied.
img_dim (int): The dimension of the image.
synonym_dict (dict): Dictionary for handling synonyms in labels.
synonym_dict_rev (dict): Dictionary for handling synonyms in labels.
Returns:
tuple: A tuple containing the final bounding boxes, scores, and labels for the annotations.
"""

boxes, scores, labels = (
pred["boxes"],
pred["scores"],
pred["labels"],
)
# Flip boxes back if using TTA
if use_tta:
boxes[:, [0, 2]] = img_dim - boxes[:, [2, 0]]

if synonym_dict is not None:
labels = torch.tensor([synonym_dict_rev[label.item()] for label in labels])

return boxes, scores, labels

def annotate_batch(
self,
images: List[PIL.Image.Image],
prompts: List[str],
conf_threshold: float = 0.1,
use_tta: bool = False,
synonym_dict: dict[str, List[str]] | None = None,
) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
"""Annotates images using the OWLv2 model.
Args:
images: The images to be annotated.
prompts: Prompts to guide the annotation.
conf_threshold (float, optional): Confidence threshold for the annotations. Defaults to 0.1.
use_tta (bool, optional): Flag to apply test-time augmentation. Defaults to False.
Expand All @@ -75,9 +153,7 @@ def annotate(
tuple: A tuple containing the final bounding boxes, scores, and labels for the annotations.
"""
if use_tta:
augmented_images = apply_tta(image)
else:
augmented_images = [image]
augmented_images = [apply_tta(image)[0] for image in images]

if synonym_dict is not None:
prompts_syn = []
Expand All @@ -93,69 +169,88 @@ def annotate(
synonym_dict_rev[prompts_syn.index(v)] = prompts.index(key)
prompts = prompts_syn

all_boxes = []
all_scores = []
all_labels = []
preds = self._generate_annotations(images, prompts, conf_threshold)
if use_tta:
augmented_preds = self._generate_annotations(
augmented_images, prompts, conf_threshold
)
else:
augmented_preds = [None] * len(images)

target_sizes = torch.Tensor([augmented_images[0].size[::-1]]).to(self.device)
final_boxes = []
final_scores = []
final_labels = []

for aug_image in augmented_images:
inputs = self.processor(
text=prompts, images=aug_image, return_tensors="pt"
).to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
# print(outputs)
preds = self.processor.post_process_object_detection(
outputs=outputs, target_sizes=target_sizes, threshold=conf_threshold
for i, (pred, aug_pred) in enumerate(zip(preds, augmented_preds)):
boxes, scores, labels = self._get_annotations(
pred,
False,
images[i].size[0],
synonym_dict,
synonym_dict_rev if synonym_dict is not None else None,
)

boxes, scores, labels = (
preds[0]["boxes"],
preds[0]["scores"],
preds[0]["labels"],
)
# Flip boxes back if using TTA
if use_tta and len(all_boxes) == 1:
boxes[:, [0, 2]] = image.size[0] - boxes[:, [2, 0]]
all_boxes = [boxes.to("cpu")]
all_scores = [scores.to("cpu")]
all_labels = [labels.to("cpu")]

if synonym_dict is not None:
labels = torch.tensor(
[synonym_dict_rev[label.item()] for label in labels]
# Flip boxes back if using TTA
if use_tta:
aug_boxes, aug_scores, aug_labels = self._get_annotations(
aug_pred,
True,
images[i].size[0],
synonym_dict,
synonym_dict_rev if synonym_dict is not None else None,
)

all_boxes.append(boxes.to("cpu"))
all_scores.append(scores.to("cpu"))
all_labels.append(labels.to("cpu"))
all_boxes.append(aug_boxes.to("cpu"))
all_scores.append(aug_scores.to("cpu"))
all_labels.append(aug_labels.to("cpu"))

# Convert list of tensors to a single tensor for NMS
all_boxes_cat = torch.cat(all_boxes)
all_scores_cat = torch.cat(all_scores)
all_labels_cat = torch.cat(all_labels)
one_hot_labels = torch.nn.functional.one_hot(
torch.cat(all_labels), num_classes=len(prompts)
)

one_hot_labels = torch.nn.functional.one_hot(
all_labels_cat, num_classes=len(prompts)
)
# Apply NMS
# transform predictions to shape [N, 5 + num_classes], N is the number of bboxes for nms function
all_boxes_cat = torch.cat(
(
torch.cat(all_boxes),
torch.cat(all_scores).unsqueeze(-1),
one_hot_labels,
),
dim=1,
)

# Apply NMS
# transform predictions to shape [N, 5 + num_classes], N is the number of bboxes for nms function
all_boxes_cat = torch.cat(
(all_boxes_cat, all_scores_cat.unsqueeze(-1), one_hot_labels),
dim=1,
)
# output is a list of detections, each item is one tensor with shape (num_boxes, 6), 6 is for [xyxy, conf, cls].
output = non_max_suppression(
all_boxes_cat.unsqueeze(0), conf_thres=conf_threshold, iou_thres=0.2
)

# output is a list of detections, each item is one tensor with shape (num_boxes, 6), 6 is for [xyxy, conf, cls].
output = non_max_suppression(
all_boxes_cat.unsqueeze(0), conf_thres=conf_threshold, iou_thres=0.2
)
output_boxes = output[0][:, :4]
output_scores = output[0][:, 4]
output_local_labels = output[0][:, 5].long()

final_boxes = output[0][:, :4]
final_scores = output[0][:, 4]
final_labels = output[0][:, 5].long()
final_boxes.append(
output_boxes.detach().cpu().numpy()
if not isinstance(output_boxes, np.ndarray)
else output_boxes
)
final_scores.append(
output_scores.detach().cpu().numpy()
if not isinstance(output_scores, np.ndarray)
else output_scores
)
final_labels.append(
output_local_labels.detach().cpu().numpy()
if not isinstance(output_local_labels, np.ndarray)
else output_local_labels
)

return final_boxes, final_scores, final_labels

def release(self, empty_cuda_cache=False) -> None:
def release(self, empty_cuda_cache: bool = False) -> None:
"""Releases the model and optionally empties the CUDA cache.
Args:
Expand Down
4 changes: 2 additions & 2 deletions datadreamer/dataset_annotation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ def apply_tta(image):
image: The image to be augmented.
Returns:
list: A list of augmented images, including the original and transformed versions.
list: A list of augmented images.
Note:
Currently, only horizontal flip is enabled. Additional transformations like
vertical flip and color jitter are commented out but can be enabled as needed.
"""
tta_transforms = [
# Original image
transforms.Compose([]),
# transforms.Compose([]),
# Horizontal Flip
transforms.Compose([transforms.RandomHorizontalFlip(p=1)]),
# Vertical Flip
Expand Down
Loading

0 comments on commit e8a9e11

Please sign in to comment.