Skip to content

Commit

Permalink
Feature/batched image generation (#34)
Browse files Browse the repository at this point in the history
* feature: add batched image generation

* test: modify image generation tests

* fix: modify examples

* fix: prompt objects weights

* docs: update args description

* docs: update prompt generation docstrings

* [Automated] Updated coverage badge

---------

Co-authored-by: GitHub Actions <[email protected]>
  • Loading branch information
sokovninn and actions-user authored Feb 22, 2024
1 parent 518b197 commit 6d313af
Show file tree
Hide file tree
Showing 14 changed files with 484 additions and 342 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ datadreamer --save_dir <directory> --class_names <objects> --prompts_number <num
- `--image_tester_patience`: Patience level for image tester. Default is 1.
- `--lm_quantization`: Quantization to use for Mistral language model. Choose between `none` and `4bit`. Default is `none`.
- `--batch_size_prompt`: Batch size for prompt generation. Default is 64.
- `--batch_size_image`: Batch size for image generation. Default is 1.
- `--device`: Choose between `cuda` and `cpu`. Default is cuda.
- `--seed`: Set a random seed for image and prompt generation. Default is 42.

Expand All @@ -130,7 +131,7 @@ datadreamer --save_dir <directory> --class_names <objects> --prompts_number <num
| | Simple random generator | Joins randomly chosen object names |
| Image Generation | [SDXL-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) | Slow and accurate (1024x1024 images) |
| | [SDXL-Turbo](https://huggingface.co/stabilityai/sdxl-turbo) | Fast and less accurate (512x512 images) |
| Image Annotation | [OWLv2](https://huggingface.co/google/owlv2-large-patch14-ensemble) | Open-Vocabulary object detector |
| Image Annotation | [OWLv2](https://huggingface.co/google/owlv2-base-patch16-ensemble) | Open-Vocabulary object detector |

<a name="example"></a>

Expand Down
46 changes: 46 additions & 0 deletions datadreamer/image_generation/clip_image_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,52 @@ def test_image(self, image: Image.Image, objects: List[str], conf_threshold=0.05
# Check if all objects meet the confidence threshold
return passed, probs, num_passed

def test_images_batch(
self,
images: List[Image.Image],
objects: List[List[str]],
conf_threshold=0.05,
) -> List[tuple]:
"""Tests the generated images against a set of objects using the CLIP model.
Args:
images (List[Image.Image]): The images to be tested.
objects (List[List[str]]): A list of objects (text) to test against the images.
conf_threshold (float, optional): Confidence threshold for considering an object as present. Defaults to 0.05.
Returns:
List[tuple]: A list of tuples containing a boolean indicating if the image passes the test,
the probabilities of the objects, and the number of objects that passed the test.
"""
# Transform the inputs for the CLIP model
objects_array = []
for obj_list in objects:
objects_array.extend(obj_list)
inputs = self.clip_processor(
text=objects_array, images=images, return_tensors="pt", padding=True
).to(self.device)

outputs = self.clip(**inputs)

logits_per_image = outputs.logits_per_image # image-text similarity score
probs = logits_per_image.softmax(dim=1) # label probabilities

# Cahnge the shape of the probs, passed and num_passed so they correspond to the initial tuples in the objects list
probs_list = []
passed_list = []
num_passed_list = []

start_pos = 0
for i, obj_list in enumerate(objects):
end_pos = start_pos + len(obj_list)
probs_list.append(probs[i, start_pos:end_pos])
passed_list.append(torch.all(probs_list[-1] > conf_threshold).item())
num_passed_list.append(torch.sum(probs_list[-1] > conf_threshold).item())
start_pos = end_pos

# Check if all objects meet the confidence threshold
return passed_list, probs_list, num_passed_list

def release(self, empty_cuda_cache=False) -> None:
"""Releases the model and optionally empties the CUDA cache.
Expand Down
49 changes: 33 additions & 16 deletions datadreamer/image_generation/image_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class ImageGenerator:
negative_prompt (str): A string of negative prompts to guide the generation away from certain features.
use_clip_image_tester (bool): Flag to use CLIP model testing for generated images.
image_tester_patience (int): The number of attempts to generate an image that passes CLIP testing.
batch_size (int): The number of images to generate in each batch.
seed (float): Seed for reproducibility.
clip_image_tester (ClipImageTester): Instance of ClipImageTester if use_clip_image_tester is True.
device (str): The device on which the model will run ('cuda' for GPU, 'cpu' for CPU).
Expand All @@ -42,6 +43,7 @@ def __init__(
] = "cartoon, blue skin, painting, scrispture, golden, illustration, worst quality, low quality, normal quality:2, unrealistic dream, low resolution, static, sd character, low quality, low resolution, greyscale, monochrome, nose, cropped, lowres, jpeg artifacts, deformed iris, deformed pupils, bad eyes, semi-realistic worst quality, bad lips, deformed mouth, deformed face, deformed fingers, bad anatomy",
use_clip_image_tester: Optional[bool] = False,
image_tester_patience: Optional[int] = 1,
batch_size: Optional[int] = 1,
seed: Optional[float] = 42,
device: str = "cuda",
) -> None:
Expand All @@ -52,6 +54,7 @@ def __init__(
self.seed = seed
self.use_clip_image_tester = use_clip_image_tester
self.image_tester_patience = image_tester_patience
self.batch_size = batch_size
self.device = device
if self.use_clip_image_tester:
self.clip_image_tester = ClipImageTester(self.device)
Expand Down Expand Up @@ -81,7 +84,7 @@ def generate_images(
prompt_objects (Optional[List[List[str]]]): Optional list of objects for each prompt for CLIP model testing.
Yields:
Image.Image: Generated images.
List[Image.Image]: A batch of generated images.
"""
if isinstance(prompts, str):
prompts = [prompts]
Expand All @@ -92,39 +95,53 @@ def generate_images(
if prompt_objects is None:
prompt_objects = [None] * len(prompts)

for prompt, prompt_objs in tqdm(
zip(prompts, prompt_objects), desc="Generating images", total=len(prompts)
):
progress_bar = tqdm(
desc="Generating images", total=len(prompts), dynamic_ncols=True
)

for i in range(0, len(prompts), self.batch_size):
prompts_batch = prompts[i : i + self.batch_size]
prompt_objs_batch = prompt_objects[i : i + self.batch_size]
if self.use_clip_image_tester:
best_prob = 0
best_image = None
best_images_batch = None
best_num_passed = 0
passed = False

for _ in tqdm(range(self.image_tester_patience), desc="Testing image"):
image = self.generate_image(
prompt, self.negative_prompt, prompt_objs
images_batch = self.generate_images_batch(
prompts_batch, self.negative_prompt, prompt_objs_batch
)
passed, probs, num_passed = self.clip_image_tester.test_image(
image, prompt_objs
(
passed_list,
probs_list,
num_passed_list,
) = self.clip_image_tester.test_images_batch(
images_batch, prompt_objs_batch
)
# Return the first image that passes the test
passed = all(passed_list)
mean_prob = sum(
torch.mean(probs).item() for probs in probs_list
) / len(probs_list)
num_passed = sum(num_passed_list)
if passed:
yield image
yield images_batch
break
mean_prob = probs.mean().item()
if num_passed > best_num_passed or (
num_passed == best_num_passed and mean_prob > best_prob
):
best_image = image
best_images_batch = images_batch
best_prob = mean_prob
best_num_passed = num_passed
# If no image passed the test, return the image with the highest number of objects that passed the test
if not passed:
yield best_image

yield best_images_batch
else:
yield self.generate_image(prompt, self.negative_prompt, prompt_objs)
yield self.generate_images_batch(
prompts_batch, self.negative_prompt, prompt_objs_batch
)

progress_bar.update(len(prompts_batch))

@abstractmethod
def release(self, empty_cuda_cache=False) -> None:
Expand Down
62 changes: 35 additions & 27 deletions datadreamer/image_generation/sdxl_image_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import torch
from compel import Compel, ReturnedEmbeddingsType
from diffusers import DiffusionPipeline
from PIL import Image

from datadreamer.image_generation.image_generator import ImageGenerator

Expand All @@ -21,7 +20,7 @@ class StableDiffusionImageGenerator(ImageGenerator):
Methods:
_init_gen_model(): Initializes the generative models for image generation.
_init_processor(): Initializes the processors for the models.
generate_image(prompt, negative_prompt, prompt_objects): Generates an image based on the provided prompt.
generate_images_batch(prompts, negative_prompt, prompt_objects): Generates a batch of images based on the provided prompts.
release(empty_cuda_cache): Releases resources and optionally empties the CUDA cache.
"""

Expand All @@ -38,6 +37,7 @@ def _init_gen_model(self):
tuple: The base and refiner models.
"""
if self.device == "cpu":
print("Loading SDXL on CPU...")
base = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
# variant="fp16",
Expand All @@ -55,6 +55,7 @@ def _init_gen_model(self):
)
refiner.to("cpu")
else:
print("Loading SDXL on GPU...")
base = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
Expand Down Expand Up @@ -94,35 +95,38 @@ def _init_processor(self):
)
return compel, compel_refiner

def generate_image(
def generate_images_batch(
self,
prompt: str,
prompts: List[str],
negative_prompt: str,
prompt_objects: Optional[List[str]] = None,
) -> Image.Image:
"""Generates an image based on the provided prompt, using Stable Diffusion
models.
prompt_objects: Optional[List[List[str]]] = None,
):
"""Generates a batch of images based on the provided prompts.
Args:
prompt (str): The positive prompt to guide image generation.
prompts (List[str]): A list of positive prompts to guide image generation.
negative_prompt (str): The negative prompt to avoid certain features in the image.
prompt_objects (Optional[List[str]]): Optional list of objects to be used in CLIP model testing.
prompt_objects (Optional[List[List[str]]]): Optional list of objects to be used in CLIP model testing.
Returns:
Image.Image: The generated image.
List[Image.Image]: A list of generated images.
"""
if prompt_objects is not None:
for obj in prompt_objects:
prompt = prompt.replace(obj, f"({obj})1.5", 1)
for i in range(len(prompt_objects)):
for obj in prompt_objects[i]:
prompts[i] = prompts[i].replace(obj, f"({obj})1.5", 1)

conditioning, pooled = self.base_processor(prompt)
conditioning_neg, pooled_neg = self.base_processor(negative_prompt)
conditioning, pooled = self.base_processor(prompts)
conditioning_neg, pooled_neg = self.base_processor(
[negative_prompt] * len(prompts)
)

conditioning_refiner, pooled_refiner = self.refiner_processor(prompt)
conditioning_refiner, pooled_refiner = self.refiner_processor(prompts)
negative_conditioning_refiner, negative_pooled_refiner = self.refiner_processor(
negative_prompt
[negative_prompt] * len(prompts)
)
image = self.base(

images = self.base(
prompt_embeds=conditioning,
pooled_prompt_embeds=pooled,
negative_prompt_embeds=conditioning_neg,
Expand All @@ -131,17 +135,18 @@ def generate_image(
denoising_end=0.78,
output_type="latent",
).images
image = self.refiner(

images = self.refiner(
prompt_embeds=conditioning_refiner,
pooled_prompt_embeds=pooled_refiner,
negative_prompt_embeds=negative_conditioning_refiner,
negative_pooled_prompt_embeds=negative_pooled_refiner,
num_inference_steps=65,
denoising_start=0.78,
image=image,
).images[0]
image=images,
).images

return image
return images

def release(self, empty_cuda_cache=False) -> None:
"""Releases the models and optionally empties the CUDA cache."""
Expand Down Expand Up @@ -171,11 +176,14 @@ def release(self, empty_cuda_cache=False) -> None:
prompt_objects = [["aeroplane", "boat", "bicycle"], ["bicycle"]]

image_paths = []
for i, generated_image in enumerate(
image_generator.generate_images(prompts, prompt_objects)
counter = 0
for generated_images_batch in image_generator.generate_images(
prompts, prompt_objects
):
image_path = os.path.join("./", f"image_{i}.jpg")
generated_image.save(image_path)
image_paths.append(image_path)
for generated_image in generated_images_batch:
image_path = os.path.join("./", f"image_{counter}.jpg")
generated_image.save(image_path)
image_paths.append(image_path)
counter += 1

image_generator.release(empty_cuda_cache=True)
Loading

0 comments on commit 6d313af

Please sign in to comment.