Skip to content

Commit

Permalink
added wiki import support
Browse files Browse the repository at this point in the history
  • Loading branch information
thijsi123 committed Oct 5, 2024
1 parent 96e9841 commit 9362b9d
Show file tree
Hide file tree
Showing 5 changed files with 415 additions and 12 deletions.
14 changes: 11 additions & 3 deletions .idea/workspace.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

54 changes: 53 additions & 1 deletion app 2/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from image_generation import *
from import_export import *
from utils import *
from imagecaption import get_sorted_general_strings # Adjusted import
from imagecaption import get_sorted_general_strings
from wiki import generate_character_summary_from_fandom


def find_image_path():
Expand All @@ -18,6 +19,7 @@ def find_image_path():

return None # Return None if no image is found


image_path = find_image_path()

if image_path:
Expand Down Expand Up @@ -47,6 +49,7 @@ def generate_tags_and_set_prompt(image):
tags = get_sorted_general_strings(image)
return tags


def create_webui():
with gr.Blocks() as webui:
gr.Markdown("# Character Factory WebUI")
Expand Down Expand Up @@ -199,6 +202,54 @@ def create_webui():
],
outputs=image_input,
)

with gr.Tab("Wiki Character"):
with gr.Column():
wiki_url = gr.Textbox(label="Fandom Wiki URL", placeholder="Enter the URL of the character's wiki page")
wiki_character_name = gr.Textbox(label="Character Name", placeholder="Enter the character's name")
wiki_topic = gr.Textbox(label="Topic/Series",
placeholder="Enter the series or topic (e.g., 'The Legend of Zelda')")
wiki_gender = gr.Textbox(label="Gender (optional)", placeholder="Enter the character's gender if known")
wiki_appearance = gr.Textbox(label="Appearance (optional)",
placeholder="Enter any specific appearance details")
wiki_nsfw = gr.Checkbox(label="Include NSFW content", value=False)

wiki_generate_button = gr.Button("Generate Character Summary from Wiki")

wiki_summary_output = gr.Textbox(label="Generated Character Summary", lines=10)

wiki_generate_button.click(
generate_character_summary_from_fandom,
inputs=[wiki_url, wiki_character_name, wiki_topic, wiki_gender, wiki_appearance, wiki_nsfw],
outputs=wiki_summary_output
)

wiki_update_button = gr.Button("Update Character with Wiki Summary")

wiki_update_button.click(
lambda wiki_name, wiki_summary: (wiki_name, wiki_summary),
inputs=[wiki_character_name, wiki_summary_output],
outputs=[name, summary]
)

def handle_wiki_generate(fandom_url, character_name, topic, gender, appearance, nsfw):
if not fandom_url:
return gr.Textbox.update(value="Error: Please provide a valid Fandom Wiki URL.", visible=True)

summary = generate_character_summary_from_fandom(fandom_url, character_name, topic, gender,
appearance, nsfw)

if summary.startswith("Error:") or summary.startswith("An error occurred"):
return gr.Textbox.update(value=summary, visible=True)
else:
return gr.Textbox.update(value=summary, visible=True)

wiki_generate_button.click(
handle_wiki_generate,
inputs=[wiki_url, wiki_character_name, wiki_topic, wiki_gender, wiki_appearance, wiki_nsfw],
outputs=wiki_summary_output
)

with gr.Tab("Import character"):
with gr.Column():
with gr.Row():
Expand Down Expand Up @@ -282,6 +333,7 @@ def create_webui():

return webui


# Add this at the end of the file to launch the interface
webui = create_webui()
webui.launch(debug=True)
9 changes: 1 addition & 8 deletions app 2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,15 @@
from diffusers import StableDiffusionPipeline
import torch
import sys

llm = None
sd = None
safety_checker_sd = None
global_avatar_prompt = ""
processed_image_path = None
global_url = None
global_url = "http://localhost:5001/api/v1/completions"

folder_path = "models" # Base directory for models


