-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Initial commit for Building Art Deco RAG ChatBot using PulseJet
- Loading branch information
0 parents
commit f47e22f
Showing
21 changed files
with
2,149 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
__pycache__/ | ||
indexing.log | ||
embeddings_data/ | ||
evaluation/questions2.csv | ||
secrets.yaml | ||
.idea/ | ||
rag_files/* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
|
||
ollama: | ||
docker run -d -v ollama:/root/.ollama -p 11434:11434 --name ollama ollama/ollama |
Large diffs are not rendered by default.
Oops, something went wrong.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import os | ||
import time | ||
from litellm import completion | ||
import logging | ||
from file_utils import get_config, read_questions | ||
from data_saving import save_answers_json, save_answers_csv, save_answers_html, save_answers_markdown | ||
import rag | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def generate_answers(questions, config, clients): | ||
answers_data = [] | ||
total_questions = len(questions) | ||
for idx, question in enumerate(questions, 1): | ||
print(f"Processing question {idx}/{total_questions}: '{question}'") | ||
question_answers = {'question': question, 'answers': []} | ||
for model_name, client in clients.items(): | ||
print(f"Querying {model_name}...") | ||
result = client(question) | ||
answer = result['response'] | ||
llm_duration = max(int(result['llm_duration'] * 1000), -1) | ||
rag_duration = max(int(result['rag_duration'] * 1000), -1) | ||
question_answers['answers'].append({'model': model_name, 'answer': answer, | ||
'llm_duration': llm_duration, 'rag_duration': rag_duration}) | ||
answers_data.append(question_answers) | ||
print(f"Completed question {idx}/{total_questions}.\n") | ||
return answers_data | ||
|
||
|
||
def ask_llm(model, query): | ||
base_url = None | ||
if model.startswith('ollama'): | ||
base_url = "http://localhost:11434" | ||
|
||
start_time = time.perf_counter() | ||
response = completion( | ||
model=model, | ||
messages=[ | ||
{"role": "user", "content": query}, | ||
], | ||
api_base=base_url | ||
) | ||
end_time = time.perf_counter() | ||
duration = end_time - start_time | ||
|
||
return {"response": response.choices[0].message.content, "llm_duration": duration, "rag_duration": -1} | ||
|
||
|
||
def print_and_return(result): | ||
print("RAG Response:") | ||
print(result['response']) | ||
print(f"LLM Duration: {result['llm_duration']:.2f} seconds") | ||
print(f"RAG Duration: {result['rag_duration']:.2f} seconds") | ||
print("--------------------") | ||
return result | ||
|
||
|
||
def main(): | ||
config = get_config() | ||
file_path = config['questions_file_path'] | ||
questions = read_questions(file_path) | ||
|
||
os.environ["OPENAI_API_KEY"] = config['openai_key'] | ||
os.environ['GROQ_API_KEY'] = config['groq_key'] | ||
|
||
all_models = config['all_models'] | ||
selected_models = config['selected_models'] | ||
|
||
clients = {} | ||
for model in selected_models: | ||
clients[model] = lambda q, m=model: print_and_return( | ||
ask_llm(all_models[m], q)) | ||
clients['ollama_rag'] = lambda q: print_and_return( | ||
rag.rag(config, q)) | ||
|
||
try: | ||
answers_data = generate_answers(questions, config, clients) | ||
save_answers_json(answers_data, os.path.join( | ||
config['evaluation_path'], 'answers.json')) | ||
save_answers_csv(answers_data, os.path.join( | ||
config['evaluation_path'], 'answers.csv')) | ||
save_answers_html(answers_data, os.path.join( | ||
config['evaluation_path'], 'answers.html')) | ||
save_answers_markdown(answers_data, os.path.join( | ||
config['evaluation_path'], 'answers.md')) | ||
except Exception as e: | ||
logger.exception("An error occurred during execution:") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
#api_keys: | ||
openai_key: "" | ||
groq_key: "" | ||
|
||
#models: | ||
main_model: "llama3.1" | ||
embed_model: "nomic-embed-text" | ||
|
||
#vector_db: | ||
vector_db: "pulsejet" | ||
|
||
#pulsejet: | ||
pulsejet_location: "remote" | ||
pulsejet_collection_name: "art_deco" | ||
|
||
#paths: | ||
rag_files_path: "rag_files/" | ||
questions_file_path: "evaluation/questions.csv" | ||
evaluation_path: "evaluation/" | ||
rag_prompt_path: "evaluation/rag_prompt.txt" | ||
metrics_file_path: "evaluation/metrics.json" | ||
|
||
#embeddings: | ||
embeddings_file_path: "embeddings_data/all_embeddings_HSNW.h5" | ||
use_precalculated_embeddings: true | ||
|
||
#llm_models: | ||
all_models: | ||
gpt-4o: "gpt-4o" | ||
groq-llama3.1-8b: "groq/llama-3.1-8b-instant" | ||
groq-llama3.1-70b: "groq/llama-3.1-70b-versatile" | ||
ollama-llama3.1: "ollama/llama3.1" | ||
ollama-llama3.1-70b: "ollama/llama3.1:70b" | ||
|
||
selected_models: | ||
- "gpt-4o" | ||
- "groq-llama3.1-70b" | ||
- "ollama-llama3.1" | ||
|
||
#rag_parameters: | ||
sentences_per_chunk: 10 | ||
chunk_overlap: 2 | ||
file_extension: ".txt" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
import json | ||
import csv | ||
|
||
|
||
def save_answers_json(answers, output_path): | ||
with open(output_path, 'w') as file: | ||
json.dump(answers, file, indent=4) | ||
|
||
|
||
def save_answers_csv(json_data, output_path): | ||
with open(output_path, mode='w', newline='') as file: | ||
writer = csv.writer(file) | ||
header = ['Question'] | ||
# Add a column for each model in the first entry's answers | ||
if json_data: | ||
for answer in json_data[0]['answers']: | ||
header.append(answer['model']) | ||
if answer['llm_duration'] != -1: | ||
header.append(answer['model'] + ' LLM Duration') | ||
if answer['rag_duration'] != -1: | ||
header.append(answer['model'] + ' RAG Duration') | ||
writer.writerow(header) | ||
|
||
for item in json_data: | ||
row = [item['question']] | ||
for answer in item['answers']: | ||
row.append(answer['answer']) | ||
if answer['llm_duration'] != -1: | ||
row.append(answer['llm_duration']) | ||
if answer['rag_duration'] != -1: | ||
row.append(answer['rag_duration']) | ||
writer.writerow(row) | ||
|
||
|
||
def save_answers_html(json_data, output_path): | ||
if json_data: | ||
# Start with one column for questions | ||
num_columns = 1 + len(json_data[0]['answers']) | ||
for answer in json_data[0]['answers']: | ||
if answer['llm_duration'] != -1: | ||
num_columns += 1 | ||
if answer['rag_duration'] != -1: | ||
num_columns += 1 | ||
|
||
# Calculate the percentage width for each column | ||
col_width = 100 / num_columns | ||
|
||
table_style = 'width: 100%; border="1" style="border-collapse: collapse;"' | ||
th_style = f'style="padding: 8px; vertical-align: top; width: {col_width}%;"' | ||
td_style = f'style="padding: 8px; vertical-align: top; width: {col_width}%;"' | ||
|
||
html_content = f'<table {table_style}>\n<tr><th {th_style}>Questions</th>' | ||
if json_data: | ||
for answer in json_data[0]['answers']: | ||
html_content += f'<th {th_style}>{answer["model"]}</th>' | ||
if answer['llm_duration'] != -1: | ||
html_content += f'<th {th_style}>{answer["model"]} LLM Duration</th>' | ||
if answer['rag_duration'] != -1: | ||
html_content += f'<th {th_style}>{answer["model"]} RAG Duration</th>' | ||
html_content += '</tr>\n' | ||
|
||
for item in json_data: | ||
row = f'<tr><td {td_style}>{item["question"]}</td>' | ||
for answer in item['answers']: | ||
row += f'<td {td_style}>{format_html(answer["answer"])}</td>' | ||
if answer['llm_duration'] != -1: | ||
row += f'<td {td_style}>{answer["llm_duration"]}</td>' | ||
if answer['rag_duration'] != -1: | ||
row += f'<td {td_style}>{answer["rag_duration"]}</td>' | ||
row += '</tr>\n' | ||
html_content += row | ||
html_content += '</table>' | ||
|
||
with open(output_path, 'w') as file: | ||
file.write(html_content) | ||
|
||
|
||
def format_html(text): | ||
"A more comprehensive function to format text with HTML tags based on Markdown syntax including lists." | ||
# Define replacements for simple Markdown syntax | ||
replacements = { | ||
'**': '<b>', | ||
'__': '<b>', | ||
'*': '<i>', | ||
'_': '<i>', | ||
'```': '<code>', | ||
'`': '<code>', | ||
'> ': '<blockquote>', | ||
'\n': '<br>', | ||
'# ': '<h1>', | ||
'## ': '<h2>', | ||
'### ': '<h3>', | ||
'#### ': '<h4>', | ||
'##### ': '<h5>', | ||
'###### ': '<h6>', | ||
} | ||
|
||
# Apply replacements | ||
for md, html in replacements.items(): | ||
text = text.replace(md, html) | ||
|
||
# Handle unordered lists | ||
lines = text.split('<br>') | ||
in_list = False | ||
for i, line in enumerate(lines): | ||
if line.startswith('* ') or line.startswith('- ') or line.startswith('+ '): | ||
if not in_list: | ||
lines[i] = '<ul><li>' + line[2:] + '</li>' | ||
in_list = True | ||
else: | ||
lines[i] = '<li>' + line[2:] + '</li>' | ||
else: | ||
if in_list: | ||
lines[i - 1] = lines[i - 1] + '</ul>' | ||
in_list = False | ||
|
||
if in_list: | ||
lines[-1] += '</ul>' | ||
|
||
# Handle ordered lists | ||
in_list = False | ||
for i, line in enumerate(lines): | ||
if line.lstrip().startswith(tuple(f'{num}.' for num in range(1, 10))): | ||
if not in_list: | ||
lines[i] = '<ol><li>' + line.split('. ', 1)[1] + '</li>' | ||
in_list = True | ||
else: | ||
lines[i] = '<li>' + line.split('. ', 1)[1] + '</li>' | ||
else: | ||
if in_list: | ||
lines[i - 1] = lines[i - 1] + '</ol>' | ||
in_list = False | ||
|
||
if in_list: | ||
lines[-1] += '</ol>' | ||
|
||
return '<br>'.join(lines) | ||
|
||
|
||
def save_answers_markdown(json_data, output_path): | ||
with open(output_path, 'w') as file: | ||
# Write the table header | ||
header = '| Question |' | ||
if json_data: | ||
for answer in json_data[0]['answers']: | ||
header += f" {answer['model']} |" | ||
if answer['llm_duration'] != -1: | ||
header += f" {answer['model']} LLM Duration |" | ||
if answer['rag_duration'] != -1: | ||
header += f" {answer['model']} RAG Duration |" | ||
file.write(header + '\n') | ||
|
||
# Write the separator line | ||
separator = '|' + '---|' * (header.count('|') - 1) + '\n' | ||
file.write(separator) | ||
|
||
# Write the data rows | ||
for item in json_data: | ||
row = f"| {escape_markdown(item['question'])} |" | ||
for answer in item['answers']: | ||
row += f" {escape_markdown(answer['answer'])} |" | ||
if answer['llm_duration'] != -1: | ||
row += f" {answer['llm_duration']} ms |" | ||
if answer['rag_duration'] != -1: | ||
row += f" {answer['rag_duration']} ms |" | ||
file.write(row + '\n') | ||
|
||
|
||
def escape_markdown(text): | ||
"""Escapes markdown special characters and formats for table cells.""" | ||
text = text.replace('|', '\\|') | ||
text = text.replace('\n', ' ') # Replace newlines with spaces | ||
text = text.replace('\r', '') # Remove carriage returns | ||
return text |
Oops, something went wrong.