From ea667ae7e479b41b24b5562a0b1a15b9f90f62da Mon Sep 17 00:00:00 2001 From: sangwonYoon Date: Mon, 20 May 2024 14:08:18 +0900 Subject: [PATCH] fix: fix bug in RAG --- backend/app/retrieval/rag.py | 21 ++++++++++++--------- backend/app/retrieval/rag_prototype.py | 2 +- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/backend/app/retrieval/rag.py b/backend/app/retrieval/rag.py index 4ec273456e..e320a12fe4 100644 --- a/backend/app/retrieval/rag.py +++ b/backend/app/retrieval/rag.py @@ -19,9 +19,12 @@ def retrieve_doc(question, pc_index, embedding, llm): retriever=vector_db.as_retriever(), llm=llm ) + logging.basicConfig() + logging.getLogger("langchain.retrievers.multi_query").setLevel(logging.INFO) + retrieved_docs = retriever_from_llm.get_relevant_documents(query=question) - doc_ids = [doc.metadata["doc-id"] for doc in retrieved_docs] + doc_ids = list(set([doc.metadata["doc-id"] for doc in retrieved_docs])) docs = [] for doc_id in doc_ids: @@ -31,14 +34,14 @@ def retrieve_doc(question, pc_index, embedding, llm): "doc-id": {"$eq": doc_id}, "section_title": {"$eq": "통사론"}, }, - top_k=1, + top_k=3, include_metadata=True ) - if result_syntax["matches"]: + for idx in range(len(result_syntax["matches"])): document = { - "title": result_syntax["matches"][0]["metadata"]["title"], - "content": result_syntax["matches"][0]["metadata"]["content"], + "title": result_syntax["matches"][idx]["metadata"]["title"], + "content": result_syntax["matches"][idx]["metadata"]["content"], } docs.append(document) @@ -48,14 +51,14 @@ def retrieve_doc(question, pc_index, embedding, llm): "doc-id": {"$eq": doc_id}, "section_title": {"$eq": "속성"}, }, - top_k=1, + top_k=3, include_metadata=True ) - if result_prop["matches"]: + for idx in range(len(result_prop["matches"])): document = { - "title": result_prop["matches"][0]["metadata"]["title"], - "content": result_prop["matches"][0]["metadata"]["content"], + "title": result_prop["matches"][idx]["metadata"]["title"], + "content": result_prop["matches"][idx]["metadata"]["content"], } docs.append(document) diff --git a/backend/app/retrieval/rag_prototype.py b/backend/app/retrieval/rag_prototype.py index bbfa62287d..c9810ddeea 100644 --- a/backend/app/retrieval/rag_prototype.py +++ b/backend/app/retrieval/rag_prototype.py @@ -21,4 +21,4 @@ question = st.text_input('질문을 입력하세요:') if question: - st.write(retrieve_doc(question, pc_index, embedding, llm=llm, top_k=5)) + st.write(retrieve_doc(question, pc_index, embedding, llm=llm))