def load_model(model_name, use_safetensors=False, use_local=False):
global sd
# Enable TensorFloat-32 for matrix multiplications
Expand Down Expand Up @@ -44,13 +42,11 @@ def load_model(model_name, use_safetensors=False, use_local=False):
else:
print(f"Loaded {model_name} to CPU.")


def process_url(url):
global global_url
global_url = url.rstrip("/") + "/api/v1/completions" # Append '/v1/completions' to the URL
return f"URL Set: {global_url}" # Return the modified URL


def send_message(prompt):
global global_url
if not global_url:
Expand Down Expand Up @@ -113,15 +109,13 @@ def send_message(prompt):
except requests.RequestException as e:
return f"Error sending request: {e}"


def input_none(text):
user_input = text
if user_input == "":
return None
else:
return user_input


def combined_avatar_prompt_action(prompt):
# First, update the avatar prompt
global global_avatar_prompt
Expand All @@ -134,7 +128,6 @@ def combined_avatar_prompt_action(prompt):
# Return both messages or just one, depending on how you want to display the outcome
return update_message, use_message


# Load the model
load_model("oof.safetensors", use_safetensors=True, use_local=True)

Expand Down
120 changes: 120 additions & 0 deletions app 2/vector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import requests
from bs4 import BeautifulSoup
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import nltk
from nltk.tokenize import sent_tokenize
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# Download necessary NLTK data
nltk.download('punkt', quiet=True)


class AdvancedVectorStorage:
def __init__(self):
self.model = SentenceTransformer('all-MiniLM-L6-v2')
self.vectors = []
self.texts = []

def add_text(self, text):
chunks = self.chunk_text(text)
for chunk in chunks:
self.texts.append(chunk)
self.vectors.append(self.model.encode(chunk))

def chunk_text(self, text, max_length=200):
words = text.split()
chunks = []
current_chunk = []
current_length = 0
for word in words:
if current_length + len(word) > max_length and current_chunk:
chunks.append(' '.join(current_chunk))
current_chunk = []
current_length = 0
current_chunk.append(word)
current_length += len(word) + 1 # +1 for space
if current_chunk:
chunks.append(' '.join(current_chunk))
return chunks

def query(self, query_text, top_k=5):
query_vector = self.model.encode(query_text)
similarities = cosine_similarity([query_vector], self.vectors)[0]
top_indices = similarities.argsort()[-top_k:][::-1]
return [(self.texts[i], similarities[i]) for i in top_indices]


def scrape_website(url):
response = requests.get(url)
soup = BeautifulSoup(response.content, 'html.parser')

main_content = soup.find('div', class_='mw-parser-output')

if main_content:
text_content = []
for element in main_content.find_all(['p', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6']):
text_content.append(element.get_text())
return ' '.join(text_content)
else:
return soup.get_text()


def answer_question_with_llm(full_context, embedding_context, question):
device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = "jinaai/reader-lm-0.5b"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)

prompt = f"""Please answer the following question based solely on the given context. If the answer is not in the context, say "I don't have enough information to answer this question."
Context: {embedding_context}
Question: {question}
Answer:"""

inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048).to(device)

outputs = model.generate(
**inputs,
max_new_tokens=50,
do_sample=True,
temperature=0.3,
top_p=0.95,
repetition_penalty=1
)

return tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True).strip()


def answer_question(storage, full_context, question):
results = storage.query(question, top_k=5)
embedding_context = " ".join([text for text, _ in results])

print("Relevant context from embedding model:")
for text, similarity in results:
print(f"- {text} (Similarity: {similarity:.2f})")

llm_answer = answer_question_with_llm(full_context, embedding_context, question)

return llm_answer


# Example usage
storage = AdvancedVectorStorage()

url = 'https://zelda.fandom.com/wiki/Link'
content = scrape_website(url)

print(f"Content length: {len(content)}")

storage.add_text(content)

question = "Who is Link? What is Link personality? What is link summary? Link gender? What does Link look like? What is Link appearance?"
answer = answer_question(storage, content, question)

print(f"\nQuestion: {question}")
print(f"LLM Answer: {answer}")
Loading

0 comments on commit 9362b9d

Please sign in to comment.