-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathchatbot_logic.py
134 lines (111 loc) · 4.85 KB
/
chatbot_logic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import os
from dotenv import load_dotenv
from pinecone import Pinecone
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.output_parsers import StrOutputParser
from langchain.chains import create_history_aware_retriever
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains import create_retrieval_chain
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_core.runnables import RunnableLambda
from langchain_community.document_transformers import LongContextReorder
from langchain_pinecone import PineconeVectorStore
from operator import itemgetter
# .env 파일에서 환경 변수 로드
load_dotenv()
# 필요한 환경 변수 불러오기
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
# Pinecone 설정
def initialize_pinecone():
pc = Pinecone(api_key=PINECONE_API_KEY)
index_name = "card-chatbot"
# 인덱스 가져오기
index = pc.Index(index_name)
# OpenAI 임베딩 로드
embeddings = OpenAIEmbeddings(
model="text-embedding-ada-002", # OpenAI의 임베딩 모델
api_key=OPENAI_API_KEY
)
# Pinecone VectorStore 생성
vectorstore = PineconeVectorStore(index=index, embedding=embeddings, text_key="page_content")
return vectorstore
def load_model():
model = ChatOpenAI(
temperature=0.1,
model_name="gpt-4o-mini",
api_key=OPENAI_API_KEY,
streaming=True
)
print("model loaded...")
return model
def rag_chain(vectorstore):
llm = load_model()
# 리트리버 설정
reranker_model = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-v2-m3")
compressor_15 = CrossEncoderReranker(model=reranker_model, top_n=15)
vs_retriever30 = vectorstore.as_retriever(search_kwargs={"k": 30})
retriever = ContextualCompressionRetriever(base_compressor=compressor_15, base_retriever=vs_retriever30)
# 리트리버 파이프라인
system_prompt = (
"Given a chat history and the latest user question "
"which might reference context in the chat history, "
"formulate a standalone question which can be understood "
"without the chat history. Please answer the question with new retrieved context, "
"just reformulate it if needed and otherwise return it as is."
)
contextualize_prompt = ChatPromptTemplate.from_messages([
("system", system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
])
history_aware_retriever_modified = create_history_aware_retriever(
llm,
retriever,
contextualize_prompt
)
# 문서 재정렬 추가
reordering = LongContextReorder()
my_retriever = (
{"input": itemgetter("input"),
"chat_history": itemgetter("chat_history")
} | history_aware_retriever_modified |
RunnableLambda(lambda docs: reordering.transform_documents(docs))
)
# LLM 체인 설정
qa_system_prompt = """You are an assistant helping with question-answering tasks.
Use the retrieved information to answer the questions.
If the information includes details like card_name or specific benefits, make sure to include them in your answer.
If you do not know the answer, simply say you don't know.
Please provide the answers in Korean.
{context}"""
qa_prompt = ChatPromptTemplate.from_messages([
("system", qa_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
])
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
# RAG 체인 생성
return create_retrieval_chain(my_retriever, question_answer_chain)
# 세션 기록을 저장할 딕셔너리
store = {}
# 세션 ID를 기반으로 세션 기록을 가져오는 함수
def get_session_history(session_ids):
if session_ids not in store: # 세션 ID가 store에 없는 경우
store[session_ids] = ChatMessageHistory()
return store[session_ids] # 해당 세션 ID에 대한 세션 기록 반환
def initialize_conversation(vectorstore):
base_rag_chain = rag_chain(vectorstore)
return RunnableWithMessageHistory(
base_rag_chain,
get_session_history,
input_messages_key="input", # 사용자의 질문이 템플릿 변수에 들어갈 key
history_messages_key="chat_history", # 기록 메시지의 키
output_messages_key="answer",
)