-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
d5c1996
commit e934e05
Showing
8 changed files
with
338 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from pathlib import Path | ||
|
||
from llama_index.node_parser.text import SentenceSplitter | ||
from llama_index.schema import TextNode | ||
|
||
|
||
class DocumentProcessor: | ||
def __init__(self, loader, text_parser=SentenceSplitter(chunk_size=1024)): | ||
self.loader = loader | ||
self.text_parser = text_parser | ||
|
||
def load_documents(self, file_path): | ||
return self.loader.load_data(file=Path(file_path)) | ||
|
||
def process_documents(self, documents): | ||
text_chunks = [] | ||
doc_idxs = [] | ||
for doc_idx, doc in enumerate(documents): | ||
cur_text_chunks = self.text_parser.split_text(doc.text) | ||
text_chunks.extend(cur_text_chunks) | ||
doc_idxs.extend([doc_idx] * len(cur_text_chunks)) | ||
|
||
nodes = [] | ||
for idx, text_chunk in enumerate(text_chunks): | ||
node = TextNode(text=text_chunk) | ||
src_doc = documents[doc_idxs[idx]] | ||
node.metadata = src_doc.metadata | ||
nodes.append(node) | ||
|
||
return nodes |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from llama_index.embeddings import HuggingFaceEmbedding | ||
|
||
|
||
class EmbeddingModel: | ||
def __init__(self, model_name): | ||
self.embed_model = HuggingFaceEmbedding(model_name=model_name) | ||
|
||
def get_text_embedding(self, text): | ||
return self.embed_model.get_text_embedding(text) | ||
|
||
def get_query_embedding(self, query): | ||
return self.embed_model.get_query_embedding(query) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from llama_index import ServiceContext | ||
from llama_index.llms import LlamaCPP | ||
|
||
|
||
class LlamaModel: | ||
def __init__( | ||
self, | ||
model_url, | ||
model_path=None, | ||
temperature=0.1, | ||
max_new_tokens=256, | ||
context_window=3900, | ||
model_kwargs={"n_gpu_layers": 1}, | ||
verbose=True, | ||
): | ||
self.llm = LlamaCPP( | ||
model_url=model_url, | ||
model_path=model_path, | ||
temperature=temperature, | ||
max_new_tokens=max_new_tokens, | ||
context_window=context_window, | ||
model_kwargs=model_kwargs, | ||
verbose=verbose, | ||
) | ||
|
||
def create_service_context(self, embed_model): | ||
return ServiceContext.from_defaults(llm=self.llm, embed_model=embed_model) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
from llama_index import download_loader | ||
from llama_index.query_engine import RetrieverQueryEngine | ||
from tqdm import tqdm | ||
|
||
from chat_doc.config import BASE_DIR | ||
from chat_doc.rag.document_processing import DocumentProcessor | ||
from chat_doc.rag.embedding_models import EmbeddingModel | ||
from chat_doc.rag.llama_models import LlamaModel | ||
from chat_doc.rag.retrieval import VectorDBRetriever | ||
from chat_doc.rag.vector_store_management import VectorStoreManager | ||
|
||
|
||
class RAGManager: | ||
def __init__(self): | ||
self.embed_model = None | ||
self.llama_model = None | ||
self.service_context = None | ||
self.vector_store_manager = None | ||
self.retriever = None | ||
self.query_engine = None | ||
self.init_rag() | ||
|
||
def init_rag(self): | ||
# Initialize embedding model | ||
self.embed_model = EmbeddingModel(model_name="BAAI/bge-small-en") | ||
|
||
# Initialize LLM model and service context | ||
self.llama_model = LlamaModel( | ||
model_url="https://huggingface.co/TheBloke/Llama-2-7B-chat-GGUF/resolve/main/llama-2-7b-chat.Q4_0.gguf" | ||
) | ||
self.service_context = self.llama_model.create_service_context(self.embed_model.embed_model) | ||
|
||
# Initialize document processor and vector store manager | ||
SimpleCSVReader = download_loader("SimpleCSVReader") | ||
loader = SimpleCSVReader(encoding="utf-8") | ||
document_processor = DocumentProcessor(loader) | ||
|
||
self.vector_store_manager = VectorStoreManager( | ||
db_name="vector_db", | ||
host="localhost", | ||
user="docrag", | ||
password="rag-adl-llama", | ||
port="5432", | ||
) | ||
self.vector_store_manager.setup_database() | ||
self.vector_store_manager.create_vector_store( | ||
table_name="icd11", | ||
embed_dim=384, # increase this to 768 if using a large model? | ||
) | ||
|
||
# if ADD_NODES (initialize vector store) is True, add nodes to vector store | ||
ADD_NODES = False | ||
if ADD_NODES: | ||
# TODO: Uncomment this to process documents for embeddings --> MAKE CLEANER (e.g. cli command) | ||
# Process documents for embeddings | ||
documents = document_processor.load_documents(BASE_DIR + "/data/icd11.csv") | ||
nodes = document_processor.process_documents(documents) | ||
self.vector_store_manager.add_nodes(nodes) | ||
|
||
# Initialize retriever and query engine | ||
self.retriever = VectorDBRetriever( | ||
self.vector_store_manager.vector_store, | ||
self.embed_model, | ||
# query_mode = "default" | ||
# query_mode = "sparse" | ||
# query_mode = "hybrid" | ||
query_mode="default", | ||
similarity_top_k=2, | ||
) | ||
self.query_engine = RetrieverQueryEngine.from_args( | ||
self.retriever, service_context=self.service_context | ||
) | ||
|
||
def retrieve(self, query_string, use_llm=False): | ||
""" | ||
Retrieve information based on the query string. | ||
Args: | ||
query_string (str): The query string for retrieval. | ||
use_llm (bool): Determines whether to use LLM for augmented generation or simple vector retrieval. | ||
Returns: | ||
Response from the retrieval process. | ||
""" | ||
if use_llm: | ||
# Use LLM for augmented generation | ||
return self.query_engine.query(query_string) | ||
# Use simple passage retrieval from vector database | ||
return self.retriever.query(query_string) | ||
|
||
|
||
def _handle_response(response): | ||
""" | ||
Handle the response from the retrieval process. | ||
Args: | ||
response: The response from the retrieval process. | ||
""" | ||
if isinstance(response, str): | ||
print(response) | ||
else: | ||
for node_with_score in response: | ||
print(f"Score: {node_with_score.score}, Content: {node_with_score.node.get_content()}") | ||
|
||
|
||
def retrieve(query_string, use_llm=False): | ||
rag_manager = RAGManager() | ||
|
||
return _handle_response(rag_manager.retrieve(query_string, use_llm)) | ||
|
||
|
||
if __name__ == "__main__": | ||
# Example usage | ||
|
||
rag_manager = RAGManager() | ||
query_str = "What are the symptoms of a migraine?" | ||
print("Retrieving information based on the query string:") | ||
print(query_str + "\n") | ||
|
||
# Example usage with LLM | ||
print("Using LLM for augmented generation:") | ||
print(retrieve(query_str, use_llm=True)) | ||
|
||
# Example usage with simple vector database retrieval | ||
print("\nUsing simple vector database retrieval:") | ||
print(retrieve(query_str)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import pandas as pd | ||
|
||
|
||
def parse_data(): | ||
_df = pd.read_json("chat_doc/data/pinglab-ICD11-data.json") | ||
_df = _df.query("definition != 'Key Not found'") | ||
_df.reset_index(inplace=True) | ||
_df = _df["name", "definition"] | ||
_df["text"] = "Name:" + _df["name"] + "\nDefinition: " + _df["definition"] | ||
|
||
return _df | ||
|
||
|
||
def store_data(df): | ||
# store data in db or csv? | ||
df.to_csv("chat_doc/data/icd11.csv", index=False) | ||
|
||
|
||
def main(): | ||
df = parse_data() | ||
store_data(df) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
from typing import Any, List, Optional | ||
|
||
from llama_index import QueryBundle | ||
from llama_index.retrievers import BaseRetriever | ||
from llama_index.schema import NodeWithScore | ||
from llama_index.vector_stores import VectorStoreQuery | ||
|
||
|
||
class VectorDBRetriever(BaseRetriever): | ||
def __init__(self, vector_store, embed_model, query_mode="default", similarity_top_k=2): | ||
self._vector_store = vector_store | ||
self._embed_model = embed_model | ||
self._query_mode = query_mode | ||
self._similarity_top_k = similarity_top_k | ||
super().__init__() | ||
|
||
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: | ||
query_embedding = self._embed_model.get_query_embedding(query_bundle.query_str) | ||
vector_store_query = VectorStoreQuery( | ||
query_embedding=query_embedding, | ||
similarity_top_k=self._similarity_top_k, | ||
mode=self._query_mode, | ||
) | ||
query_result = self._vector_store.query(vector_store_query) | ||
|
||
nodes_with_scores = [] | ||
for index, node in enumerate(query_result.nodes): | ||
score: Optional[float] = None | ||
if query_result.similarities is not None: | ||
score = query_result.similarities[index] | ||
nodes_with_scores.append(NodeWithScore(node=node, score=score)) | ||
|
||
return nodes_with_scores | ||
|
||
def query(self, query_str: str) -> Any: | ||
# Check if embeddings exist for the nodes in the vector store. | ||
embeddings_exist = self._check_embeddings_existence() | ||
if not embeddings_exist: | ||
# Prompt the user to decide whether to process documents for embeddings. | ||
process_documents = self._ask_user_for_processing() | ||
if process_documents: | ||
self._process_documents() | ||
else: | ||
return "Embeddings are missing. Please process documents for embeddings." | ||
|
||
return self._retrieve(query_bundle=QueryBundle(query_str=query_str)) | ||
|
||
def _check_embeddings_existence(self): | ||
""" | ||
Check if embeddings exist for the nodes in the vector store. | ||
Returns: | ||
bool: True if embeddings exist, False otherwise. | ||
""" | ||
# Implement the logic to check if embeddings exist. | ||
# This could be a database query or a check in the vector store. | ||
# For example: return self._vector_store.check_embeddings_existence() | ||
return True # Placeholder return | ||
|
||
def _ask_user_for_processing(self): | ||
""" | ||
Prompt the user to decide whether to process documents for embeddings. | ||
Returns: | ||
bool: True if user agrees to process, False otherwise. | ||
""" | ||
user_input = ( | ||
input( | ||
"Embeddings are missing. Would you like to start processing documents to generate embeddings? (y/n): " | ||
) | ||
.strip() | ||
.lower() | ||
) | ||
return user_input == "y" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import psycopg2 | ||
from llama_index.vector_stores import PGVectorStore | ||
|
||
|
||
class VectorStoreManager: | ||
def __init__(self, db_name, host, user, password, port): | ||
self.db_name = db_name | ||
self.host = host | ||
self.user = user | ||
self.password = password | ||
self.port = port | ||
self.conn = None | ||
self.vector_store = None | ||
|
||
def setup_database(self): | ||
self.conn = psycopg2.connect( | ||
dbname="postgres", | ||
host=self.host, | ||
password=self.password, | ||
port=self.port, | ||
user=self.user, | ||
) | ||
self.conn.autocommit = True | ||
with self.conn.cursor() as c: | ||
c.execute(f"DROP DATABASE IF EXISTS {self.db_name}") | ||
c.execute(f"CREATE DATABASE {self.db_name}") | ||
self.conn.close() | ||
|
||
def create_vector_store(self, table_name, embed_dim): | ||
self.vector_store = PGVectorStore.from_params( | ||
database=self.db_name, | ||
host=self.host, | ||
password=self.password, | ||
port=self.port, | ||
user=self.user, | ||
table_name=table_name, | ||
embed_dim=embed_dim, | ||
) | ||
|
||
def add_nodes(self, nodes): | ||
if self.vector_store: | ||
self.vector_store.add(nodes) | ||
else: | ||
raise Exception("Vector store not initialized.") |