From e934e05ba6a78fe71386df0059b90cfe6da848bf Mon Sep 17 00:00:00 2001 From: Tilman Kerl Date: Mon, 15 Jan 2024 00:27:08 +0100 Subject: [PATCH] adding rag v1 --- chat_doc/rag/__init__.py | 0 chat_doc/rag/document_processing.py | 30 ++++++ chat_doc/rag/embedding_models.py | 12 +++ chat_doc/rag/llama_models.py | 27 +++++ chat_doc/rag/main.py | 126 ++++++++++++++++++++++++ chat_doc/rag/parse_data.py | 25 +++++ chat_doc/rag/retrieval.py | 74 ++++++++++++++ chat_doc/rag/vector_store_management.py | 44 +++++++++ 8 files changed, 338 insertions(+) create mode 100644 chat_doc/rag/__init__.py create mode 100644 chat_doc/rag/document_processing.py create mode 100644 chat_doc/rag/embedding_models.py create mode 100644 chat_doc/rag/llama_models.py create mode 100644 chat_doc/rag/main.py create mode 100644 chat_doc/rag/parse_data.py create mode 100644 chat_doc/rag/retrieval.py create mode 100644 chat_doc/rag/vector_store_management.py diff --git a/chat_doc/rag/__init__.py b/chat_doc/rag/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/chat_doc/rag/document_processing.py b/chat_doc/rag/document_processing.py new file mode 100644 index 0000000..1d38e63 --- /dev/null +++ b/chat_doc/rag/document_processing.py @@ -0,0 +1,30 @@ +from pathlib import Path + +from llama_index.node_parser.text import SentenceSplitter +from llama_index.schema import TextNode + + +class DocumentProcessor: + def __init__(self, loader, text_parser=SentenceSplitter(chunk_size=1024)): + self.loader = loader + self.text_parser = text_parser + + def load_documents(self, file_path): + return self.loader.load_data(file=Path(file_path)) + + def process_documents(self, documents): + text_chunks = [] + doc_idxs = [] + for doc_idx, doc in enumerate(documents): + cur_text_chunks = self.text_parser.split_text(doc.text) + text_chunks.extend(cur_text_chunks) + doc_idxs.extend([doc_idx] * len(cur_text_chunks)) + + nodes = [] + for idx, text_chunk in enumerate(text_chunks): + node = TextNode(text=text_chunk) + src_doc = documents[doc_idxs[idx]] + node.metadata = src_doc.metadata + nodes.append(node) + + return nodes diff --git a/chat_doc/rag/embedding_models.py b/chat_doc/rag/embedding_models.py new file mode 100644 index 0000000..fd4d10a --- /dev/null +++ b/chat_doc/rag/embedding_models.py @@ -0,0 +1,12 @@ +from llama_index.embeddings import HuggingFaceEmbedding + + +class EmbeddingModel: + def __init__(self, model_name): + self.embed_model = HuggingFaceEmbedding(model_name=model_name) + + def get_text_embedding(self, text): + return self.embed_model.get_text_embedding(text) + + def get_query_embedding(self, query): + return self.embed_model.get_query_embedding(query) diff --git a/chat_doc/rag/llama_models.py b/chat_doc/rag/llama_models.py new file mode 100644 index 0000000..ce41fa3 --- /dev/null +++ b/chat_doc/rag/llama_models.py @@ -0,0 +1,27 @@ +from llama_index import ServiceContext +from llama_index.llms import LlamaCPP + + +class LlamaModel: + def __init__( + self, + model_url, + model_path=None, + temperature=0.1, + max_new_tokens=256, + context_window=3900, + model_kwargs={"n_gpu_layers": 1}, + verbose=True, + ): + self.llm = LlamaCPP( + model_url=model_url, + model_path=model_path, + temperature=temperature, + max_new_tokens=max_new_tokens, + context_window=context_window, + model_kwargs=model_kwargs, + verbose=verbose, + ) + + def create_service_context(self, embed_model): + return ServiceContext.from_defaults(llm=self.llm, embed_model=embed_model) diff --git a/chat_doc/rag/main.py b/chat_doc/rag/main.py new file mode 100644 index 0000000..440cded --- /dev/null +++ b/chat_doc/rag/main.py @@ -0,0 +1,126 @@ +from llama_index import download_loader +from llama_index.query_engine import RetrieverQueryEngine +from tqdm import tqdm + +from chat_doc.config import BASE_DIR +from chat_doc.rag.document_processing import DocumentProcessor +from chat_doc.rag.embedding_models import EmbeddingModel +from chat_doc.rag.llama_models import LlamaModel +from chat_doc.rag.retrieval import VectorDBRetriever +from chat_doc.rag.vector_store_management import VectorStoreManager + + +class RAGManager: + def __init__(self): + self.embed_model = None + self.llama_model = None + self.service_context = None + self.vector_store_manager = None + self.retriever = None + self.query_engine = None + self.init_rag() + + def init_rag(self): + # Initialize embedding model + self.embed_model = EmbeddingModel(model_name="BAAI/bge-small-en") + + # Initialize LLM model and service context + self.llama_model = LlamaModel( + model_url="https://huggingface.co/TheBloke/Llama-2-7B-chat-GGUF/resolve/main/llama-2-7b-chat.Q4_0.gguf" + ) + self.service_context = self.llama_model.create_service_context(self.embed_model.embed_model) + + # Initialize document processor and vector store manager + SimpleCSVReader = download_loader("SimpleCSVReader") + loader = SimpleCSVReader(encoding="utf-8") + document_processor = DocumentProcessor(loader) + + self.vector_store_manager = VectorStoreManager( + db_name="vector_db", + host="localhost", + user="docrag", + password="rag-adl-llama", + port="5432", + ) + self.vector_store_manager.setup_database() + self.vector_store_manager.create_vector_store( + table_name="icd11", + embed_dim=384, # increase this to 768 if using a large model? + ) + + # if ADD_NODES (initialize vector store) is True, add nodes to vector store + ADD_NODES = False + if ADD_NODES: + # TODO: Uncomment this to process documents for embeddings --> MAKE CLEANER (e.g. cli command) + # Process documents for embeddings + documents = document_processor.load_documents(BASE_DIR + "/data/icd11.csv") + nodes = document_processor.process_documents(documents) + self.vector_store_manager.add_nodes(nodes) + + # Initialize retriever and query engine + self.retriever = VectorDBRetriever( + self.vector_store_manager.vector_store, + self.embed_model, + # query_mode = "default" + # query_mode = "sparse" + # query_mode = "hybrid" + query_mode="default", + similarity_top_k=2, + ) + self.query_engine = RetrieverQueryEngine.from_args( + self.retriever, service_context=self.service_context + ) + + def retrieve(self, query_string, use_llm=False): + """ + Retrieve information based on the query string. + + Args: + query_string (str): The query string for retrieval. + use_llm (bool): Determines whether to use LLM for augmented generation or simple vector retrieval. + + Returns: + Response from the retrieval process. + """ + if use_llm: + # Use LLM for augmented generation + return self.query_engine.query(query_string) + # Use simple passage retrieval from vector database + return self.retriever.query(query_string) + + +def _handle_response(response): + """ + Handle the response from the retrieval process. + + Args: + response: The response from the retrieval process. + """ + if isinstance(response, str): + print(response) + else: + for node_with_score in response: + print(f"Score: {node_with_score.score}, Content: {node_with_score.node.get_content()}") + + +def retrieve(query_string, use_llm=False): + rag_manager = RAGManager() + + return _handle_response(rag_manager.retrieve(query_string, use_llm)) + + +if __name__ == "__main__": + # Example usage + + rag_manager = RAGManager() + query_str = "What are the symptoms of a migraine?" + print("Retrieving information based on the query string:") + print(query_str + "\n") + + # Example usage with LLM + print("Using LLM for augmented generation:") + print(retrieve(query_str, use_llm=True)) + + # Example usage with simple vector database retrieval + print("\nUsing simple vector database retrieval:") + print(retrieve(query_str)) diff --git a/chat_doc/rag/parse_data.py b/chat_doc/rag/parse_data.py new file mode 100644 index 0000000..ff94f23 --- /dev/null +++ b/chat_doc/rag/parse_data.py @@ -0,0 +1,25 @@ +import pandas as pd + + +def parse_data(): + _df = pd.read_json("chat_doc/data/pinglab-ICD11-data.json") + _df = _df.query("definition != 'Key Not found'") + _df.reset_index(inplace=True) + _df = _df["name", "definition"] + _df["text"] = "Name:" + _df["name"] + "\nDefinition: " + _df["definition"] + + return _df + + +def store_data(df): + # store data in db or csv? + df.to_csv("chat_doc/data/icd11.csv", index=False) + + +def main(): + df = parse_data() + store_data(df) + + +if __name__ == "__main__": + main() diff --git a/chat_doc/rag/retrieval.py b/chat_doc/rag/retrieval.py new file mode 100644 index 0000000..247dab7 --- /dev/null +++ b/chat_doc/rag/retrieval.py @@ -0,0 +1,74 @@ +from typing import Any, List, Optional + +from llama_index import QueryBundle +from llama_index.retrievers import BaseRetriever +from llama_index.schema import NodeWithScore +from llama_index.vector_stores import VectorStoreQuery + + +class VectorDBRetriever(BaseRetriever): + def __init__(self, vector_store, embed_model, query_mode="default", similarity_top_k=2): + self._vector_store = vector_store + self._embed_model = embed_model + self._query_mode = query_mode + self._similarity_top_k = similarity_top_k + super().__init__() + + def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: + query_embedding = self._embed_model.get_query_embedding(query_bundle.query_str) + vector_store_query = VectorStoreQuery( + query_embedding=query_embedding, + similarity_top_k=self._similarity_top_k, + mode=self._query_mode, + ) + query_result = self._vector_store.query(vector_store_query) + + nodes_with_scores = [] + for index, node in enumerate(query_result.nodes): + score: Optional[float] = None + if query_result.similarities is not None: + score = query_result.similarities[index] + nodes_with_scores.append(NodeWithScore(node=node, score=score)) + + return nodes_with_scores + + def query(self, query_str: str) -> Any: + # Check if embeddings exist for the nodes in the vector store. + embeddings_exist = self._check_embeddings_existence() + if not embeddings_exist: + # Prompt the user to decide whether to process documents for embeddings. + process_documents = self._ask_user_for_processing() + if process_documents: + self._process_documents() + else: + return "Embeddings are missing. Please process documents for embeddings." + + return self._retrieve(query_bundle=QueryBundle(query_str=query_str)) + + def _check_embeddings_existence(self): + """ + Check if embeddings exist for the nodes in the vector store. + + Returns: + bool: True if embeddings exist, False otherwise. + """ + # Implement the logic to check if embeddings exist. + # This could be a database query or a check in the vector store. + # For example: return self._vector_store.check_embeddings_existence() + return True # Placeholder return + + def _ask_user_for_processing(self): + """ + Prompt the user to decide whether to process documents for embeddings. + + Returns: + bool: True if user agrees to process, False otherwise. + """ + user_input = ( + input( + "Embeddings are missing. Would you like to start processing documents to generate embeddings? (y/n): " + ) + .strip() + .lower() + ) + return user_input == "y" diff --git a/chat_doc/rag/vector_store_management.py b/chat_doc/rag/vector_store_management.py new file mode 100644 index 0000000..0e818ea --- /dev/null +++ b/chat_doc/rag/vector_store_management.py @@ -0,0 +1,44 @@ +import psycopg2 +from llama_index.vector_stores import PGVectorStore + + +class VectorStoreManager: + def __init__(self, db_name, host, user, password, port): + self.db_name = db_name + self.host = host + self.user = user + self.password = password + self.port = port + self.conn = None + self.vector_store = None + + def setup_database(self): + self.conn = psycopg2.connect( + dbname="postgres", + host=self.host, + password=self.password, + port=self.port, + user=self.user, + ) + self.conn.autocommit = True + with self.conn.cursor() as c: + c.execute(f"DROP DATABASE IF EXISTS {self.db_name}") + c.execute(f"CREATE DATABASE {self.db_name}") + self.conn.close() + + def create_vector_store(self, table_name, embed_dim): + self.vector_store = PGVectorStore.from_params( + database=self.db_name, + host=self.host, + password=self.password, + port=self.port, + user=self.user, + table_name=table_name, + embed_dim=embed_dim, + ) + + def add_nodes(self, nodes): + if self.vector_store: + self.vector_store.add(nodes) + else: + raise Exception("Vector store not initialized.")