Skip to content

Commit

Permalink
fix: escaped symbols
Browse files Browse the repository at this point in the history
  • Loading branch information
R3gm committed Dec 1, 2024
1 parent 51550f1 commit d8ec318
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 7 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "stablepy"
version = "0.5.1"
version = "0.5.2"
description = "A tool for easy use of stable diffusion"
authors = ["Roger Condori(R3gm) <[email protected]>"]
readme = "README.md"
Expand Down
2 changes: 1 addition & 1 deletion stablepy/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.5.1"
__version__ = "0.5.2"
33 changes: 28 additions & 5 deletions stablepy/diffusers_vanilla/prompt_weights.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,22 @@
# =====================================
# Prompt weights
# =====================================
import torch
import re

ESCAPED_SYNTACTIC_SYMBOLS = [
'"',
'(',
')',
'=',
# '-',
# '+',
# '.',
# ',',
]

TRANSLATION_DICT = {
ord(symbol): "\\" + symbol for symbol in ESCAPED_SYNTACTIC_SYMBOLS
}


def parse_prompt_attention(text):
re_attention = re.compile(r"""
\\\(|
Expand Down Expand Up @@ -74,9 +88,12 @@ def multiply_range(start_position, multiplier):

return res


def prompt_attention_to_invoke_prompt(attention):
tokens = []
for text, weight in attention:
text = text.translate(TRANSLATION_DICT)

# Round weight to 2 decimal places
weight = round(weight, 2)
if weight == 1.0:
Expand All @@ -93,11 +110,13 @@ def prompt_attention_to_invoke_prompt(attention):
tokens.append(f"({text}){weight}")
return "".join(tokens)


def concat_tensor(t):
t_list = torch.split(t, 1, dim=0)
t = torch.cat(t_list, dim=1)
return t


def merge_embeds(prompt_chanks, compel):
num_chanks = len(prompt_chanks)
if num_chanks != 0:
Expand All @@ -111,6 +130,7 @@ def merge_embeds(prompt_chanks, compel):
prompt_emb = compel('')
return prompt_emb


def detokenize(chunk, actual_prompt):
chunk[-1] = chunk[-1].replace('</w>', '')
chanked_prompt = ''.join(chunk).strip()
Expand All @@ -119,10 +139,11 @@ def detokenize(chunk, actual_prompt):
chanked_prompt = chanked_prompt.replace('</w>', ' ', 1)
else:
chanked_prompt = chanked_prompt.replace('</w>', '', 1)
actual_prompt = actual_prompt.replace(chanked_prompt,'')
actual_prompt = actual_prompt.replace(chanked_prompt, '')
return chanked_prompt.strip(), actual_prompt.strip()

def tokenize_line(line, tokenizer): # split into chunks

def tokenize_line(line, tokenizer): # split into chunks
actual_prompt = line.lower().strip()
actual_tokens = tokenizer.tokenize(actual_prompt)
max_tokens = tokenizer.model_max_length - 2
Expand Down Expand Up @@ -154,6 +175,7 @@ def tokenize_line(line, tokenizer): # split into chunks

return chunks


def get_embed_new(prompt, pipeline, compel, only_convert_string=False, compel_process_sd=False):

if compel_process_sd:
Expand Down Expand Up @@ -203,6 +225,7 @@ def get_embed_new(prompt, pipeline, compel, only_convert_string=False, compel_pr

return merge_embeds([prompt_attention_to_invoke_prompt(i) for i in global_prompt_chanks], compel)


def add_comma_after_pattern_ti(text):
pattern = re.compile(r'\b\w+_\d+\b')
modified_text = pattern.sub(lambda x: x.group() + ',', text)
Expand Down

0 comments on commit d8ec318

Please sign in to comment.