-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathchat_with_doc.py
115 lines (95 loc) · 4.37 KB
/
chat_with_doc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
from dotenv import load_dotenv, find_dotenv
from pathlib import Path
import streamlit as st
import tiktoken
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.chat_models import ChatOpenAI
from langchain.chains import RetrievalQA
from langchain.vectorstores import Chroma
from langchain.document_loaders import PyPDFLoader, Docx2txtLoader, TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
def load_document(file):
name, extension = os.path.splitext(file)
if extension == '.pdf':
print(f'loading {file}')
loader = PyPDFLoader(file)
elif extension == '.docx':
print(f'loading {file}')
loader = Docx2txtLoader(file)
elif extension == '.txt':
print(f'loading {file}')
loader = TextLoader(file)
else:
print(f"Document format for {file} not supported")
return None
data = loader.load()
return data
def chunk_data(data, chunk_size=256, chunk_overlap=20):
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size,
chunk_overlap=chunk_overlap)
chunks = text_splitter.split_documents(data)
return chunks
def create_embeddings(chunks):
embeddings = OpenAIEmbeddings()
vector_store = Chroma.from_documents(chunks, embeddings)
return vector_store
def ask_and_get_answer(vector_store, q, k=3):
llm = ChatOpenAI(model='gpt-3.5-turbo', temperature=1)
retriever = vector_store.as_retriever(search_type='similarity',
search_kwargs={'k': k})
chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff",
retriever=retriever)
answer = chain.run(q)
return answer
def calculate_embedding_cost(texts):
enc = tiktoken.encoding_for_model('text-embedding-ada-002')
total_tokens = sum([len(enc.encode(page.page_content)) for page in texts])
return total_tokens, total_tokens/1000 * 0.0004
def clear_history():
if 'history' in st.session_state:
del st.session_state['history']
if __name__ == "__main__":
load_dotenv(find_dotenv(), override=True)
st.image('static/img.png')
st.subheader("LLM Question-Answering App")
with st.sidebar:
api_key = st.text_input('OpenAI API key:', type="password")
if api_key:
os.environ["OpenAI_API_KEY"] = api_key
uploaded_file = st.file_uploader("upload a file",
type=['pdf', 'docx', 'txt'])
chunk_size = st.number_input('chunk size', min_value=100,
max_value=2048, value=512,
on_change=clear_history)
k = st.number_input('k', min_value=1, max_value=20, value=3,
on_change=clear_history)
add_data = st.button('Add data', on_click=clear_history)
if uploaded_file and add_data:
with st.spinner('Reading, chunking and embedding file ...'):
bytes_data = uploaded_file.read()
file_name = os.path.join('./', uploaded_file.name)
with open(file_name, 'wb') as f:
f.write(bytes_data)
data = load_document(file_name)
chunks = chunk_data(data, chunk_size=chunk_size)
st.write(f"Chunk size: {chunk_size}, Chunks: {len(chunks)}")
tokens, embedding_cost = calculate_embedding_cost(chunks)
st.write(f'Embedding cost: ${embedding_cost:.4f}')
vector_store = create_embeddings(chunks)
st.session_state.vs = vector_store
st.success('File uploaded, chunked and embedded successfully')
q = st.text_input("Ask a question about the content of your file")
if q:
if 'vs' in st.session_state:
vector_store = st.session_state.vs
st.write(f'k: {k}')
answer = ask_and_get_answer(vector_store, q, k)
st.text_area('LLM Answer: ', value=answer)
st.divider()
if 'history' not in st.session_state:
st.session_state.history = ''
value = f'Question: {q} \nAnswer: {answer}'
st.session_state.history = f'{value} \n {"*" * 50} \n {st.session_state.history}'
h = st.session_state.history
st.text_area(label='Chat History', value=h, key='history', height=400)