diff --git a/.github/workflows/autogenerate_requirements.yaml b/.github/workflows/autogenerate_requirements.yaml
new file mode 100644
index 0000000..f56e483
--- /dev/null
+++ b/.github/workflows/autogenerate_requirements.yaml
@@ -0,0 +1,43 @@
+name: Autogenerate Requirements
+
+on:
+ pull_request:
+ branches: [dev, main]
+ paths:
+ - 'pyproject.toml'
+ - 'tools/autogenerate_requirements.py'
+ - '.github/workflows/autogenerate_requirements.yaml'
+
+jobs:
+ update-requirements:
+ runs-on: ubuntu-latest
+
+ steps:
+ - name: Check out repository
+ uses: actions/checkout@v3
+ with:
+ ref: ${{ github.head_ref }}
+
+ - name: Set up Python
+ uses: actions/setup-python@v3
+ with:
+ python-version: '3.10'
+
+ - name: Run autogeneration script
+ run: |
+ pip install toml
+ python tools/autogenerate_requirements.py
+
+ - name: Commit files
+ run: |
+ git config --global user.name 'GitHub Actions'
+ git config --global user.email 'actions@github.com'
+ git diff --quiet requirements.txt || {
+ git add requirements.txt
+ git commit -m "[Automated] Updated requirements.txt"
+ }
+
+ - name: Push changes
+ uses: ad-m/github-push-action@master
+ with:
+ branch: ${{ github.head_ref }}
diff --git a/.github/workflows/pre-commit.yaml b/.github/workflows/pre-commit.yaml
new file mode 100644
index 0000000..24d1d43
--- /dev/null
+++ b/.github/workflows/pre-commit.yaml
@@ -0,0 +1,15 @@
+name: pre-commit
+
+on:
+ pull_request:
+ branches: [dev, main]
+
+jobs:
+ pre-commit:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v3
+ - uses: actions/setup-python@v3
+ with:
+ python-version: '3.8'
+ - uses: pre-commit/action@v3.0.0
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 51b4e5c..fd0d203 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,11 +1,34 @@
repos:
+ - repo: https://github.com/astral-sh/ruff-pre-commit
+ rev: v0.1.2
+ hooks:
+ - id: ruff
+ args: [--fix, --exit-non-zero-on-fix]
+ types_or: [python, pyi]
+
- repo: https://github.com/ambv/black
rev: 23.3.0
hooks:
- id: black
- language_version: python3
+ language_version: python3.8
+
+ - repo: https://github.com/PyCQA/docformatter
+ rev: v1.7.5
+ hooks:
+ - id: docformatter
+ additional_dependencies: [tomli]
+ args: [--in-place]
+
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: no-commit-to-branch
- args: ['--branch', 'main', '--branch', 'dev']
\ No newline at end of file
+ args: ['--branch', 'main', '--branch', 'dev']
+
+ - repo: https://github.com/executablebooks/mdformat
+ rev: 0.7.10
+ hooks:
+ - id: mdformat
+ additional_dependencies:
+ - mdformat-gfm
+ - mdformat-toc
diff --git a/README.md b/README.md
index 3339db8..6c8c9a3 100644
--- a/README.md
+++ b/README.md
@@ -1,8 +1,10 @@
-![DataDreamer examples](images/grid_image_3x2_generated_dataset.jpg)
+# DataDreamer
+[![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/luxonis/datadreamer/blob/main/examples/generate_dataset_and_train_yolo.ipynb)
-# DataDreamer
+![DataDreamer examples](images/grid_image_3x2_generated_dataset.jpg)
+
`DataDreamer` is an advanced toolkit engineered to facilitate the development of edge AI models, irrespective of initial data availability. Distinctive features of DataDreamer include:
- **Synthetic Data Generation**: Eliminate the dependency on extensive datasets for AI training. DataDreamer empowers users to generate synthetic datasets from the ground up, utilizing advanced AI algorithms capable of producing high-quality, diverse images.
@@ -18,11 +20,28 @@ The rationale behind the name `DataDreamer`:
In essence, `DataDreamer` is designed to transform the AI development process, making it more accessible, efficient, and effective, turning visionary ideas into reality.
+## Table of Contents
+
+- [Features](#features)
+- [Installation](#installation)
+- [Hardware Requirements](#hardware-requirements)
+- [Usage](#usage)
+ - [Main Parameters](#main-parameters)
+ - [Additional Parameters](#additional-parameters)
+ - [Example](#example)
+ - [Output](#output)
+ - [Annotations Format](#annotations-format)
+ - [Note](#note)
+- [Limitations](#limitations)
+- [License](#license)
+- [Acknowledgements](#acknowledgements)
+
## Features
- **Prompt Generation**: Automate the creation of image prompts using powerful language models.
- *Provided class names: ["horse", "robot"]* -> *Generated prompt: "A photo of a horse and a robot coexisting peacefully in the midst of a serene pasture."*
+ *Provided class names: \["horse", "robot"\]* -> *Generated prompt: "A photo of a horse and a robot coexisting peacefully in the midst of a serene pasture."*
+
- **Image Generation**: Generate synthetic datasets with state-of-the-art generative models.
@@ -31,7 +50,6 @@ In essence, `DataDreamer` is designed to transform the AI development process, m
-
- **Edge Model Training**: Train efficient small-scale neural networks for edge deployment. (not part of this library)
[Example](https://github.com/luxonis/datadreamer/blob/main/examples/generate_dataset_and_train_yolo.ipynb)
@@ -50,6 +68,7 @@ pip install -e .
```
## Hardware Requirements
+
To ensure optimal performance and compatibility with the libraries used in this project, the following hardware specifications are recommended:
- `GPU`: A CUDA-compatible GPU with a minimum of 16 GB memory. This is essential for libraries like `torch`, `torchvision`, `transformers`, and `diffusers`, which leverage CUDA for accelerated computing in machine learning and image processing tasks.
@@ -57,10 +76,8 @@ To ensure optimal performance and compatibility with the libraries used in this
## Usage
-### Overview
The `datadreamer/pipelines/generate_dataset_from_scratch.py` (`datadreamer` command) script is a powerful tool for generating and annotating images with specific objects. It uses advanced models to both create images and accurately annotate them with bounding boxes for designated objects.
-### Usage
Run the following command in your terminal to use the script:
```bash
@@ -68,6 +85,7 @@ datadreamer --save_dir --class_names --prompts_number --class_names --prompts_number =3.8"
license = { file = "LICENSE" }
+maintainers = [{ name = "Luxonis", email = "support@luxonis.com"}]
+keywords = ["computer vision", "AI", "machine learning", "generative models"]
classifiers = [
+ "License :: Apache License 2.0",
"Development Status :: 3 - Alpha",
+ "Programming Language :: Python :: 3.8",
"Intended Audience :: Developers",
- "Programming Language :: Python :: 3",
- "Topic :: Software Development :: Libraries"
+ "Topic :: Software Development :: Libraries",
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
+ "Topic :: Scientific/Engineering :: Image Processing",
+ "Topic :: Scientific/Engineering :: Image Recognition",
]
-keywords = ["computer vision", "AI", "machine learning", "generative models"]
dependencies = [
- "torch",
- "torchvision",
- "transformers",
- "diffusers",
- "compel",
- "tqdm",
- "Pillow",
- "numpy",
- "matplotlib",
- "opencv-python",
- "accelerate",
- "scipy"
+ "torch>=2.0.0",
+ "torchvision>=0.16.0",
+ "transformers>=4.0.0",
+ "diffusers>=0.24.0",
+ "compel>=2.0.0",
+ "tqdm>=4.0.0",
+ "Pillow>=10.0.0",
+ "numpy>=1.22.0",
+ "matplotlib>=3.6.0",
+ "opencv-python>=4.7.0",
+ "accelerate>=0.25.0",
+ "scipy>=1.10.0",
+]
+[project.optional-dependencies]
+dev = [
+ "datadreamer",
+ "pre-commit>=3.2.1",
+ "toml>=0.10.2",
]
-requires-python = ">=3.8"
[project.urls]
Homepage = "https://github.com/luxonis/datadreamer"
+[project.scripts]
+datadreamer = "datadreamer.pipelines.generate_dataset_from_scratch:main"
+
[tool.setuptools.packages.find]
where = ["src"]
-[project.scripts]
-datadreamer = "datadreamer.pipelines.generate_dataset_from_scratch:main"
\ No newline at end of file
+[tool.ruff]
+target-version = "py38"
+
+[tool.ruff.lint]
+ignore = ["F403", "B028", "B905", "D1"]
+select = ["E4", "E7", "E9", "F", "W", "B", "I"]
+
+[tool.ruff.pydocstyle]
+convention = "google"
+
+[tool.docformatter]
+black = true
+
+[tool.mypy]
+python_version = "3.8"
+ignore_missing_imports = true
+
+[tool.pyright]
+typeCheckingMode = "basic"
diff --git a/requirements.txt b/requirements.txt
index 8e2b0ee..16a0e9c 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,12 +1,16 @@
-torch
-torchvision
-transformers
-diffusers
-compel
-tqdm
-Pillow
-numpy
-matplotlib
-opencv-python
-accelerate
-scipy
\ No newline at end of file
+torch>=2.0.0
+torchvision>=0.16.0
+transformers>=4.0.0
+diffusers>=0.24.0
+compel>=2.0.0
+tqdm>=4.0.0
+Pillow>=10.0.0
+numpy>=1.22.0
+matplotlib>=3.6.0
+opencv-python>=4.7.0
+accelerate>=0.25.0
+scipy>=1.10.0
+
+# dev
+pre-commit>=3.2.1
+toml>=0.10.2
diff --git a/src/datadreamer/dataset_annotation/__init__.py b/src/datadreamer/dataset_annotation/__init__.py
index d3c82eb..0789761 100644
--- a/src/datadreamer/dataset_annotation/__init__.py
+++ b/src/datadreamer/dataset_annotation/__init__.py
@@ -1,3 +1,5 @@
-from .owlv2_annotator import OWLv2Annotator
from .image_annotator import BaseAnnotator, TaskList
from .kosmos2_annotator import Kosmos2Annotator
+from .owlv2_annotator import OWLv2Annotator
+
+__all__ = ["BaseAnnotator", "TaskList", "OWLv2Annotator", "Kosmos2Annotator"]
diff --git a/src/datadreamer/dataset_annotation/image_annotator.py b/src/datadreamer/dataset_annotation/image_annotator.py
index 6bb56ae..8ba4935 100644
--- a/src/datadreamer/dataset_annotation/image_annotator.py
+++ b/src/datadreamer/dataset_annotation/image_annotator.py
@@ -1,6 +1,5 @@
-from abc import ABC, abstractmethod
-from typing import List, Generic, TypeVar
import enum
+from abc import ABC, abstractmethod
# Enum for different labeling tasks
@@ -13,8 +12,7 @@ class TaskList(enum.Enum):
# Abstract base class for data labeling
class BaseAnnotator(ABC):
- """
- Abstract base class for creating annotators.
+ """Abstract base class for creating annotators.
Attributes:
seed (float): A seed value to ensure reproducibility in annotation processes.
diff --git a/src/datadreamer/dataset_annotation/kosmos2_annotator.py b/src/datadreamer/dataset_annotation/kosmos2_annotator.py
index 4469986..0479060 100644
--- a/src/datadreamer/dataset_annotation/kosmos2_annotator.py
+++ b/src/datadreamer/dataset_annotation/kosmos2_annotator.py
@@ -1,14 +1,13 @@
-import torch
-import torchvision.ops as ops
import numpy as np
+import torch
+from transformers import AutoProcessor, Kosmos2ForConditionalGeneration
from datadreamer.dataset_annotation.image_annotator import BaseAnnotator
-from transformers import AutoProcessor, Kosmos2ForConditionalGeneration
class Kosmos2Annotator(BaseAnnotator):
- """
- An image annotator class that utilizes the Kosmos2 model for conditional image generation.
+ """An image annotator class that utilizes the Kosmos2 model for conditional image
+ generation.
Attributes:
model (Kosmos2ForConditionalGeneration): The Kosmos2 model for conditional image generation.
@@ -27,8 +26,7 @@ def __init__(
seed: float,
device: str = "cuda",
) -> None:
- """
- Initializes the Kosmos2Annotator with a given seed and device.
+ """Initializes the Kosmos2Annotator with a given seed and device.
Args:
seed (float): Seed for reproducibility.
@@ -41,8 +39,7 @@ def __init__(
self.model.to(self.device)
def _init_model(self):
- """
- Initializes the Kosmos2 model.
+ """Initializes the Kosmos2 model.
Returns:
Kosmos2ForConditionalGeneration: The initialized Kosmos2 model.
@@ -52,8 +49,7 @@ def _init_model(self):
)
def _init_processor(self):
- """
- Initializes the processor for the Kosmos2 model.
+ """Initializes the processor for the Kosmos2 model.
Returns:
AutoProcessor: The initialized processor.
@@ -61,8 +57,7 @@ def _init_processor(self):
return AutoProcessor.from_pretrained("microsoft/kosmos-2-patch14-224")
def annotate(self, image, prompts, conf_threshold=0.1, use_tta=False):
- """
- Annotates an image using the Kosmos2 model.
+ """Annotates an image using the Kosmos2 model.
Args:
image: The image to be annotated.
@@ -93,9 +88,6 @@ def annotate(self, image, prompts, conf_threshold=0.1, use_tta=False):
generated_text = self.processor.batch_decode(
generated_ids, skip_special_tokens=True
)[0]
- processed_text = self.processor.post_process_generation(
- generated_text, cleanup_and_extract=False
- )
caption, entities = self.processor.post_process_generation(generated_text)
@@ -122,8 +114,7 @@ def annotate(self, image, prompts, conf_threshold=0.1, use_tta=False):
return final_boxes, final_scores, final_labels
def release(self, empty_cuda_cache=False) -> None:
- """
- Releases the model and optionally empties the CUDA cache.
+ """Releases the model and optionally empties the CUDA cache.
Args:
empty_cuda_cache (bool, optional): Whether to empty the CUDA cache. Defaults to False.
diff --git a/src/datadreamer/dataset_annotation/owlv2_annotator.py b/src/datadreamer/dataset_annotation/owlv2_annotator.py
index 470abbf..f6d571e 100644
--- a/src/datadreamer/dataset_annotation/owlv2_annotator.py
+++ b/src/datadreamer/dataset_annotation/owlv2_annotator.py
@@ -1,16 +1,14 @@
import torch
-import torchvision.ops as ops
+from transformers import Owlv2ForObjectDetection, Owlv2Processor
from datadreamer.dataset_annotation.image_annotator import BaseAnnotator
-from transformers import Owlv2Processor, Owlv2ForObjectDetection
from datadreamer.dataset_annotation.utils import apply_tta
-
from datadreamer.utils.nms import non_max_suppression
class OWLv2Annotator(BaseAnnotator):
- """
- A class for image annotation using the OWLv2 model, specializing in object detection.
+ """A class for image annotation using the OWLv2 model, specializing in object
+ detection.
Attributes:
model (Owlv2ForObjectDetection): The OWLv2 model for object detection.
@@ -29,8 +27,7 @@ def __init__(
seed: float = 42,
device: str = "cuda",
) -> None:
- """
- Initializes the OWLv2Annotator with a specific seed and device.
+ """Initializes the OWLv2Annotator with a specific seed and device.
Args:
seed (float): Seed for reproducibility. Defaults to 42.
@@ -43,8 +40,7 @@ def __init__(
self.model.to(self.device)
def _init_model(self):
- """
- Initializes the OWLv2 model for object detection.
+ """Initializes the OWLv2 model for object detection.
Returns:
Owlv2ForObjectDetection: The initialized OWLv2 model.
@@ -54,8 +50,7 @@ def _init_model(self):
)
def _init_processor(self):
- """
- Initializes the processor for the OWLv2 model.
+ """Initializes the processor for the OWLv2 model.
Returns:
Owlv2Processor: The initialized processor.
@@ -67,8 +62,7 @@ def _init_processor(self):
def annotate(
self, image, prompts, conf_threshold=0.1, use_tta=False, synonym_dict=None
):
- """
- Annotates an image using the OWLv2 model.
+ """Annotates an image using the OWLv2 model.
Args:
image: The image to be annotated.
@@ -162,8 +156,7 @@ def annotate(
return final_boxes, final_scores, final_labels
def release(self, empty_cuda_cache=False) -> None:
- """
- Releases the model and optionally empties the CUDA cache.
+ """Releases the model and optionally empties the CUDA cache.
Args:
empty_cuda_cache (bool, optional): Whether to empty the CUDA cache. Defaults to False.
diff --git a/src/datadreamer/dataset_annotation/utils.py b/src/datadreamer/dataset_annotation/utils.py
index 742a226..acf9734 100644
--- a/src/datadreamer/dataset_annotation/utils.py
+++ b/src/datadreamer/dataset_annotation/utils.py
@@ -2,8 +2,7 @@
def apply_tta(image):
- """
- Apply test-time augmentation (TTA) to the given image.
+ """Apply test-time augmentation (TTA) to the given image.
Args:
image: The image to be augmented.
diff --git a/src/datadreamer/image_generation/__init__.py b/src/datadreamer/image_generation/__init__.py
index 04e8196..6462c2d 100644
--- a/src/datadreamer/image_generation/__init__.py
+++ b/src/datadreamer/image_generation/__init__.py
@@ -1,2 +1,4 @@
from .sdxl_image_generator import StableDiffusionImageGenerator
from .sdxl_turbo_image_generator import StableDiffusionTurboImageGenerator
+
+__all__ = ["StableDiffusionImageGenerator", "StableDiffusionTurboImageGenerator"]
diff --git a/src/datadreamer/image_generation/clip_image_tester.py b/src/datadreamer/image_generation/clip_image_tester.py
index d1f767a..faf4ceb 100644
--- a/src/datadreamer/image_generation/clip_image_tester.py
+++ b/src/datadreamer/image_generation/clip_image_tester.py
@@ -1,12 +1,12 @@
from typing import List
-from PIL import Image
+
import torch
+from PIL import Image
from transformers import CLIPModel, CLIPProcessor
class ClipImageTester:
- """
- A class for testing images against a set of textual objects using the CLIP model.
+ """A class for testing images against a set of textual objects using the CLIP model.
Attributes:
clip (CLIPModel): The CLIP model for image-text similarity evaluation.
@@ -18,17 +18,14 @@ class ClipImageTester:
"""
def __init__(self) -> None:
- """
- Initializes the ClipImageTester with the CLIP model and processor.
- """
+ """Initializes the ClipImageTester with the CLIP model and processor."""
self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
self.clip_processor = CLIPProcessor.from_pretrained(
"openai/clip-vit-base-patch32"
)
def test_image(self, image: Image.Image, objects: List[str], conf_threshold=0.05):
- """
- Tests the generated image against a set of objects using the CLIP model.
+ """Tests the generated image against a set of objects using the CLIP model.
Args:
image (Image.Image): The image to be tested.
@@ -55,8 +52,7 @@ def test_image(self, image: Image.Image, objects: List[str], conf_threshold=0.05
return passed, probs, num_passed
def release(self, empty_cuda_cache=False) -> None:
- """
- Releases the model and optionally empties the CUDA cache.
+ """Releases the model and optionally empties the CUDA cache.
Args:
empty_cuda_cache (bool, optional): Whether to empty the CUDA cache. Defaults to False.
diff --git a/src/datadreamer/image_generation/image_generator.py b/src/datadreamer/image_generation/image_generator.py
index e001244..7b5ffb0 100644
--- a/src/datadreamer/image_generation/image_generator.py
+++ b/src/datadreamer/image_generation/image_generator.py
@@ -1,17 +1,17 @@
-from abc import ABC, abstractmethod
-from typing import Optional, Union, List
-import enum
-from PIL import Image
+import random
+from abc import abstractmethod
+from typing import List, Optional, Union
+
import torch
+from PIL import Image
from tqdm import tqdm
-import random
from datadreamer.image_generation.clip_image_tester import ClipImageTester
class ImageGenerator:
- """
- A class for generating images based on textual prompts, with optional CLIP model testing.
+ """A class for generating images based on textual prompts, with optional CLIP model
+ testing.
Attributes:
prompt_prefix (str): Optional prefix to add to every prompt.
@@ -43,9 +43,7 @@ def __init__(
image_tester_patience: Optional[int] = 1,
seed: Optional[float] = 42,
) -> None:
- """
- Initializes the ImageGenerator with the specified settings.
- """
+ """Initializes the ImageGenerator with the specified settings."""
self.prompt_prefix = prompt_prefix
self.prompt_suffix = prompt_suffix
self.negative_prompt = negative_prompt
@@ -59,8 +57,7 @@ def __init__(
@staticmethod
def set_seed(seed: int):
- """
- Sets the seed for random number generators in Python and PyTorch.
+ """Sets the seed for random number generators in Python and PyTorch.
Args:
seed (int): The seed value to set.
@@ -74,8 +71,7 @@ def generate_images(
prompts: Union[str, List[str]],
prompt_objects: Optional[List[List[str]]] = None,
):
- """
- Generates images based on the provided prompts and optional object prompts.
+ """Generates images based on the provided prompts and optional object prompts.
Args:
prompts (Union[str, List[str]]): Single prompt or a list of prompts to guide the image generation.
@@ -129,9 +125,7 @@ def generate_images(
@abstractmethod
def release(self, empty_cuda_cache=False) -> None:
- """
- Releases resources and optionally empties the CUDA cache.
- """
+ """Releases resources and optionally empties the CUDA cache."""
pass
@abstractmethod
@@ -141,8 +135,7 @@ def generate_image(
negative_prompt: str,
prompt_objects: Optional[List[str]] = None,
) -> Image.Image:
- """
- Generates a single image based on the provided prompt.
+ """Generates a single image based on the provided prompt.
Args:
prompt (str): The positive prompt to guide image generation.
diff --git a/src/datadreamer/image_generation/sdxl_image_generator.py b/src/datadreamer/image_generation/sdxl_image_generator.py
index 23e8a40..c2d57c7 100644
--- a/src/datadreamer/image_generation/sdxl_image_generator.py
+++ b/src/datadreamer/image_generation/sdxl_image_generator.py
@@ -1,15 +1,16 @@
-from PIL import Image
+from typing import List, Optional
+
import torch
-from diffusers import DiffusionPipeline
from compel import Compel, ReturnedEmbeddingsType
-from typing import List, Optional
+from diffusers import DiffusionPipeline
+from PIL import Image
from datadreamer.image_generation.image_generator import ImageGenerator
class StableDiffusionImageGenerator(ImageGenerator):
- """
- A subclass of ImageGenerator that uses the Stable Diffusion model for image generation.
+ """A subclass of ImageGenerator that uses the Stable Diffusion model for image
+ generation.
Attributes:
base (DiffusionPipeline): The base Stable Diffusion model for initial image generation.
@@ -25,16 +26,13 @@ class StableDiffusionImageGenerator(ImageGenerator):
"""
def __init__(self, *args, **kwargs):
- """
- Initializes the StableDiffusionImageGenerator with the given arguments.
- """
+ """Initializes the StableDiffusionImageGenerator with the given arguments."""
super().__init__(*args, **kwargs)
self.base, self.refiner = self._init_gen_model()
self.base_processor, self.refiner_processor = self._init_processor()
def _init_gen_model(self):
- """
- Initializes the base and refiner models of Stable Diffusion.
+ """Initializes the base and refiner models of Stable Diffusion.
Returns:
tuple: The base and refiner models.
@@ -59,8 +57,7 @@ def _init_gen_model(self):
return base, refiner
def _init_processor(self):
- """
- Initializes the processors for the base and refiner models.
+ """Initializes the processors for the base and refiner models.
Returns:
tuple: The processors for the base and refiner models.
@@ -85,8 +82,8 @@ def generate_image(
negative_prompt: str,
prompt_objects: Optional[List[str]] = None,
) -> Image.Image:
- """
- Generates an image based on the provided prompt, using Stable Diffusion models.
+ """Generates an image based on the provided prompt, using Stable Diffusion
+ models.
Args:
prompt (str): The positive prompt to guide image generation.
@@ -129,9 +126,7 @@ def generate_image(
return image
def release(self, empty_cuda_cache=False) -> None:
- """
- Releases the models and optionally empties the CUDA cache.
- """
+ """Releases the models and optionally empties the CUDA cache."""
self.base = self.base.to("cpu")
self.refiner = self.refiner.to("cpu")
if self.use_clip_image_tester:
diff --git a/src/datadreamer/image_generation/sdxl_turbo_image_generator.py b/src/datadreamer/image_generation/sdxl_turbo_image_generator.py
index 69a2f45..dedd8df 100644
--- a/src/datadreamer/image_generation/sdxl_turbo_image_generator.py
+++ b/src/datadreamer/image_generation/sdxl_turbo_image_generator.py
@@ -1,14 +1,15 @@
-from PIL import Image
+from typing import List, Optional
+
import torch
from diffusers import AutoPipelineForText2Image
-from typing import List, Optional
+from PIL import Image
from datadreamer.image_generation.image_generator import ImageGenerator
class StableDiffusionTurboImageGenerator(ImageGenerator):
- """
- A subclass of ImageGenerator specifically designed to use the Stable Diffusion Turbo model for faster image generation.
+ """A subclass of ImageGenerator specifically designed to use the Stable Diffusion
+ Turbo model for faster image generation.
Attributes:
base (AutoPipelineForText2Image): The Stable Diffusion Turbo model for image generation.
@@ -20,15 +21,13 @@ class StableDiffusionTurboImageGenerator(ImageGenerator):
"""
def __init__(self, *args, **kwargs):
- """
- Initializes the StableDiffusionTurboImageGenerator with the given arguments.
- """
+ """Initializes the StableDiffusionTurboImageGenerator with the given
+ arguments."""
super().__init__(*args, **kwargs)
self.base = self._init_gen_model()
def _init_gen_model(self):
- """
- Initializes the Stable Diffusion Turbo model for image generation.
+ """Initializes the Stable Diffusion Turbo model for image generation.
Returns:
AutoPipelineForText2Image: The initialized Stable Diffusion Turbo model.
@@ -49,8 +48,8 @@ def generate_image(
negative_prompt: str,
prompt_objects: Optional[List[str]] = None,
) -> Image.Image:
- """
- Generates an image using the Stable Diffusion Turbo model based on the provided prompt.
+ """Generates an image using the Stable Diffusion Turbo model based on the
+ provided prompt.
Args:
prompt (str): The positive prompt to guide image generation.
@@ -70,9 +69,7 @@ def generate_image(
return image
def release(self, empty_cuda_cache=False) -> None:
- """
- Releases the model and optionally empties the CUDA cache.
- """
+ """Releases the model and optionally empties the CUDA cache."""
self.base = self.base.to("cpu")
if self.use_clip_image_tester:
self.clip_image_tester.release()
diff --git a/src/datadreamer/pipelines/generate_dataset_from_scratch.py b/src/datadreamer/pipelines/generate_dataset_from_scratch.py
index 3ffb625..376e862 100644
--- a/src/datadreamer/pipelines/generate_dataset_from_scratch.py
+++ b/src/datadreamer/pipelines/generate_dataset_from_scratch.py
@@ -1,24 +1,23 @@
-import matplotlib.pyplot as plt
+import argparse
+import json
+import os
+
import matplotlib.patches as patches
-import torch
-from PIL import Image
+import matplotlib.pyplot as plt
import numpy as np
-import os
-import json
-import argparse
+from PIL import Image
from tqdm import tqdm
+from datadreamer.dataset_annotation import OWLv2Annotator
+from datadreamer.image_generation import (
+ StableDiffusionImageGenerator,
+ StableDiffusionTurboImageGenerator,
+)
from datadreamer.prompt_generation import (
- SimplePromptGenerator,
LMPromptGenerator,
+ SimplePromptGenerator,
SynonymGenerator,
)
-from datadreamer.image_generation import (
- StableDiffusionTurboImageGenerator,
- StableDiffusionImageGenerator,
-)
-from datadreamer.dataset_annotation import OWLv2Annotator
-
prompt_generators = {"simple": SimplePromptGenerator, "lm": LMPromptGenerator}
@@ -149,11 +148,11 @@ def save_det_annotations_to_json(
file_name="annotations.json",
):
annotations = {}
- for image_path, bboxes, labels_list in zip(image_paths, boxes_list, labels_list):
+ for image_path, bboxes, labels in zip(image_paths, boxes_list, labels_list):
image_name = os.path.basename(image_path)
annotations[image_name] = {
"boxes": bboxes.tolist(),
- "labels": labels_list.tolist(),
+ "labels": labels.tolist(),
}
annotations["class_names"] = class_names
@@ -166,10 +165,10 @@ def save_clf_annotations_to_json(
image_paths, labels_list, class_names, save_dir, file_name="annotations.json"
):
annotations = {}
- for image_path, labels_list in zip(image_paths, labels_list):
+ for image_path, labels in zip(image_paths, labels_list):
image_name = os.path.basename(image_path)
annotations[image_name] = {
- "labels": labels_list.tolist(),
+ "labels": labels.tolist(),
}
annotations["class_names"] = class_names
@@ -242,7 +241,7 @@ def main():
if args.task == "classification":
# Classification annotation
labels_list = []
- for i, (image_path, prompt_objs) in enumerate(zip(image_paths, prompt_objects)):
+ for prompt_objs in prompt_objects:
labels = []
for obj in prompt_objs:
labels.append(args.class_names.index(obj))
diff --git a/src/datadreamer/prompt_generation/__init__.py b/src/datadreamer/prompt_generation/__init__.py
index 2216420..c2096d7 100644
--- a/src/datadreamer/prompt_generation/__init__.py
+++ b/src/datadreamer/prompt_generation/__init__.py
@@ -1,3 +1,5 @@
-from .simple_prompt_generator import SimplePromptGenerator
from .lm_prompt_generator import LMPromptGenerator
+from .simple_prompt_generator import SimplePromptGenerator
from .synonym_generator import SynonymGenerator
+
+__all__ = ["SimplePromptGenerator", "LMPromptGenerator", "SynonymGenerator"]
diff --git a/src/datadreamer/prompt_generation/lm_prompt_generator.py b/src/datadreamer/prompt_generation/lm_prompt_generator.py
index e963806..13b01f5 100644
--- a/src/datadreamer/prompt_generation/lm_prompt_generator.py
+++ b/src/datadreamer/prompt_generation/lm_prompt_generator.py
@@ -1,16 +1,16 @@
import random
-from transformers import AutoModelForCausalLM, AutoTokenizer
-import torch
-from tqdm import tqdm
import re
from typing import List, Optional
+import torch
+from tqdm import tqdm
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
from datadreamer.prompt_generation.prompt_generator import PromptGenerator
class LMPromptGenerator(PromptGenerator):
- """
- A language model-based prompt generator class, extending PromptGenerator.
+ """A language model-based prompt generator class, extending PromptGenerator.
Attributes:
device (str): Device to run the language model on ('cuda' for GPU, 'cpu' for CPU).
@@ -30,20 +30,18 @@ def __init__(
self,
class_names: List[str],
prompts_number: int = 10,
- num_objects_range: Optional[List[int]] = [1, 3],
+ num_objects_range: Optional[List[int]] = None,
seed: Optional[float] = 42,
device: str = "cuda",
) -> None:
- """
- Initializes the LMPromptGenerator with class names and other settings.
- """
+ """Initializes the LMPromptGenerator with class names and other settings."""
+ num_objects_range = num_objects_range or [1, 3]
super().__init__(class_names, prompts_number, num_objects_range, seed)
self.device = device
self.model, self.tokenizer = self._init_lang_model()
def _init_lang_model(self):
- """
- Initializes the language model and tokenizer for prompt generation.
+ """Initializes the language model and tokenizer for prompt generation.
Returns:
tuple: The initialized language model and tokenizer.
@@ -57,8 +55,7 @@ def _init_lang_model(self):
return model, tokenizer
def generate_prompts(self) -> List[str]:
- """
- Generates a list of text prompts based on the class names.
+ """Generates a list of text prompts based on the class names.
Returns:
List[str]: A list of generated prompts.
@@ -78,8 +75,7 @@ def generate_prompts(self) -> List[str]:
return prompts
def _create_lm_prompt_text(self, selected_objects: List[str]) -> str:
- """
- Creates a language model text prompt based on selected objects.
+ """Creates a language model text prompt based on selected objects.
Args:
selected_objects (List[str]): Objects to include in the prompt.
@@ -90,8 +86,7 @@ def _create_lm_prompt_text(self, selected_objects: List[str]) -> str:
return f"[INST] Generate a short and consice caption for an image. Follow this template: 'A photo of {', '.join(selected_objects)}', where the objects interact in a meaningful way within a scene, complete with a short scene description. [/INST]"
def generate_prompt(self, prompt_text: str) -> str:
- """
- Generates a single prompt using the language model.
+ """Generates a single prompt using the language model.
Args:
prompt_text (str): The text prompt for the language model.
@@ -120,8 +115,7 @@ def generate_prompt(self, prompt_text: str) -> str:
return decoded_prompt
def _test_prompt(self, prompt: str, selected_objects: List[str]) -> bool:
- """
- Tests if the generated prompt is valid based on selected objects.
+ """Tests if the generated prompt is valid based on selected objects.
Args:
prompt (str): The generated prompt.
@@ -135,9 +129,7 @@ def _test_prompt(self, prompt: str, selected_objects: List[str]) -> bool:
) # and all(obj.lower() in prompt.lower() for obj in selected_objects)
def release(self, empty_cuda_cache=False) -> None:
- """
- Releases the model and optionally empties the CUDA cache.
- """
+ """Releases the model and optionally empties the CUDA cache."""
self.model = self.model.to("cpu")
if empty_cuda_cache:
with torch.no_grad():
diff --git a/src/datadreamer/prompt_generation/prompt_generator.py b/src/datadreamer/prompt_generation/prompt_generator.py
index 4cbe144..e0cc84e 100644
--- a/src/datadreamer/prompt_generation/prompt_generator.py
+++ b/src/datadreamer/prompt_generation/prompt_generator.py
@@ -1,16 +1,14 @@
import json
import random
-import torch
+from abc import ABC, abstractmethod
from typing import List, Optional
-
-from abc import ABC, abstractmethod
+import torch
# Abstract base class for prompt generation
class PromptGenerator(ABC):
- """
- Abstract base class for prompt generation.
+ """Abstract base class for prompt generation.
Attributes:
class_names (List[str]): List of class names or objects for prompt generation.
@@ -29,23 +27,20 @@ def __init__(
self,
class_names: List[str],
prompts_number: int = 10,
- num_objects_range: Optional[List[int]] = [1, 3],
+ num_objects_range: Optional[List[int]] = None,
seed: Optional[float] = None,
) -> None:
- """
- Initializes the PromptGenerator with class names and other settings.
- """
+ """Initializes the PromptGenerator with class names and other settings."""
self.class_names = class_names
self.prompts_number = prompts_number
- self.num_objects_range = num_objects_range
+ self.num_objects_range = num_objects_range or [1, 3]
self.seed = seed
if seed is not None:
self.set_seed(seed)
@staticmethod
def set_seed(seed: int):
- """
- Sets the random seed for consistent prompt generation.
+ """Sets the random seed for consistent prompt generation.
Args:
seed (int): The random seed.
@@ -55,8 +50,7 @@ def set_seed(seed: int):
torch.cuda.manual_seed_all(seed)
def save_prompts(self, prompts: List[str], save_path: str) -> None:
- """
- Saves generated prompts to a JSON file.
+ """Saves generated prompts to a JSON file.
Args:
prompts (List[str]): List of generated prompts.
@@ -67,8 +61,7 @@ def save_prompts(self, prompts: List[str], save_path: str) -> None:
@abstractmethod
def generate_prompts(self) -> List[str]:
- """
- Abstract method to generate prompts (must be implemented in subclasses).
+ """Abstract method to generate prompts (must be implemented in subclasses).
Returns:
List[str]: A list of generated prompts.
@@ -77,7 +70,5 @@ def generate_prompts(self) -> List[str]:
@abstractmethod
def release(self, empty_cuda_cache=False) -> None:
- """
- Abstract method to release resources (must be implemented in subclasses).
- """
+ """Abstract method to release resources (must be implemented in subclasses)."""
pass
diff --git a/src/datadreamer/prompt_generation/simple_prompt_generator.py b/src/datadreamer/prompt_generation/simple_prompt_generator.py
index 7572016..c701024 100644
--- a/src/datadreamer/prompt_generation/simple_prompt_generator.py
+++ b/src/datadreamer/prompt_generation/simple_prompt_generator.py
@@ -1,12 +1,11 @@
import random
-from typing import List, Optional
+from typing import List
from datadreamer.prompt_generation.prompt_generator import PromptGenerator
class SimplePromptGenerator(PromptGenerator):
- """
- Prompt generator that creates simple prompts for text generation tasks.
+ """Prompt generator that creates simple prompts for text generation tasks.
Args:
class_names (List[str]): List of class names or objects for prompt generation.
@@ -25,14 +24,11 @@ def __init__(
*args,
**kwargs,
) -> None:
- """
- Initializes the SimplePromptGenerator with class names and other settings.
- """
+ """Initializes the SimplePromptGenerator with class names and other settings."""
super().__init__(*args, **kwargs)
def generate_prompts(self) -> List[str]:
- """
- Generates a list of simple prompts.
+ """Generates a list of simple prompts.
Returns:
List[str]: A list of generated prompts in the form of "A photo of a {selected_objects}".
@@ -47,8 +43,7 @@ def generate_prompts(self) -> List[str]:
return prompts
def generate_prompt(self, selected_objects: List[str]) -> str:
- """
- Generates a single simple prompt based on selected objects.
+ """Generates a single simple prompt based on selected objects.
Args:
selected_objects (List[str]): List of selected objects to include in the prompt.
@@ -59,9 +54,7 @@ def generate_prompt(self, selected_objects: List[str]) -> str:
return f"A photo of a {', a '.join(selected_objects)}"
def release(self, empty_cuda_cache=False) -> None:
- """
- Releases resources (no action is taken in this implementation).
- """
+ """Releases resources (no action is taken in this implementation)."""
pass
diff --git a/src/datadreamer/prompt_generation/synonym_generator.py b/src/datadreamer/prompt_generation/synonym_generator.py
index 8b28ae8..75eae0d 100644
--- a/src/datadreamer/prompt_generation/synonym_generator.py
+++ b/src/datadreamer/prompt_generation/synonym_generator.py
@@ -1,19 +1,15 @@
-from typing import List, Optional
-import random
-import torch
-from transformers import AutoModelForCausalLM, AutoTokenizer
import json
import re
-from tqdm import tqdm
-
from typing import List, Optional
+
import torch
+from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
class SynonymGenerator:
- """
- Synonym generator that generates synonyms for a list of words using a language model.
+ """Synonym generator that generates synonyms for a list of words using a language
+ model.
Args:
synonyms_number (int): Number of synonyms to generate for each word.
@@ -33,18 +29,14 @@ def __init__(
seed: Optional[float] = 42,
device: str = "cuda",
) -> None:
- """
- Initializes the SynonymGenerator with parameters.
- """
+ """Initializes the SynonymGenerator with parameters."""
self.synonyms_number = synonyms_number
self.seed = seed
self.device = device
self.model, self.tokenizer = self._init_lang_model()
def _init_lang_model(self):
- """
- Initializes the language model and tokenizer for synonym generation.
- """
+ """Initializes the language model and tokenizer for synonym generation."""
print("Initializing language model for synonym generation")
model = AutoModelForCausalLM.from_pretrained(
"mistralai/Mistral-7B-Instruct-v0.1", torch_dtype=torch.float16
@@ -53,8 +45,7 @@ def _init_lang_model(self):
return model, tokenizer
def generate_synonyms_for_list(self, words: List[str]) -> dict:
- """
- Generates synonyms for a list of words and returns them in a dictionary.
+ """Generates synonyms for a list of words and returns them in a dictionary.
Args:
words (List[str]): List of words for which synonyms are generated.
@@ -70,8 +61,7 @@ def generate_synonyms_for_list(self, words: List[str]) -> dict:
return synonyms_dict
def generate_synonyms(self, word: str) -> List[str]:
- """
- Generates synonyms for a single word and returns them in a list.
+ """Generates synonyms for a single word and returns them in a list.
Args:
word (str): The word for which synonyms are generated.
@@ -84,8 +74,7 @@ def generate_synonyms(self, word: str) -> List[str]:
return generated_synonyms
def _create_prompt_text(self, word: str) -> str:
- """
- Creates a prompt text for generating synonyms for a given word.
+ """Creates a prompt text for generating synonyms for a given word.
Args:
word (str): The word for which synonyms are generated.
@@ -96,8 +85,7 @@ def _create_prompt_text(self, word: str) -> str:
return f"[INST] List {self.synonyms_number} most common synonyms for the word '{word}'. Write only synonyms separated by commas. [/INST]"
def _generate_synonyms(self, prompt_text: str) -> List[str]:
- """
- Generates synonyms based on a given prompt text.
+ """Generates synonyms based on a given prompt text.
Args:
prompt_text (str): The prompt text for generating synonyms.
@@ -131,8 +119,7 @@ def _generate_synonyms(self, prompt_text: str) -> List[str]:
return synonyms
def _extract_synonyms(self, text: str) -> List[str]:
- """
- Extracts synonyms from a text containing synonyms.
+ """Extracts synonyms from a text containing synonyms.
Args:
text (str): The text containing synonyms.
@@ -146,8 +133,7 @@ def _extract_synonyms(self, text: str) -> List[str]:
return synonyms[: self.synonyms_number]
def save_synonyms(self, synonyms, save_path: str) -> None:
- """
- Saves the generated synonyms to a JSON file.
+ """Saves the generated synonyms to a JSON file.
Args:
synonyms: The synonyms to save (typically a dictionary).
@@ -157,8 +143,7 @@ def save_synonyms(self, synonyms, save_path: str) -> None:
json.dump(synonyms, f)
def release(self, empty_cuda_cache=False) -> None:
- """
- Releases resources and optionally empties the CUDA cache.
+ """Releases resources and optionally empties the CUDA cache.
Args:
empty_cuda_cache (bool): Whether to empty the CUDA cache (default is False).
diff --git a/src/datadreamer/utils/convert_dataset_to_yolo.py b/src/datadreamer/utils/convert_dataset_to_yolo.py
index 61e582f..9e16823 100644
--- a/src/datadreamer/utils/convert_dataset_to_yolo.py
+++ b/src/datadreamer/utils/convert_dataset_to_yolo.py
@@ -1,14 +1,14 @@
+import argparse
import json
import os
import shutil
+
import numpy as np
-import argparse
from PIL import Image
def read_annotations(annotation_path):
- """
- Reads annotations from a JSON file located at the specified path.
+ """Reads annotations from a JSON file located at the specified path.
Args:
- annotation_path (str): The path to the JSON file containing annotations.
@@ -22,8 +22,7 @@ def read_annotations(annotation_path):
def convert_to_yolo_format(box, image_width, image_height):
- """
- Converts bounding box coordinates to YOLO format.
+ """Converts bounding box coordinates to YOLO format.
Args:
- box (list of float): A list containing the bounding box coordinates [x_min, y_min, x_max, y_max].
@@ -41,8 +40,8 @@ def convert_to_yolo_format(box, image_width, image_height):
def process_data(data, image_dir, output_dir, split_ratio):
- """
- Processes the data by dividing it into training and validation sets, and saves the images and labels in YOLO format.
+ """Processes the data by dividing it into training and validation sets, and saves
+ the images and labels in YOLO format.
Args:
- data (dict): The dictionary containing image annotations.
@@ -93,8 +92,7 @@ def process_data(data, image_dir, output_dir, split_ratio):
def create_data_yaml(root_dir, class_names):
- """
- Creates a YAML file for dataset configuration, specifying paths and class names.
+ """Creates a YAML file for dataset configuration, specifying paths and class names.
Args:
- root_dir (str): The root directory where the dataset is located.
@@ -113,8 +111,8 @@ def create_data_yaml(root_dir, class_names):
def convert(dataset_dir, output_dir, train_val_split_ratio):
- """
- Converts a dataset into a format suitable for training with YOLO, including creating training and validation splits.
+ """Converts a dataset into a format suitable for training with YOLO, including
+ creating training and validation splits.
Args:
- dataset_dir (str): The directory where the source dataset is located.
diff --git a/src/datadreamer/utils/nms.py b/src/datadreamer/utils/nms.py
index b6c38fe..4bc3992 100644
--- a/src/datadreamer/utils/nms.py
+++ b/src/datadreamer/utils/nms.py
@@ -5,12 +5,12 @@
import os
import time
-import numpy as np
+
import cv2
+import numpy as np
import torch
import torchvision
-
# Settings
torch.set_printoptions(linewidth=320, precision=5, profile="long")
np.set_printoptions(
@@ -23,7 +23,8 @@
def xywh2xyxy(x):
- """Convert boxes with shape [n, 4] from [x, y, w, h] to [x1, y1, x2, y2] where x1y1 is top-left, x2y2=bottom-right."""
+ """Convert boxes with shape [n, 4] from [x, y, w, h] to [x1, y1, x2, y2] where x1y1
+ is top-left, x2y2=bottom-right."""
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
diff --git a/tools/autogenerate_requirements.py b/tools/autogenerate_requirements.py
new file mode 100644
index 0000000..427384a
--- /dev/null
+++ b/tools/autogenerate_requirements.py
@@ -0,0 +1,23 @@
+import toml
+
+
+def main():
+ with open("pyproject.toml", "r") as f:
+ pyproject = toml.load(f)
+
+ with open("requirements.txt", "w") as f:
+ for dep in pyproject["project"]["dependencies"]:
+ if dep.startswith("python"):
+ continue
+ f.write(dep + "\n")
+
+ for name, deps in pyproject["project"]["optional-dependencies"].items():
+ f.write(f"\n# {name}\n")
+ for dep in deps:
+ if dep.startswith("datadreamer"):
+ continue
+ f.write(dep + "\n")
+
+
+if __name__ == "__main__":
+ main()