Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/sdxl lightning #36

Merged
merged 6 commits into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ datadreamer --save_dir <directory> --class_names <objects> --prompts_number <num
- `--task`: Choose between `detection` and `classification`. Default is `detection`.
- `--num_objects_range`: Range of objects in a prompt. Default is 1 to 3.
- `--prompt_generator`: Choose between `simple`, `lm` (language model) and `tiny` (tiny LM). Default is `simple`.
- `--image_generator`: Choose image generator, e.g., `sdxl` or `sdxl-turbo`. Default is `sdxl-turbo`.
- `--image_generator`: Choose image generator, e.g., `sdxl`, `sdxl-turbo` or `sdxl-lightning`. Default is `sdxl-turbo`.
- `--image_annotator`: Specify the image annotator, like `owlv2`. Default is `owlv2`.
- `--conf_threshold`: Confidence threshold for object detection. Default is 0.15.
- `--use_tta`: Toggle test time augmentation for object detection. Default is True.
Expand All @@ -132,6 +132,7 @@ datadreamer --save_dir <directory> --class_names <objects> --prompts_number <num
| | Simple random generator | Joins randomly chosen object names |
| Image Generation | [SDXL-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) | Slow and accurate (1024x1024 images) |
| | [SDXL-Turbo](https://huggingface.co/stabilityai/sdxl-turbo) | Fast and less accurate (512x512 images) |
| | [SDXL-Lightning](https://huggingface.co/ByteDance/SDXL-Lightning) | Fast and accurate (1024x1024 images) |
| Image Annotation | [OWLv2](https://huggingface.co/google/owlv2-base-patch16-ensemble) | Open-Vocabulary object detector |

<a name="example"></a>
Expand Down
7 changes: 6 additions & 1 deletion datadreamer/image_generation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from .sdxl_image_generator import StableDiffusionImageGenerator
from .sdxl_lightning_image_generator import StableDiffusionLightningImageGenerator
from .sdxl_turbo_image_generator import StableDiffusionTurboImageGenerator

__all__ = ["StableDiffusionImageGenerator", "StableDiffusionTurboImageGenerator"]
__all__ = [
"StableDiffusionImageGenerator",
"StableDiffusionTurboImageGenerator",
"StableDiffusionLightningImageGenerator",
]
2 changes: 1 addition & 1 deletion datadreamer/image_generation/sdxl_image_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def release(self, empty_cuda_cache=False) -> None:
"A photo of a bicycle pedaling alongside an aeroplane taking off, showcasing the harmony between human-powered and mechanical transportation.",
"A photo of bicycles along a scenic mountain path, where the riders seem to have taken a moment to appreciate the stunning views.",
]
prompt_objects = [["aeroplane", "boat", "bicycle"], ["bicycle"]]
prompt_objects = [["aeroplane", "bicycle"], ["bicycle"]]

image_paths = []
counter = 0
Expand Down
173 changes: 173 additions & 0 deletions datadreamer/image_generation/sdxl_lightning_image_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
from typing import List, Optional

import torch
from compel import Compel, ReturnedEmbeddingsType
from diffusers import (
EulerDiscreteScheduler,
StableDiffusionXLPipeline,
UNet2DConditionModel,
)
from huggingface_hub import hf_hub_download
from PIL import Image
from safetensors.torch import load_file

from datadreamer.image_generation.image_generator import ImageGenerator


class StableDiffusionLightningImageGenerator(ImageGenerator):
"""A subclass of ImageGenerator specifically designed to use the Stable Diffusion
Lightning model for faster image generation.

Attributes:
pipe (StableDiffusionXLPipeline): The Stable Diffusion Lightning model for image generation.

Methods:
_init_gen_model(): Initializes the Stable Diffusion Lightning model.
_init_compel(): Initializes the Compel model for text prompt weighting.
generate_images_batch(prompts, negative_prompt, prompt_objects): Generates a batch of images based on the provided prompts.
release(empty_cuda_cache): Releases resources and optionally empties the CUDA cache.
"""

def __init__(self, *args, **kwargs):
"""Initializes the StableDiffusionLightningImageGenerator with the given
arguments."""
super().__init__(*args, **kwargs)
self.pipe = self._init_gen_model()
self.compel = self._init_compel()

def _init_gen_model(self):
"""Initializes the Stable Diffusion Lightning model for image generation.

Returns:
StableDiffusionXLPipeline: The initialized Stable Diffusion Lightning model.
"""
base = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "ByteDance/SDXL-Lightning"
ckpt = "sdxl_lightning_4step_unet.safetensors" # Use the correct ckpt for your step setting!

# Load model.
if self.device == "cpu":
print("Loading SDXL Lightning on CPU...")
unet = UNet2DConditionModel.from_config(base, subfolder="unet")
unet.load_state_dict(load_file(hf_hub_download(repo, ckpt)))
pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet)
else:
print("Loading SDXL Lightning on GPU...")
unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(
self.device, torch.float16
)
unet.load_state_dict(
load_file(hf_hub_download(repo, ckpt), device=self.device)
)
pipe = StableDiffusionXLPipeline.from_pretrained(
base, unet=unet, torch_dtype=torch.float16, variant="fp16"
).to(self.device)
pipe.enable_model_cpu_offload()

# Ensure sampler uses "trailing" timesteps.
pipe.scheduler = EulerDiscreteScheduler.from_config(
pipe.scheduler.config, timestep_spacing="trailing"
)

return pipe

def _init_compel(self):
"""Initializes the Compel model for text prompt weighting.

Returns:
Compel: The initialized Compel model.
"""
compel = Compel(
tokenizer=[self.pipe.tokenizer, self.pipe.tokenizer_2],
text_encoder=[self.pipe.text_encoder, self.pipe.text_encoder_2],
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
requires_pooled=[False, True],
)
return compel

def generate_images_batch(
self,
prompts: List[str],
negative_prompt: str,
prompt_objects: Optional[List[List[str]]] = None,
batch_size: int = 1,
) -> List[Image.Image]:
"""Generates a batch of images using the Stable Diffusion Lightning model based
on the provided prompts.

Args:
prompts (List[str]): A list of positive prompts to guide image generation.
negative_prompt (str): The negative prompt to avoid certain features in the image.
prompt_objects (Optional[List[List[str]]]): Optional list of objects for each prompt for CLIP model testing.
batch_size (int): The number of images to generate in each batch.

Returns:
List[Image.Image]: A list of generated images.
"""

if prompt_objects is not None:
for i in range(len(prompt_objects)):
for obj in prompt_objects[i]:
prompts[i] = prompts[i].replace(obj, f"({obj})1.5", 1)

conditioning, pooled = self.compel(prompts)
conditioning_neg, pooled_neg = self.compel([negative_prompt] * len(prompts))
images = self.pipe(
prompt_embeds=conditioning,
pooled_prompt_embeds=pooled,
negative_prompt_embeds=conditioning_neg,
negative_pooled_prompt_embeds=pooled_neg,
guidance_scale=0.0,
num_inference_steps=4,
).images

return images

def release(self, empty_cuda_cache=False) -> None:
"""Releases the model and optionally empties the CUDA cache."""
self.pipe = self.pipe.to("cpu")
if self.use_clip_image_tester:
self.clip_image_tester.release()
if empty_cuda_cache:
with torch.no_grad():
torch.cuda.empty_cache()


if __name__ == "__main__":
import os

# Create the generator
image_generator = StableDiffusionLightningImageGenerator(
seed=42,
use_clip_image_tester=False,
image_tester_patience=1,
batch_size=4,
device="cpu",
)
prompts = [
"A photo of a bicycle pedaling alongside an aeroplane.",
"A photo of a dragonfly flying in the sky.",
"A photo of a dog walking in the park.",
"A photo of an alien exploring the galaxy.",
"A photo of a robot working on a computer.",
]
prompt_objects = [
["aeroplane", "bicycle"],
["dragonfly"],
["dog"],
["alien"],
["robot", "computer"],
]

image_paths = []
counter = 0
for generated_images_batch in image_generator.generate_images(
prompts, prompt_objects
):
for generated_image in generated_images_batch:
image_path = os.path.join("./", f"image_lightning_{counter}.jpg")
generated_image.save(image_path)
image_paths.append(image_path)
counter += 1

image_generator.release(empty_cuda_cache=True)
2 changes: 1 addition & 1 deletion datadreamer/image_generation/sdxl_turbo_image_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def release(self, empty_cuda_cache=False) -> None:
prompts = [
"A photo of a bicycle pedaling alongside an aeroplane taking off, showcasing the harmony between human-powered and mechanical transportation.",
] * 16
prompt_objects = [["aeroplane", "boat", "bicycle"]] * 16
prompt_objects = [["aeroplane", "bicycle"]] * 16

image_paths = []
counter = 0
Expand Down
4 changes: 3 additions & 1 deletion datadreamer/pipelines/generate_dataset_from_scratch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from datadreamer.dataset_annotation import OWLv2Annotator
from datadreamer.image_generation import (
StableDiffusionImageGenerator,
StableDiffusionLightningImageGenerator,
StableDiffusionTurboImageGenerator,
)
from datadreamer.prompt_generation import (
Expand All @@ -30,6 +31,7 @@
image_generators = {
"sdxl": StableDiffusionImageGenerator,
"sdxl-turbo": StableDiffusionTurboImageGenerator,
"sdxl-lightning": StableDiffusionLightningImageGenerator,
}

annotators = {"owlv2": OWLv2Annotator}
Expand Down Expand Up @@ -84,7 +86,7 @@ def parse_args():
"--image_generator",
type=str,
default="sdxl-turbo",
choices=["sdxl", "sdxl-turbo"],
choices=["sdxl", "sdxl-turbo", "sdxl-lightning"],
help="Image generator to use",
)
parser.add_argument(
Expand Down
4 changes: 2 additions & 2 deletions media/coverage_badge.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
44 changes: 44 additions & 0 deletions tests/integration/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,50 @@ def test_cuda_simple_sdxl_detection_pipeline():
_check_detection_pipeline(cmd, target_folder)


@pytest.mark.skipif(
total_memory < 16 or total_disk_space < 35,
reason="Test requires at least 16GB of RAM and 35GB of HDD",
)
def test_cpu_simple_sdxl_lightning_detection_pipeline():
# Define target folder
target_folder = "data/data-det-cpu-simple-sdxl-lightning/"
# Define the command to run the datadreamer
cmd = (
f"datadreamer --save_dir {target_folder} "
f"--class_names alien mars cat "
f"--prompts_number 1 "
f"--prompt_generator simple "
f"--num_objects_range 1 2 "
f"--image_generator sdxl-lightning "
f"--use_image_tester "
f"--device cpu"
)
# Check the run of the pipeline
_check_detection_pipeline(cmd, target_folder)


@pytest.mark.skipif(
not torch.cuda.is_available() or total_memory < 16 or total_disk_space < 35,
reason="Test requires GPU, at least 16GB of RAM and 35GB of HDD",
)
def test_cuda_simple_sdxl_lightning_detection_pipeline():
# Define target folder
target_folder = "data/data-det-cuda-simple-sdxl-lightning/"
# Define the command to run the datadreamer
cmd = (
f"datadreamer --save_dir {target_folder} "
f"--class_names alien mars cat "
f"--prompts_number 1 "
f"--prompt_generator simple "
f"--num_objects_range 1 2 "
f"--image_generator sdxl-lightning "
f"--use_image_tester "
f"--device cuda"
)
# Check the run of the pipeline
_check_detection_pipeline(cmd, target_folder)


# =========================================================
# DETECTION - LLM
# =========================================================
Expand Down
29 changes: 24 additions & 5 deletions tests/unittests/test_image_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@
import torch
from PIL import Image

from datadreamer.image_generation.clip_image_tester import ClipImageTester
from datadreamer.image_generation.sdxl_image_generator import (
from datadreamer.image_generation import (
StableDiffusionImageGenerator,
)
from datadreamer.image_generation.sdxl_turbo_image_generator import (
StableDiffusionLightningImageGenerator,
StableDiffusionTurboImageGenerator,
)
from datadreamer.image_generation.clip_image_tester import ClipImageTester

# Get the total memory in GB
total_memory = psutil.virtual_memory().total / (1024**3)
Expand Down Expand Up @@ -55,7 +54,11 @@ def test_cpu_clip_image_tester():

def _check_image_generator(
image_generator_class: Type[
Union[StableDiffusionImageGenerator, StableDiffusionTurboImageGenerator]
Union[
StableDiffusionImageGenerator,
StableDiffusionTurboImageGenerator,
StableDiffusionLightningImageGenerator,
]
],
device: str,
):
Expand Down Expand Up @@ -101,3 +104,19 @@ def test_cuda_sdxl_turbo_image_generator():
)
def test_cpu_sdxl_turbo_image_generator():
_check_image_generator(StableDiffusionTurboImageGenerator, "cpu")


@pytest.mark.skipif(
not torch.cuda.is_available() or total_memory < 16 or total_disk_space < 25,
reason="Test requires GPU, at least 16GB of RAM and 25GB of HDD",
)
def test_cuda_sdxl_lightning_image_generator():
_check_image_generator(StableDiffusionLightningImageGenerator, "cuda")


@pytest.mark.skipif(
total_memory < 16 or total_disk_space < 25,
reason="Test requires at least 16GB of RAM and 25GB of HDD",
)
def test_cpu_sdxl_lightning_image_generator():
_check_image_generator(StableDiffusionLightningImageGenerator, "cpu")