-
Notifications
You must be signed in to change notification settings - Fork 51
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
193 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# MariaDB KB Chat and Vector Store Generator | ||
|
||
This script scrapes web pages from the MariaDB Knowledge Base, cleans and processes the content, and then generates a FAISS index using the OpenAI embeddings for each document. The vector store is saved as a pickle file and then used by a chatbot to answer questions about the MariaDB server. | ||
|
||
## Requirements | ||
|
||
Install the required packages with the following command: | ||
|
||
pip install argparse bs4 dotenv faiss-cpu openai requests numpy streamlit | ||
|
||
## Setup | ||
|
||
1. Download the MariaDB KB CSV file from https://github.com/Icerath/mariadb_kb_server/blob/main/kb_urls.csv | ||
2. Create a `.env` file in the same directory as the script. | ||
3. Add your OpenAI API key to the `.env` file as follows: | ||
|
||
OPENAI_API_KEY=your_api_key_here | ||
|
||
## Preprocessing | ||
|
||
Run the script with the following command: | ||
|
||
python create_vectorestore.py --csv-file kb_urls.csv --tmp-dir tmp --md-dir md --vectorstore-path vectorstore.pkl --chunk-size 4000 --chunk-overlap 200 | ||
|
||
This will create a file `vectorestore.pkl` which is used to answer questions | ||
|
||
## Run chat | ||
|
||
streamlit run chat.py | ||
|
||
Now, you will have a self hosted version of the chat over the MariaDB KB. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import streamlit as st | ||
import pickle | ||
from langchain.vectorstores import FAISS | ||
from dotenv import load_dotenv | ||
import openai | ||
import os | ||
|
||
load_dotenv() | ||
|
||
openai.api_key = os.getenv("OPENAI_API_KEY") | ||
|
||
def gen_prompts(content, question): | ||
system_msg_content = "You are a questioning answering expert about MariaDB. You only respond based on the facts that are given to you and ignore your prior knowledge." | ||
user_msg_content = f"{content}\n---\n\nGiven the above content about MariaDB along with the URL of the content, respond to this question {question} and mention the URL as a source. If the question is not about MariaDB and you cannot answer it based on the provided content, politely decline to answer. Simply state that you couldn't find any relevant information instead of going into details. Do not say the phrase 'in the provided content'. If the information I provide contains the word obsolete, emphasize that the response is obsolete. Also, suggest newer MariaDB versions if the question is about versions older than 10.3 and say that the others are no longer maintained. Do not add the URL as a source if you cannot answer based on the provided content. If there are exceptions for particular MariaDB version, specify the exceptions that apply. Also, if the provided score is lower than 0.2 decline to answer and say you found no relevant information. If the source URL repeats, only use it once." | ||
system_msg = {"role": "system", "content": system_msg_content} | ||
user_msg = {"role": "user", "content": user_msg_content} | ||
|
||
return system_msg, user_msg | ||
|
||
def process_doc(content, question, model_type="gpt-4", max_tokens=30000): | ||
if len(content) > max_tokens: | ||
print('Trimmed') | ||
content = content[:max_tokens] | ||
system_msg, user_msg = gen_prompts(content, question) | ||
|
||
try: | ||
response = openai.ChatCompletion.create( | ||
model=model_type, | ||
messages=[system_msg, user_msg], | ||
) | ||
except Exception as e: | ||
return "Sorry, there was an error. Please try again!" | ||
|
||
result = response.choices[0].message['content'] | ||
return result | ||
|
||
with open("vectorstore.pkl", "rb") as f: | ||
faiss_index = pickle.load(f) | ||
|
||
def search_similar_docs(question, k=4): | ||
docs = faiss_index.similarity_search_with_score(question, k=k) | ||
docs_with_url = [] | ||
for doc in docs: | ||
url = doc[0].metadata["source"] | ||
doc[0].page_content = f"URL: {url}\n{doc[0].page_content}\nSCORE:{doc[1]}\n" | ||
docs_with_url.append(doc[0]) | ||
print(docs) | ||
return docs_with_url | ||
|
||
def main(): | ||
st.title("MariaDB KB Chatbot") | ||
|
||
if 'chat_history' not in st.session_state: | ||
st.session_state.chat_history = [] | ||
|
||
user_input = st.text_input("Ask a question:", "") | ||
if st.button("Send"): | ||
st.session_state.chat_history.append(("User", user_input)) | ||
results = process_doc(search_similar_docs(user_input), user_input) | ||
|
||
st.session_state.chat_history.append(("Bot", results)) | ||
|
||
for role, message in st.session_state.chat_history: | ||
if role == "User": | ||
st.markdown(f"> **{role}**: {message}") | ||
else: | ||
st.markdown(f"**{role}**: {message}") | ||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import argparse | ||
import pickle | ||
import os | ||
import csv | ||
import openai | ||
import re | ||
import requests | ||
|
||
from langchain.document_loaders import BSHTMLLoader | ||
from langchain.vectorstores import FAISS | ||
from langchain.text_splitter import CharacterTextSplitter | ||
from langchain.embeddings.openai import OpenAIEmbeddings | ||
from dotenv import load_dotenv | ||
|
||
load_dotenv() | ||
openai.api_key = os.getenv("OPENAI_API_KEY") | ||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser(description='MariaDB KB Vector Store Generator') | ||
parser.add_argument('--csv-file', type=str, default='kb_urls.csv', help='Path to the input CSV file containing the URLs') | ||
parser.add_argument('--tmp-dir', type=str, default='tmp', help='Directory where the temporary HTML files will be stored') | ||
parser.add_argument('--md-dir', type=str, default='md', help='Directory where the output Markdown files will be stored') | ||
parser.add_argument('--vectorstore-path', type=str, default='vectorstore.pkl', help='Path to save the generated FAISS vector store pickle file') | ||
parser.add_argument('--chunk-size', type=int, default=4000, help='Chunk size for splitting the documents') | ||
parser.add_argument('--chunk-overlap', type=int, default=200, help='Overlap size between chunks when splitting documents') | ||
return parser.parse_args() | ||
|
||
def download_web_page(url): | ||
response = requests.get(url) | ||
|
||
if response.status_code == 200: | ||
content = response.text | ||
filename = url.replace('://', '_').replace('/', '_') + '.html' | ||
|
||
with open('./tmp/' + filename, 'w', encoding='utf-8') as file: | ||
file.write(content) | ||
else: | ||
print(f"Error: Unable to fetch the web page. Status code: {response.status_code}") | ||
|
||
def read_csv(csv_file): | ||
urls = [] | ||
|
||
with open(csv_file, newline='', encoding='utf-8') as csvfile: | ||
csv_reader = csv.reader(csvfile) | ||
for row in csv_reader: | ||
if row[0].strip(): | ||
urls.append(row[0]) | ||
|
||
return urls[1:] | ||
|
||
def main(): | ||
args = parse_args() | ||
|
||
urls = read_csv(args.csv_file) | ||
all_docs = [] | ||
idx = 0 | ||
for url in urls: | ||
filename = url.replace('://', '_').replace('/', '_').strip() + '.html' | ||
doc_path = args.tmp_dir + '/' + filename | ||
if not os.path.exists(doc_path): | ||
download_web_page(url) | ||
loader = BSHTMLLoader(doc_path) | ||
doc = loader.load()[0] | ||
|
||
content = re.sub(r'\s+', ' ', doc.page_content) | ||
doc.page_content = content | ||
doc.metadata["source"] = url | ||
|
||
md_filename = os.path.join(args.md_dir, f'{filename}.md') | ||
|
||
with open(md_filename, 'w', encoding='utf-8') as md_file: | ||
md_file.write(doc.page_content) | ||
|
||
all_docs.append(doc) | ||
|
||
text_splitter = CharacterTextSplitter( | ||
separator = " ", | ||
chunk_size = args.chunk_size, | ||
chunk_overlap = args.chunk_overlap, | ||
length_function = len, | ||
) | ||
print("Loaded {} documents".format(len(all_docs))) | ||
all_docs = text_splitter.split_documents(all_docs) | ||
print("After split: {} documents".format(len(all_docs))) | ||
|
||
faiss_index = FAISS.from_documents(all_docs, OpenAIEmbeddings()) | ||
|
||
with open(args.vectorstore_path, "wb") as f: | ||
pickle.dump(faiss_index, f) | ||
|
||
if __name__ == "__main__": | ||
main() |