Skip to content

Commit

Permalink
Add Img Cls Annotator
Browse files Browse the repository at this point in the history
  • Loading branch information
HonzaCuhel committed Jan 17, 2025
1 parent 5833c62 commit 9e4990b
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 200 deletions.
2 changes: 2 additions & 0 deletions datadreamer/dataset_annotation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from .aimv2_annotator import AIMv2Annotator
from .clip_annotator import CLIPAnnotator
from .cls_annotator import ImgClassificationAnnotator
from .image_annotator import BaseAnnotator, TaskList
from .owlv2_annotator import OWLv2Annotator
from .slimsam_annotator import SlimSAMAnnotator
Expand All @@ -11,6 +12,7 @@
"BaseAnnotator",
"TaskList",
"OWLv2Annotator",
"ImgClassificationAnnotator",
"CLIPAnnotator",
"SlimSAMAnnotator",
]
102 changes: 2 additions & 100 deletions datadreamer/dataset_annotation/aimv2_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,17 @@
from __future__ import annotations

import logging
from typing import Dict, List

import numpy as np
import PIL
import torch
from PIL import Image
from transformers import AutoModel, AutoProcessor

from datadreamer.dataset_annotation.image_annotator import BaseAnnotator, TaskList
from datadreamer.dataset_annotation.cls_annotator import ImgClassificationAnnotator

logger = logging.getLogger(__name__)


class AIMv2Annotator(BaseAnnotator):
class AIMv2Annotator(ImgClassificationAnnotator):
"""A class for image annotation using the AIMv2 model, specializing in image
classification.
Expand All @@ -39,25 +36,6 @@ class AIMv2Annotator(BaseAnnotator):
release(empty_cuda_cache): Releases resources and optionally empties the CUDA cache.
"""

def __init__(
self,
seed: float = 42,
device: str = "cuda",
size: str = "base",
) -> None:
"""Initializes the AIMv2Annotator with a specific seed and device.
Args:
seed (float): Seed for reproducibility. Defaults to 42.
device (str): The device to run the model on. Defaults to 'cuda'.
"""
super().__init__(seed, task_definition=TaskList.CLASSIFICATION)
self.size = size
self.model = self._init_model()
self.processor = self._init_processor()
self.device = device
self.model.to(self.device)

def _init_processor(self) -> AutoProcessor:
"""Initializes the AIMv2 processor.
Expand All @@ -77,82 +55,6 @@ def _init_model(self) -> AutoModel:
"apple/aimv2-large-patch14-224-lit", trust_remote_code=True
)

def annotate_batch(
self,
images: List[PIL.Image.Image],
objects: List[str],
conf_threshold: float = 0.1,
synonym_dict: Dict[str, List[str]] | None = None,
) -> List[np.ndarray]:
"""Annotates images using the AIMv2 model.
Args:
images: The images to be annotated.
objects: A list of objects (text) to test against the images.
conf_threshold (float, optional): Confidence threshold for the annotations. Defaults to 0.1.
synonym_dict (dict, optional): Dictionary for handling synonyms in labels. Defaults to None.
Returns:
List[np.ndarray]: A list of the annotations for each image.
"""
if synonym_dict is not None:
objs_syn = set()
for obj in objects:
objs_syn.add(obj)
for syn in synonym_dict[obj]:
objs_syn.add(syn)
objs_syn = list(objs_syn)
# Make a dict to transform synonym ids to original ids
synonym_dict_rev = {}
for key, value in synonym_dict.items():
if key in objects:
synonym_dict_rev[objs_syn.index(key)] = objects.index(key)
for v in value:
synonym_dict_rev[objs_syn.index(v)] = objects.index(key)
objects = objs_syn

inputs = self.processor(
text=objects, images=images, return_tensors="pt", padding=True
).to(self.device)

outputs = self.model(**inputs)

logits_per_image = outputs.logits_per_image # image-text similarity score
probs = logits_per_image.softmax(dim=1).cpu() # label probabilities

labels = []
# Get the labels for each image
if synonym_dict is not None:
for prob in probs:
labels.append(
np.unique(
np.array(
[
synonym_dict_rev[label.item()]
for label in torch.where(prob > conf_threshold)[
0
].numpy()
]
)
)
)
else:
for prob in probs:
labels.append(torch.where(prob > conf_threshold)[0].numpy())

return labels

def release(self, empty_cuda_cache: bool = False) -> None:
"""Releases the model and optionally empties the CUDA cache.
Args:
empty_cuda_cache (bool, optional): Whether to empty the CUDA cache. Defaults to False.
"""
self.model = self.model.to("cpu")
if empty_cuda_cache:
with torch.no_grad():
torch.cuda.empty_cache()


if __name__ == "__main__":
import requests
Expand Down
102 changes: 2 additions & 100 deletions datadreamer/dataset_annotation/clip_annotator.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
from __future__ import annotations

import logging
from typing import Dict, List

import numpy as np
import PIL
import torch
from PIL import Image
from transformers import CLIPModel, CLIPProcessor

from datadreamer.dataset_annotation.image_annotator import BaseAnnotator, TaskList
from datadreamer.dataset_annotation.cls_annotator import ImgClassificationAnnotator

logger = logging.getLogger(__name__)


class CLIPAnnotator(BaseAnnotator):
class CLIPAnnotator(ImgClassificationAnnotator):
"""A class for image annotation using the CLIP model, specializing in image
classification.
Expand All @@ -31,25 +28,6 @@ class CLIPAnnotator(BaseAnnotator):
release(empty_cuda_cache): Releases resources and optionally empties the CUDA cache.
"""

def __init__(
self,
seed: float = 42,
device: str = "cuda",
size: str = "base",
) -> None:
"""Initializes the CLIPAnnotator with a specific seed and device.
Args:
seed (float): Seed for reproducibility. Defaults to 42.
device (str): The device to run the model on. Defaults to 'cuda'.
"""
super().__init__(seed, task_definition=TaskList.CLASSIFICATION)
self.size = size
self.model = self._init_model()
self.processor = self._init_processor()
self.device = device
self.model.to(self.device)

def _init_processor(self) -> CLIPProcessor:
"""Initializes the CLIP processor.
Expand All @@ -71,82 +49,6 @@ def _init_model(self) -> CLIPModel:
return CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
return CLIPModel.from_pretrained("openai/clip-vit-base-patch32")

