Skip to content

Commit

Permalink
Remove caption reads sentences
Browse files Browse the repository at this point in the history
  • Loading branch information
HonzaCuhel committed Jan 30, 2024
1 parent 103e997 commit 059bf62
Show file tree
Hide file tree
Showing 5 changed files with 238 additions and 12 deletions.
6 changes: 5 additions & 1 deletion datadreamer/pipelines/generate_dataset_from_scratch.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@
TinyLlamaLMPromptGenerator,
)

prompt_generators = {"simple": SimplePromptGenerator, "lm": LMPromptGenerator, 'tiny': TinyLlamaLMPromptGenerator}
prompt_generators = {
"simple": SimplePromptGenerator,
"lm": LMPromptGenerator,
"tiny": TinyLlamaLMPromptGenerator,
}

image_generators = {
"sdxl": StableDiffusionImageGenerator,
Expand Down
7 changes: 6 additions & 1 deletion datadreamer/prompt_generation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,9 @@
from .synonym_generator import SynonymGenerator
from .tinyllama_lm_prompt_generator import TinyLlamaLMPromptGenerator

__all__ = ["SimplePromptGenerator", "LMPromptGenerator", "SynonymGenerator", "TinyLlamaLMPromptGenerator"]
__all__ = [
"SimplePromptGenerator",
"LMPromptGenerator",
"SynonymGenerator",
"TinyLlamaLMPromptGenerator",
]
28 changes: 20 additions & 8 deletions datadreamer/prompt_generation/tinyllama_lm_prompt_generator.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import random
import re
from typing import List, Optional

import torch
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

from datadreamer.prompt_generation.lm_prompt_generator import LMPromptGenerator
Expand Down Expand Up @@ -48,25 +46,38 @@ def _init_lang_model(self):
model = AutoModelForCausalLM.from_pretrained(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
torch_dtype="auto",
low_cpu_mem_usage=True
device_map="cpu",
low_cpu_mem_usage=True,
)
else:
print("Loading language model on GPU...")
model = AutoModelForCausalLM.from_pretrained(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", torch_dtype=torch.float16
)

tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", trust_remote_code=True
)
print("Done!")
return model.to(self.device), tokenizer

def remove_incomplete_sentence(self, text):
# Define the regex pattern to capture up to the last sentence-ending punctuation
pattern = r'^(.*[.!?])'
# pattern = r'^(.*[.!?;:])'
pattern = r"^(.*[.!?])"
match = re.search(pattern, text)
return match.group(0) if match else text

def remove_caption_sentences(self, text):
# Pattern to find sentences that start with "Caption reads: "
# \s* matches any whitespace characters at the beginning of the string (including none)
# re.IGNORECASE makes the search case-insensitive
# [^\.!?]* matches any sequence of characters that are not a period, exclamation mark, or question mark
# [\.\!?] matches a period, exclamation mark, or question mark, indicating the end of a sentence
pattern = re.compile(r"\s*Caption reads: [^\.!?]*[\.\!?]", re.IGNORECASE)
# Replace the matched sentences with an empty string
cleaned_text = re.sub(pattern, "", text)
return cleaned_text

def _create_lm_prompt_text(self, selected_objects: List[str]) -> str:
"""Creates a language model text prompt based on selected objects.
Expand All @@ -77,7 +88,6 @@ def _create_lm_prompt_text(self, selected_objects: List[str]) -> str:
str: A text prompt for the language model.
"""
return f"<|system|>\nYou are a chatbot who describes content of images!</s>\n<|user|>\nGenerate a short and concise caption for an image. Follow this template: 'A photo of {', '.join(selected_objects)}', where the objects interact in a meaningful way within a scene, complete with a short scene description. The caption must be short in length and start with the words: 'A photo of '! Do not use the phrase 'Caption reads'.</s>\n<|assistant|>\n"


