-
Notifications
You must be signed in to change notification settings - Fork 119
/
Copy pathmain.py
117 lines (97 loc) · 3.57 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# %% [markdown]
# # Code documentation Q&A bot example with LangChain
#
# This Q&A bot will allow you to query your own documentation easily using questions. We'll also demonstrate the use of LangChain and LanceDB using the OpenAI API.
#
# In this example we'll use Pandas 2.0 documentation, but, this could be replaced for your own docs as well
import os
import openai
import argparse
import lancedb
import re
import pickle
import requests
import zipfile
from pathlib import Path
from langchain.document_loaders import BSHTMLLoader
from langchain.embeddings import OpenAIEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import LanceDB
from langchain.llms import OpenAI
from langchain.chains import RetrievalQA
def get_document_title(document):
m = str(document.metadata["source"])
title = re.findall("pandas.documentation(.*).html", m)
if title[0] is not None:
return title[0]
return ""
def arg_parse():
default_query = "What are the major differences in pandas 2.0?"
parser = argparse.ArgumentParser(description="Code Documentation QA Bot")
parser.add_argument(
"--query", type=str, default=default_query, help="query to search"
)
parser.add_argument("--openai-key", type=str, help="OpenAI API Key")
args = parser.parse_args()
if not args.openai_key:
if "OPENAI_API_KEY" not in os.environ:
raise ValueError(
"OPENAI_API_KEY environment variable not set. Please set it or pass --openai_key"
)
else:
openai.api_key = args.openai_key
return args
if __name__ == "__main__":
args = arg_parse()
docs_path = Path("docs.pkl")
docs = []
pandas_docs = requests.get(
"https://eto-public.s3.us-west-2.amazonaws.com/datasets/pandas_docs/pandas.documentation.zip"
)
with open("/tmp/pandas.documentation.zip", "wb") as f:
f.write(pandas_docs.content)
file = zipfile.ZipFile("/tmp/pandas.documentation.zip")
file.extractall(path="/tmp/pandas_docs")
if not docs_path.exists():
for p in Path("/tmp/pandas_docs/pandas.documentation").rglob("*.html"):
print(p)
if p.is_dir():
continue
loader = BSHTMLLoader(p, open_encoding="utf8")
raw_document = loader.load()
m = {}
m["title"] = get_document_title(raw_document[0])
m["version"] = "2.0rc0"
raw_document[0].metadata = raw_document[0].metadata | m
raw_document[0].metadata["source"] = str(raw_document[0].metadata["source"])
docs = docs + raw_document
with docs_path.open("wb") as fh:
pickle.dump(docs, fh)
else:
with docs_path.open("rb") as fh:
docs = pickle.load(fh)
print("Loaded {} documents".format(len(docs)))
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200,
)
documents = text_splitter.split_documents(docs)
embeddings = OpenAIEmbeddings()
db = lancedb.connect("/tmp/lancedb")
table = db.create_table(
"pandas_docs",
data=[
{
"vector": embeddings.embed_query("Hello World"),
"text": "Hello World",
"id": "1",
}
],
mode="overwrite",
)
docsearch = LanceDB.from_documents(documents, embeddings, connection=table)
qa = RetrievalQA.from_chain_type(
llm=OpenAI(), chain_type="stuff", retriever=docsearch.as_retriever()
)
result = qa.run(args.query)
print(result)