Skip to content

Commit

Permalink
adding rag v1
Browse files Browse the repository at this point in the history
  • Loading branch information
MisterXY89 committed Jan 14, 2024
1 parent d5c1996 commit e934e05
Show file tree
Hide file tree
Showing 8 changed files with 338 additions and 0 deletions.
Empty file added chat_doc/rag/__init__.py
Empty file.
30 changes: 30 additions & 0 deletions chat_doc/rag/document_processing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from pathlib import Path

from llama_index.node_parser.text import SentenceSplitter
from llama_index.schema import TextNode


class DocumentProcessor:
def __init__(self, loader, text_parser=SentenceSplitter(chunk_size=1024)):
self.loader = loader
self.text_parser = text_parser

def load_documents(self, file_path):
return self.loader.load_data(file=Path(file_path))

def process_documents(self, documents):
text_chunks = []
doc_idxs = []
for doc_idx, doc in enumerate(documents):
cur_text_chunks = self.text_parser.split_text(doc.text)
text_chunks.extend(cur_text_chunks)
doc_idxs.extend([doc_idx] * len(cur_text_chunks))

nodes = []
for idx, text_chunk in enumerate(text_chunks):
node = TextNode(text=text_chunk)
src_doc = documents[doc_idxs[idx]]
node.metadata = src_doc.metadata
nodes.append(node)

return nodes
12 changes: 12 additions & 0 deletions chat_doc/rag/embedding_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from llama_index.embeddings import HuggingFaceEmbedding


class EmbeddingModel:
def __init__(self, model_name):
self.embed_model = HuggingFaceEmbedding(model_name=model_name)

def get_text_embedding(self, text):
return self.embed_model.get_text_embedding(text)

def get_query_embedding(self, query):
return self.embed_model.get_query_embedding(query)
27 changes: 27 additions & 0 deletions chat_doc/rag/llama_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from llama_index import ServiceContext
from llama_index.llms import LlamaCPP


class LlamaModel:
def __init__(
self,
model_url,
model_path=None,
temperature=0.1,
max_new_tokens=256,
context_window=3900,
model_kwargs={"n_gpu_layers": 1},
verbose=True,
):
self.llm = LlamaCPP(
model_url=model_url,
model_path=model_path,
temperature=temperature,
max_new_tokens=max_new_tokens,
context_window=context_window,
model_kwargs=model_kwargs,
verbose=verbose,
)

def create_service_context(self, embed_model):
return ServiceContext.from_defaults(llm=self.llm, embed_model=embed_model)
126 changes: 126 additions & 0 deletions chat_doc/rag/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from llama_index import download_loader
from llama_index.query_engine import RetrieverQueryEngine
from tqdm import tqdm

from chat_doc.config import BASE_DIR
from chat_doc.rag.document_processing import DocumentProcessor
from chat_doc.rag.embedding_models import EmbeddingModel
from chat_doc.rag.llama_models import LlamaModel
from chat_doc.rag.retrieval import VectorDBRetriever
from chat_doc.rag.vector_store_management import VectorStoreManager


class RAGManager:
def __init__(self):
self.embed_model = None
self.llama_model = None
self.service_context = None
self.vector_store_manager = None
self.retriever = None
self.query_engine = None
self.init_rag()

def init_rag(self):
# Initialize embedding model
self.embed_model = EmbeddingModel(model_name="BAAI/bge-small-en")

# Initialize LLM model and service context
self.llama_model = LlamaModel(
model_url="https://huggingface.co/TheBloke/Llama-2-7B-chat-GGUF/resolve/main/llama-2-7b-chat.Q4_0.gguf"
)
self.service_context = self.llama_model.create_service_context(self.embed_model.embed_model)

# Initialize document processor and vector store manager
SimpleCSVReader = download_loader("SimpleCSVReader")
loader = SimpleCSVReader(encoding="utf-8")
document_processor = DocumentProcessor(loader)

self.vector_store_manager = VectorStoreManager(
db_name="vector_db",
host="localhost",
user="docrag",
password="rag-adl-llama",
port="5432",
)
self.vector_store_manager.setup_database()
self.vector_store_manager.create_vector_store(
table_name="icd11",
embed_dim=384, # increase this to 768 if using a large model?
)

# if ADD_NODES (initialize vector store) is True, add nodes to vector store
ADD_NODES = False
if ADD_NODES:
# TODO: Uncomment this to process documents for embeddings --> MAKE CLEANER (e.g. cli command)
# Process documents for embeddings
documents = document_processor.load_documents(BASE_DIR + "/data/icd11.csv")
nodes = document_processor.process_documents(documents)
self.vector_store_manager.add_nodes(nodes)

# Initialize retriever and query engine
self.retriever = VectorDBRetriever(
self.vector_store_manager.vector_store,
self.embed_model,
# query_mode = "default"
# query_mode = "sparse"
# query_mode = "hybrid"
query_mode="default",
similarity_top_k=2,
)
self.query_engine = RetrieverQueryEngine.from_args(
self.retriever, service_context=self.service_context
)