def generate_prompt(self, prompt_text: str) -> str:
"""Generates a single prompt using the language model.
Expand Down Expand Up @@ -110,7 +120,9 @@ def generate_prompt(self, prompt_text: str) -> str:
.replace("'", "")
)

return self.remove_incomplete_sentence(decoded_prompt)
return self.remove_caption_sentences(
self.remove_incomplete_sentence(decoded_prompt)
)


if __name__ == "__main__":
Expand Down
186 changes: 186 additions & 0 deletions tests/integration/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,97 @@ def test_cuda_lm_sdxl_detection_pipeline():
_check_detection_pipeline(cmd, target_folder)


# =========================================================
# DETECTION - TinyLlama LLM
# =========================================================
@pytest.mark.skipif(
total_memory < 16 or total_disk_space < 35,
reason="Test requires at least 16GB of RAM and 35GB of HDD",
)
def test_cpu_tiny_sdxl_turbo_detection_pipeline():
# Define target folder
target_folder = "data/data-det-cpu-tiny-sdxl-turbo/"
# Define the command to run the datadreamer
cmd = (
f"datadreamer --save_dir {target_folder} "
f"--class_names alien mars cat "
f"--prompts_number 1 "
f"--prompt_generator tiny "
f"--num_objects_range 1 2 "
f"--image_generator sdxl-turbo "
f"--use_image_tester "
f"--device cpu"
)
# Check the run of the pipeline
_check_detection_pipeline(cmd, target_folder)


@pytest.mark.skipif(
not torch.cuda.is_available() or total_memory < 16 or total_disk_space < 35,
reason="Test requires GPU, at least 16GB of RAM and 35GB of HDD",
)
def test_cuda_tiny_sdxl_turbo_detection_pipeline():
# Define target folder
target_folder = "data/data-det-cuda-tiny-sdxl-turbo/"
# Define the command to run the datadreamer
cmd = (
f"datadreamer --save_dir {target_folder} "
f"--class_names alien mars cat "
f"--prompts_number 1 "
f"--prompt_generator tiny "
f"--num_objects_range 1 2 "
f"--image_generator sdxl-turbo "
f"--use_image_tester "
f"--device cuda"
)
# Check the run of the pipeline
_check_detection_pipeline(cmd, target_folder)


@pytest.mark.skipif(
total_memory < 16 or total_disk_space < 35,
reason="Test requires at least 16GB of RAM and 35GB of HDD",
)
def test_cpu_tiny_sdxl_detection_pipeline():
# Define target folder
target_folder = "data/data-det-cpu-tiny-sdxl/"
# Define the command to run the datadreamer
cmd = (
f"datadreamer --save_dir {target_folder} "
f"--class_names alien mars cat "
f"--prompts_number 1 "
f"--prompt_generator tiny "
f"--num_objects_range 1 2 "
f"--image_generator sdxl "
f"--use_image_tester "
f"--device cpu"
)
# Check the run of the pipeline
_check_detection_pipeline(cmd, target_folder)


@pytest.mark.skipif(
not torch.cuda.is_available() or total_memory < 16 or total_disk_space < 35,
reason="Test requires GPU, at least 16GB of RAM and 35GB of HDD",
)
def test_cuda_tiny_sdxl_detection_pipeline():
# Define target folder
target_folder = "data/data-det-cuda-tiny-sdxl/"
# Define the command to run the datadreamer
cmd = (
f"datadreamer --save_dir {target_folder} "
f"--class_names alien mars cat "
f"--prompts_number 1 "
f"--prompt_generator tiny "
f"--num_objects_range 1 2 "
f"--image_generator sdxl "
f"--use_image_tester "
f"--device cuda"
)
# Check the run of the pipeline
_check_detection_pipeline(cmd, target_folder)


# =========================================================
# CLASSIFICATION - SIMPLE LM
# =========================================================
Expand Down Expand Up @@ -544,3 +635,98 @@ def test_cuda_lm_sdxl_classification_pipeline():
)
# Check the run of the pipeline
_check_detection_pipeline(cmd, target_folder)


# =========================================================
# CLASSIFICATION - TinyLlama LLM
# =========================================================
@pytest.mark.skipif(
total_memory < 16 or total_disk_space < 35,
reason="Test requires at least 16GB of RAM and 35GB of HDD",
)
def test_cpu_tiny_sdxl_turbo_classification_pipeline():
# Define target folder
target_folder = "data/data-cls-cpu-tiny-sdxl-turbo/"
# Define the command to run the datadreamer
cmd = (
f"datadreamer --task classification "
f"--save_dir {target_folder} "
f"--class_names alien mars cat "
f"--prompts_number 1 "
f"--prompt_generator tiny "
f"--num_objects_range 1 2 "
f"--image_generator sdxl-turbo "
f"--use_image_tester "
f"--device cpu"
)
# Check the run of the pipeline
_check_detection_pipeline(cmd, target_folder)


@pytest.mark.skipif(
not torch.cuda.is_available() or total_memory < 16 or total_disk_space < 35,
reason="Test requires GPU, at least 16GB of RAM and 35GB of HDD",
)
def test_cuda_tiny_sdxl_turbo_classification_pipeline():
# Define target folder
target_folder = "data/data-cls-cuda-tiny-sdxl-turbo/"
# Define the command to run the datadreamer
cmd = (
f"datadreamer --task classification "
f"--save_dir {target_folder} "
f"--class_names alien mars cat "
f"--prompts_number 1 "
f"--prompt_generator tiny "
f"--num_objects_range 1 2 "
f"--image_generator sdxl-turbo "
f"--use_image_tester "
f"--device cuda"
)
# Check the run of the pipeline
_check_detection_pipeline(cmd, target_folder)


@pytest.mark.skipif(
total_memory < 16 or total_disk_space < 35,
reason="Test requires at least 16GB of RAM and 35GB of HDD",
)
def test_cpu_tiny_sdxl_classification_pipeline():
# Define target folder
target_folder = "data/data-cls-cpu-tiny-sdxl/"
# Define the command to run the datadreamer
cmd = (
f"datadreamer --task classification "
f"--save_dir {target_folder} "
f"--class_names alien mars cat "
f"--prompts_number 1 "
f"--prompt_generator tiny "
f"--num_objects_range 1 2 "
f"--image_generator sdxl "
f"--use_image_tester "
f"--device cpu"
)
# Check the run of the pipeline
_check_detection_pipeline(cmd, target_folder)


@pytest.mark.skipif(
not torch.cuda.is_available() or total_memory < 16 or total_disk_space < 35,
reason="Test requires GPU, at least 16GB of RAM and 35GB of HDD",
)
def test_cuda_tiny_sdxl_classification_pipeline():
# Define target folder
target_folder = "data/data-cls-cuda-tiny-sdxl/"
# Define the command to run the datadreamer
cmd = (
f"datadreamer --task classification "
f"--save_dir {target_folder} "
f"--class_names alien mars cat "
f"--prompts_number 1 "
f"--prompt_generator tiny "
f"--num_objects_range 1 2 "
f"--image_generator sdxl "
f"--use_image_tester "
f"--device cuda"
)
# Check the run of the pipeline
_check_detection_pipeline(cmd, target_folder)
23 changes: 21 additions & 2 deletions tests/unittests/test_prompt_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from datadreamer.prompt_generation.lm_prompt_generator import LMPromptGenerator
from datadreamer.prompt_generation.simple_prompt_generator import SimplePromptGenerator
from datadreamer.prompt_generation.synonym_generator import SynonymGenerator
from datadreamer.prompt_generation.tinyllama_lm_prompt_generator import (
TinyLlamaLMPromptGenerator,
)

# Get the total memory in GB
total_memory = psutil.virtual_memory().total / (1024**3)
Expand Down Expand Up @@ -32,9 +35,9 @@ def test_simple_prompt_generator():
assert prompt_text == f"A photo of a {', a '.join(selected_objects)}"


def _check_lm_prompt_generator(device: str):
def _check_lm_prompt_generator(device: str, prompt_generator_class=LMPromptGenerator):
object_names = ["aeroplane", "bicycle", "bird", "boat"]
prompt_generator = LMPromptGenerator(
prompt_generator = prompt_generator_class(
class_names=object_names, prompts_number=2, device=device
)
prompts = prompt_generator.generate_prompts()
Expand Down Expand Up @@ -73,6 +76,22 @@ def test_cpu_lm_prompt_generator():
_check_lm_prompt_generator("cpu")


@pytest.mark.skipif(
total_memory < 8 or not torch.cuda.is_available() or total_disk_space < 12,
reason="Test requires at least 8GB of RAM, 12GB of HDD and CUDA support",
)
def test_cuda_tinyllama_lm_prompt_generator():
_check_lm_prompt_generator("cuda", TinyLlamaLMPromptGenerator)


@pytest.mark.skipif(
total_memory < 12 or total_disk_space < 12,
reason="Test requires at least 12GB of RAM and 12GB of HDD for running on CPU",
)
def test_cpu_tinyllama_lm_prompt_generator():
_check_lm_prompt_generator("cpu", TinyLlamaLMPromptGenerator)


def _check_synonym_generator(device: str):
synonyms_num = 3
generator = SynonymGenerator(synonyms_number=synonyms_num, device=device)
Expand Down

0 comments on commit 059bf62

Please sign in to comment.