def annotate_batch(
self,
images: List[PIL.Image.Image],
objects: List[str],
conf_threshold: float = 0.1,
synonym_dict: Dict[str, List[str]] | None = None,
) -> List[np.ndarray]:
"""Annotates images using the CLIP model.
Args:
images: The images to be annotated.
objects: A list of objects (text) to test against the images.
conf_threshold (float, optional): Confidence threshold for the annotations. Defaults to 0.1.
synonym_dict (dict, optional): Dictionary for handling synonyms in labels. Defaults to None.
Returns:
List[np.ndarray]: A list of the annotations for each image.
"""
if synonym_dict is not None:
objs_syn = set()
for obj in objects:
objs_syn.add(obj)
for syn in synonym_dict[obj]:
objs_syn.add(syn)
objs_syn = list(objs_syn)
# Make a dict to transform synonym ids to original ids
synonym_dict_rev = {}
for key, value in synonym_dict.items():
if key in objects:
synonym_dict_rev[objs_syn.index(key)] = objects.index(key)
for v in value:
synonym_dict_rev[objs_syn.index(v)] = objects.index(key)
objects = objs_syn

inputs = self.processor(
text=objects, images=images, return_tensors="pt", padding=True
).to(self.device)

outputs = self.model(**inputs)

logits_per_image = outputs.logits_per_image # image-text similarity score
probs = logits_per_image.softmax(dim=1).cpu() # label probabilities

labels = []
# Get the labels for each image
if synonym_dict is not None:
for prob in probs:
labels.append(
np.unique(
np.array(
[
synonym_dict_rev[label.item()]
for label in torch.where(prob > conf_threshold)[
0
].numpy()
]
)
)
)
else:
for prob in probs:
labels.append(torch.where(prob > conf_threshold)[0].numpy())

return labels

def release(self, empty_cuda_cache: bool = False) -> None:
"""Releases the model and optionally empties the CUDA cache.
Args:
empty_cuda_cache (bool, optional): Whether to empty the CUDA cache. Defaults to False.
"""
self.model = self.model.to("cpu")
if empty_cuda_cache:
with torch.no_grad():
torch.cuda.empty_cache()


if __name__ == "__main__":
import requests
Expand Down
Loading

0 comments on commit 9e4990b

Please sign in to comment.