From d8ec318274fd4d14317e9e8dc5434b66ecf75066 Mon Sep 17 00:00:00 2001 From: Roger Condori <114810545+R3gm@users.noreply.github.com> Date: Sun, 1 Dec 2024 03:36:20 +0000 Subject: [PATCH] fix: escaped symbols --- pyproject.toml | 2 +- stablepy/__version__.py | 2 +- stablepy/diffusers_vanilla/prompt_weights.py | 33 +++++++++++++++++--- 3 files changed, 30 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e48b31e..cf397f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "stablepy" -version = "0.5.1" +version = "0.5.2" description = "A tool for easy use of stable diffusion" authors = ["Roger Condori(R3gm) "] readme = "README.md" diff --git a/stablepy/__version__.py b/stablepy/__version__.py index dd9b22c..7225152 100644 --- a/stablepy/__version__.py +++ b/stablepy/__version__.py @@ -1 +1 @@ -__version__ = "0.5.1" +__version__ = "0.5.2" diff --git a/stablepy/diffusers_vanilla/prompt_weights.py b/stablepy/diffusers_vanilla/prompt_weights.py index 74d92b7..8bda45f 100644 --- a/stablepy/diffusers_vanilla/prompt_weights.py +++ b/stablepy/diffusers_vanilla/prompt_weights.py @@ -1,8 +1,22 @@ -# ===================================== -# Prompt weights -# ===================================== import torch import re + +ESCAPED_SYNTACTIC_SYMBOLS = [ + '"', + '(', + ')', + '=', + # '-', + # '+', + # '.', + # ',', +] + +TRANSLATION_DICT = { + ord(symbol): "\\" + symbol for symbol in ESCAPED_SYNTACTIC_SYMBOLS +} + + def parse_prompt_attention(text): re_attention = re.compile(r""" \\\(| @@ -74,9 +88,12 @@ def multiply_range(start_position, multiplier): return res + def prompt_attention_to_invoke_prompt(attention): tokens = [] for text, weight in attention: + text = text.translate(TRANSLATION_DICT) + # Round weight to 2 decimal places weight = round(weight, 2) if weight == 1.0: @@ -93,11 +110,13 @@ def prompt_attention_to_invoke_prompt(attention): tokens.append(f"({text}){weight}") return "".join(tokens) + def concat_tensor(t): t_list = torch.split(t, 1, dim=0) t = torch.cat(t_list, dim=1) return t + def merge_embeds(prompt_chanks, compel): num_chanks = len(prompt_chanks) if num_chanks != 0: @@ -111,6 +130,7 @@ def merge_embeds(prompt_chanks, compel): prompt_emb = compel('') return prompt_emb + def detokenize(chunk, actual_prompt): chunk[-1] = chunk[-1].replace('', '') chanked_prompt = ''.join(chunk).strip() @@ -119,10 +139,11 @@ def detokenize(chunk, actual_prompt): chanked_prompt = chanked_prompt.replace('', ' ', 1) else: chanked_prompt = chanked_prompt.replace('', '', 1) - actual_prompt = actual_prompt.replace(chanked_prompt,'') + actual_prompt = actual_prompt.replace(chanked_prompt, '') return chanked_prompt.strip(), actual_prompt.strip() -def tokenize_line(line, tokenizer): # split into chunks + +def tokenize_line(line, tokenizer): # split into chunks actual_prompt = line.lower().strip() actual_tokens = tokenizer.tokenize(actual_prompt) max_tokens = tokenizer.model_max_length - 2 @@ -154,6 +175,7 @@ def tokenize_line(line, tokenizer): # split into chunks return chunks + def get_embed_new(prompt, pipeline, compel, only_convert_string=False, compel_process_sd=False): if compel_process_sd: @@ -203,6 +225,7 @@ def get_embed_new(prompt, pipeline, compel, only_convert_string=False, compel_pr return merge_embeds([prompt_attention_to_invoke_prompt(i) for i in global_prompt_chanks], compel) + def add_comma_after_pattern_ti(text): pattern = re.compile(r'\b\w+_\d+\b') modified_text = pattern.sub(lambda x: x.group() + ',', text)