def retrieve(self, query_string, use_llm=False):
"""
Retrieve information based on the query string.
Args:
query_string (str): The query string for retrieval.
use_llm (bool): Determines whether to use LLM for augmented generation or simple vector retrieval.
Returns:
Response from the retrieval process.
"""
if use_llm:
# Use LLM for augmented generation
return self.query_engine.query(query_string)
# Use simple passage retrieval from vector database
return self.retriever.query(query_string)


def _handle_response(response):
"""
Handle the response from the retrieval process.
Args:
response: The response from the retrieval process.
"""
if isinstance(response, str):
print(response)
else:
for node_with_score in response:
print(f"Score: {node_with_score.score}, Content: {node_with_score.node.get_content()}")


def retrieve(query_string, use_llm=False):
rag_manager = RAGManager()

return _handle_response(rag_manager.retrieve(query_string, use_llm))


if __name__ == "__main__":
# Example usage

rag_manager = RAGManager()
query_str = "What are the symptoms of a migraine?"
print("Retrieving information based on the query string:")
print(query_str + "\n")

# Example usage with LLM
print("Using LLM for augmented generation:")
print(retrieve(query_str, use_llm=True))

# Example usage with simple vector database retrieval
print("\nUsing simple vector database retrieval:")
print(retrieve(query_str))
25 changes: 25 additions & 0 deletions chat_doc/rag/parse_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import pandas as pd


def parse_data():
_df = pd.read_json("chat_doc/data/pinglab-ICD11-data.json")
_df = _df.query("definition != 'Key Not found'")
_df.reset_index(inplace=True)
_df = _df["name", "definition"]
_df["text"] = "Name:" + _df["name"] + "\nDefinition: " + _df["definition"]

return _df


def store_data(df):
# store data in db or csv?
df.to_csv("chat_doc/data/icd11.csv", index=False)


def main():
df = parse_data()
store_data(df)


if __name__ == "__main__":
main()
74 changes: 74 additions & 0 deletions chat_doc/rag/retrieval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from typing import Any, List, Optional

from llama_index import QueryBundle
from llama_index.retrievers import BaseRetriever
from llama_index.schema import NodeWithScore
from llama_index.vector_stores import VectorStoreQuery


class VectorDBRetriever(BaseRetriever):
def __init__(self, vector_store, embed_model, query_mode="default", similarity_top_k=2):
self._vector_store = vector_store
self._embed_model = embed_model
self._query_mode = query_mode
self._similarity_top_k = similarity_top_k
super().__init__()

def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
query_embedding = self._embed_model.get_query_embedding(query_bundle.query_str)
vector_store_query = VectorStoreQuery(
query_embedding=query_embedding,
similarity_top_k=self._similarity_top_k,
mode=self._query_mode,
)
query_result = self._vector_store.query(vector_store_query)

nodes_with_scores = []
for index, node in enumerate(query_result.nodes):
score: Optional[float] = None
if query_result.similarities is not None:
score = query_result.similarities[index]
nodes_with_scores.append(NodeWithScore(node=node, score=score))

return nodes_with_scores

def query(self, query_str: str) -> Any:
# Check if embeddings exist for the nodes in the vector store.
embeddings_exist = self._check_embeddings_existence()
if not embeddings_exist:
# Prompt the user to decide whether to process documents for embeddings.
process_documents = self._ask_user_for_processing()
if process_documents:
self._process_documents()
else:
return "Embeddings are missing. Please process documents for embeddings."

return self._retrieve(query_bundle=QueryBundle(query_str=query_str))

def _check_embeddings_existence(self):
"""
Check if embeddings exist for the nodes in the vector store.
Returns:
bool: True if embeddings exist, False otherwise.
"""
# Implement the logic to check if embeddings exist.
# This could be a database query or a check in the vector store.
# For example: return self._vector_store.check_embeddings_existence()
return True # Placeholder return

def _ask_user_for_processing(self):
"""
Prompt the user to decide whether to process documents for embeddings.
Returns:
bool: True if user agrees to process, False otherwise.
"""
user_input = (
input(
"Embeddings are missing. Would you like to start processing documents to generate embeddings? (y/n): "
)
.strip()
.lower()
)
return user_input == "y"
44 changes: 44 additions & 0 deletions chat_doc/rag/vector_store_management.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import psycopg2
from llama_index.vector_stores import PGVectorStore


class VectorStoreManager:
def __init__(self, db_name, host, user, password, port):
self.db_name = db_name
self.host = host
self.user = user
self.password = password
self.port = port
self.conn = None
self.vector_store = None

def setup_database(self):
self.conn = psycopg2.connect(
dbname="postgres",
host=self.host,
password=self.password,
port=self.port,
user=self.user,
)
self.conn.autocommit = True
with self.conn.cursor() as c:
c.execute(f"DROP DATABASE IF EXISTS {self.db_name}")
c.execute(f"CREATE DATABASE {self.db_name}")
self.conn.close()

def create_vector_store(self, table_name, embed_dim):
self.vector_store = PGVectorStore.from_params(
database=self.db_name,
host=self.host,
password=self.password,
port=self.port,
user=self.user,
table_name=table_name,
embed_dim=embed_dim,
)

def add_nodes(self, nodes):
if self.vector_store:
self.vector_store.add(nodes)
else:
raise Exception("Vector store not initialized.")

0 comments on commit e934e05

Please sign in to comment.