Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/add instance segmentation #67

Merged
merged 29 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
3d3c88e
Add FastSAM
HonzaCuhel Oct 17, 2024
f2dbf33
Update
HonzaCuhel Oct 19, 2024
535d09a
Update Colab notebook
HonzaCuhel Oct 19, 2024
454d749
Add vizualization
HonzaCuhel Oct 20, 2024
7bb93e9
Update README.md and tests
HonzaCuhel Oct 21, 2024
3fcb736
Update COCO converter
HonzaCuhel Oct 21, 2024
5a0795d
Refactor YOLO converter
HonzaCuhel Oct 21, 2024
c0cf6ab
Refactor visualize function
HonzaCuhel Oct 21, 2024
a1c6b6a
[Automated] Updated coverage badge
actions-user Oct 21, 2024
7879220
fix: different color for different classes in the segmenetation visua…
sokovninn Oct 21, 2024
4fae718
Switch to SlimSAM
HonzaCuhel Oct 24, 2024
f40e5a0
Switch to SlimSAM
HonzaCuhel Oct 24, 2024
853d5ad
Update instance segmentation example
HonzaCuhel Oct 24, 2024
04e91fd
Update tests
HonzaCuhel Oct 24, 2024
ff771ad
Fix: annotator tests
HonzaCuhel Oct 24, 2024
335cc05
[Automated] Updated coverage badge
actions-user Oct 24, 2024
f887910
Update docs & luxonis dataset creation
HonzaCuhel Oct 25, 2024
b8151cb
fix: return SliamSAM processor
sokovninn Oct 25, 2024
af08e4b
fix: handle empty polygon list
sokovninn Oct 25, 2024
c566bea
Fix: remove long outputs from Jupyter Notebook
HonzaCuhel Oct 25, 2024
07a58f0
Fix: README.md
HonzaCuhel Oct 25, 2024
057a9b4
Add OWLv2 non-square pixel fix
HonzaCuhel Oct 25, 2024
437d067
Rename vars
HonzaCuhel Oct 25, 2024
cd819c4
Fix: correct all SlimSAM mentions
HonzaCuhel Oct 25, 2024
5e45347
fix: different image sizes for owlv2 postprocessing
sokovninn Oct 25, 2024
3b915ba
Update OWLv2 bbox correction
HonzaCuhel Oct 25, 2024
68487e4
fix: pass segmentation annotator size
sokovninn Oct 25, 2024
5401431
fix: shifted annotations when tta is used
sokovninn Oct 25, 2024
d47253a
Fix OWLv2 device
HonzaCuhel Oct 28, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,13 @@ datadreamer --config <path-to-config>

### 🔧 Additional Parameters

- `--task`: Choose between detection and classification. Default is `detection`.
- `--task`: Choose between detection, classification and instance segmentation. Default is `detection`.
- `--dataset_format`: Format of the dataset. Defaults to `raw`. Supported values: `raw`, `yolo`, `coco`, `luxonis-dataset`, `cls-single`.
- `--split_ratios`: Split ratios for train, validation, and test sets. Defaults to `[0.8, 0.1, 0.1]`.
- `--num_objects_range`: Range of objects in a prompt. Default is 1 to 3.
- `--prompt_generator`: Choose between `simple`, `lm` (Mistral-7B), `tiny` (tiny LM), and `qwen2` (Qwen2.5 LM). Default is `qwen2`.
- `--image_generator`: Choose image generator, e.g., `sdxl`, `sdxl-turbo` or `sdxl-lightning`. Default is `sdxl-turbo`.
- `--image_annotator`: Specify the image annotator, like `owlv2` for object detection or `clip` for image classification. Default is `owlv2`.
- `--image_annotator`: Specify the image annotator, like `owlv2` for object detection or `clip` for image classification or `owlv2-slimsam` for instance segmentation. Default is `owlv2`.
- `--conf_threshold`: Confidence threshold for annotation. Default is `0.15`.
- `--annotation_iou_threshold`: Intersection over Union (IoU) threshold for annotation. Default is `0.2`.
- `--prompt_prefix`: Prefix to add to every image generation prompt. Default is `""`.
Expand Down Expand Up @@ -199,6 +199,7 @@ datadreamer --config <path-to-config>
| | [SDXL-Lightning](https://huggingface.co/ByteDance/SDXL-Lightning) | Fast and accurate (1024x1024 images) |
| Image Annotation | [OWLv2](https://huggingface.co/google/owlv2-base-patch16-ensemble) | Open-Vocabulary object detector |
| | [CLIP](https://huggingface.co/openai/clip-vit-base-patch32) | Zero-shot-image-classification |
| | [SlimSAM](https://huggingface.co/Zigeng/SlimSAM-uniform-50) | Zero-shot-instance-segmentation |

<a name="example"></a>

Expand Down Expand Up @@ -271,6 +272,23 @@ save_dir/
}
```

3. Instance Segmentation Annotations (instance_segmentation_annotations.json):

- Each entry corresponds to an image and contains bounding boxes, masks and labels for objects in the image.
- Format:

```bash
{
"image_path": {
"boxes": [[x_min, y_min, x_max, y_max], ...],
"masks": [[[x0, y0],[x1, y1],...], [[x0, y0],[x1, y1],...], ....]
"labels": [label_index, ...]
},
...
"class_names": ["class1", "class2", ...]
}
```

<a name="limitations"></a>

## ⚠️ Limitations
Expand Down
9 changes: 8 additions & 1 deletion datadreamer/dataset_annotation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,12 @@
from .clip_annotator import CLIPAnnotator
from .image_annotator import BaseAnnotator, TaskList
from .owlv2_annotator import OWLv2Annotator
from .slimsam_annotator import SlimSAMAnnotator

__all__ = ["BaseAnnotator", "TaskList", "OWLv2Annotator", "CLIPAnnotator"]
__all__ = [
"BaseAnnotator",
"TaskList",
"OWLv2Annotator",
"CLIPAnnotator",
"SlimSAMAnnotator",
]
20 changes: 18 additions & 2 deletions datadreamer/dataset_annotation/owlv2_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,22 @@ def annotate_batch(
torch.cat(all_labels), num_classes=len(prompts)
)

# Fix the bounding boxes
width_ratio = 1
height_ratio = 1
width = images[i].width
height = images[i].height
if width > height:
height_ratio = height / width
elif height > width:
width_ratio = width / height

all_boxes = [
box
/ torch.tensor([width_ratio, height_ratio, width_ratio, height_ratio])
HonzaCuhel marked this conversation as resolved.
Show resolved Hide resolved
for box in all_boxes
]

# Apply NMS
# transform predictions to shape [N, 5 + num_classes], N is the number of bboxes for nms function
all_boxes_cat = torch.cat(
Expand Down Expand Up @@ -294,8 +310,8 @@ def release(self, empty_cuda_cache: bool = False) -> None:

url = "https://ultralytics.com/images/bus.jpg"
im = Image.open(requests.get(url, stream=True).raw)
annotator = OWLv2Annotator(device="cpu", size="large")
annotator = OWLv2Annotator(device="cpu", size="base")
final_boxes, final_scores, final_labels = annotator.annotate_batch(
[im], ["robot", "horse"]
[im], ["bus", "person"]
)
annotator.release()
153 changes: 153 additions & 0 deletions datadreamer/dataset_annotation/slimsam_annotator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
from __future__ import annotations

import logging
from typing import List

import numpy as np
import PIL
import torch
from transformers import SamModel, SamProcessor

from datadreamer.dataset_annotation.image_annotator import BaseAnnotator
from datadreamer.dataset_annotation.utils import mask_to_polygon

logger = logging.getLogger(__name__)


class SlimSAMAnnotator(BaseAnnotator):
"""A class for image annotation using the SlimSAM model, specializing in instance
segmentation.

Attributes:
model (SAM): The SAM model for instance segmentation.
processor (SamProcessor): The processor for the SAM model.
device (str): The device on which the model will run ('cuda' for GPU, 'cpu' for CPU).
size (str): The size of the SAM model to use ('base' or 'large').

Methods:
_init_model(): Initializes the SAM model.
_init_processor(): Initializes the processor for the SAM model.
annotate_batch(image, prompts, conf_threshold, use_tta, synonym_dict): Annotates the given image with bounding boxes and labels.
release(empty_cuda_cache): Releases resources and optionally empties the CUDA cache.
"""

def __init__(
self,
seed: float = 42,
device: str = "cuda",
size: str = "base",
) -> None:
"""Initializes the SAMAnnotator with a specific seed and device.

Args:
seed (float): Seed for reproducibility. Defaults to 42.
device (str): The device to run the model on. Defaults to 'cuda'.
"""
super().__init__(seed)
self.size = size
self.model = self._init_model()
self.processor = self._init_processor()
self.device = device
self.model.to(self.device)

def _init_model(self) -> SamModel:
"""Initializes the SAM model for object detection.

Returns:
SamModel: The initialized SAM model.
"""
logger.info(f"Initializing `SlimSAM {self.size} model...")
if self.size == "large":
return SamModel.from_pretrained("Zigeng/SlimSAM-uniform-50")
return SamModel.from_pretrained("Zigeng/SlimSAM-uniform-77")

def _init_processor(self) -> SamProcessor:
"""Initializes the processor for the SAM model.

Returns:
SamProcessor: The initialized processor.
"""
if self.size == "large":
return SamProcessor.from_pretrained("Zigeng/SlimSAM-uniform-50")
return SamProcessor.from_pretrained("Zigeng/SlimSAM-uniform-77")

def annotate_batch(
self,
images: List[PIL.Image.Image],
boxes_batch: List[np.ndarray],
iou_threshold: float = 0.2,
) -> List[List[List[float]]]:
"""Annotates images for the task of instance segmentation using the FastSAM
HonzaCuhel marked this conversation as resolved.
Show resolved Hide resolved
model.

Args:
images: The images to be annotated.
boxes_batch: The bounding boxes of found objects.
iou_threshold (float, optional): Intersection over union threshold for non-maximum suppression. Defaults to 0.2.

Returns:
List: A list containing the final segment masks represented as a polygon.
"""
final_segments = []

n = len(images)

for i in range(n):
boxes = boxes_batch[i].tolist()
if len(boxes) == 0:
final_segments.append([])
continue

inputs = self.processor(
images[i], input_boxes=[boxes], return_tensors="pt"
).to(self.device)

with torch.no_grad():
outputs = self.model(**inputs, return_dict=True)
Comment on lines +93 to +106
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SlimSAM doesn't support batched inference?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The thing is that each image can have different number of detected objects, and in that case the batched inference isn't possible straight away , so that's why I implemented it per image. But now that you've mentioned it, I thought about it again and realized that we could "padd" the bboxes with dummy bboxes, so that we can have batch inference, I'm currrently testing it. Let me know @sokovninn, if you'd find this small hack better.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see. Dummy bboxes is a good solution. However, I am not sure if it will bring any boost in inference speed, but it is worth a try I think.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly, I'll test it and let you know.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It turned out not to be faster, so not gonna use it.


masks = self.processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu(),
)[0]

iou_scores = outputs.iou_scores.cpu()

image_masks = []
for j in range(len(boxes)):
keep_idx = iou_scores[0, j] >= iou_threshold
filtered_masks = masks[j, keep_idx].cpu().float()
final_masks = filtered_masks.permute(1, 2, 0)
final_masks = final_masks.mean(axis=-1)
final_masks = (final_masks > 0).int()
final_masks = final_masks.numpy().astype(np.uint8)
polygon = mask_to_polygon(final_masks)
if len(polygon) != 0:
image_masks.append(polygon)

final_segments.append(image_masks)

return final_segments

def release(self, empty_cuda_cache: bool = False) -> None:
"""Releases the model and optionally empties the CUDA cache.

Args:
empty_cuda_cache (bool, optional): Whether to empty the CUDA cache. Defaults to False.
"""
self.model = self.model.to("cpu")
if empty_cuda_cache:
with torch.no_grad():
torch.cuda.empty_cache()


if __name__ == "__main__":
import requests
from PIL import Image

url = "https://ultralytics.com/images/bus.jpg"
im = Image.open(requests.get(url, stream=True).raw)
annotator = SlimSAMAnnotator(device="cpu", size="large")
final_segments = annotator.annotate_batch([im], [np.array([[3, 229, 559, 650]])])
print(len(final_segments), len(final_segments[0]))
print(final_segments[0][0][:5])
26 changes: 26 additions & 0 deletions datadreamer/dataset_annotation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from typing import List

import cv2
import numpy as np
from torchvision import transforms


Expand Down Expand Up @@ -32,3 +34,27 @@ def apply_tta(image) -> List[transforms.Compose]:

augmented_images = [t(image) for t in tta_transforms]
return augmented_images


def mask_to_polygon(mask: np.ndarray) -> List[List[int]]:
"""Converts a binary mask to a polygon.

Args:
mask: The binary mask to be converted.

Returns:
List: A list of vertices of the polygon.
"""
# Find contours in the binary mask
contours, _ = cv2.findContours(
mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
)
if len(contours) == 0:
return []
# Find the contour with the largest area
largest_contour = max(contours, key=cv2.contourArea)

# Extract the vertices of the contour
polygon = largest_contour.reshape(-1, 2).tolist()

return polygon
Loading
Loading