From 584320de3deeeed7abd624e26e51ac8fb96fd22e Mon Sep 17 00:00:00 2001 From: sreyakumar <121137643+sreyakumar@users.noreply.github.com> Date: Thu, 17 Oct 2024 09:43:03 -0700 Subject: [PATCH] changed credentials to rest API --- src/metadata_chatbot/agents/async_workflow.py | 49 ++++++++++++------- .../agents/docdb_retriever.py | 20 ++++++-- src/metadata_chatbot/agents/workflow.py | 21 ++++---- 3 files changed, 59 insertions(+), 31 deletions(-) diff --git a/src/metadata_chatbot/agents/async_workflow.py b/src/metadata_chatbot/agents/async_workflow.py index 4f1bd49..0250c2c 100644 --- a/src/metadata_chatbot/agents/async_workflow.py +++ b/src/metadata_chatbot/agents/async_workflow.py @@ -8,12 +8,23 @@ sys.path.append(os.path.abspath("C:/Users/sreya.kumar/Documents/GitHub/metadata-chatbot")) from metadata_chatbot.utils import ResourceManager +from aind_data_access_api.document_db import MetadataDbClient from metadata_chatbot.agents.docdb_retriever import DocDBRetriever from metadata_chatbot.agents.agentic_graph import datasource_router, db_surveyor, query_grader, filter_generation_chain, doc_grader, rag_chain logging.basicConfig(filename='async_workflow.log', level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', filemode="w") +API_GATEWAY_HOST = "api.allenneuraldynamics-test.org" +DATABASE = "metadata_vector_index" +COLLECTION = "bigger_LANGCHAIN_curated_chunks" + +docdb_api_client = MetadataDbClient( + host=API_GATEWAY_HOST, + database=DATABASE, + collection=COLLECTION, +) + class GraphState(TypedDict): """ Represents the state of our graph. @@ -109,11 +120,15 @@ async def retrieve_async(state): filter = state["filter"] # Retrieval - with ResourceManager() as RM: - db = RM.async_client.get_database('metadata_vector_index') - collection = db.get_collection('bigger_LANGCHAIN_curated_chunks') - retriever = DocDBRetriever(collection = collection, k = 10) - documents = await retriever.aget_relevant_documents(query = query, query_filter = filter) + + retriever = DocDBRetriever(k = 10) + documents = await retriever.aget_relevant_documents(query = query, query_filter = filter) + + # with ResourceManager() as RM: + # db = RM.async_client.get_database('metadata_vector_index') + # collection = db.get_collection('bigger_LANGCHAIN_curated_chunks') + # retriever = DocDBRetriever(collection = collection, k = 10) + # documents = await retriever.aget_relevant_documents(query = query, query_filter = filter) return {"documents": documents, "query": query} async def grade_doc_async(query, doc: Document): @@ -190,18 +205,18 @@ async def generate_async(state): async_app = async_workflow.compile() -# async def main(): -# query = "Can you give me a timeline of events for subject 675387?" -# inputs = {"query": query} -# result = async_app.astream(inputs) +async def main(): + query = "Can you give me a timeline of events for subject 675387?" + inputs = {"query": query} + result = async_app.astream(inputs) -# value = None -# async for output in result: -# for key, value in output.items(): -# logging.info(f"Currently on node '{key}':") + value = None + async for output in result: + for key, value in output.items(): + logging.info(f"Currently on node '{key}':") -# if value: -# print(value['generation']) + if value: + print(value['generation']) -# #Run the async function -# asyncio.run(main()) +#Run the async function +asyncio.run(main()) diff --git a/src/metadata_chatbot/agents/docdb_retriever.py b/src/metadata_chatbot/agents/docdb_retriever.py index c329386..f5380a2 100644 --- a/src/metadata_chatbot/agents/docdb_retriever.py +++ b/src/metadata_chatbot/agents/docdb_retriever.py @@ -8,15 +8,25 @@ from bson import json_util from langsmith import trace as langsmith_trace from pydantic import Field +from aind_data_access_api.document_db import MetadataDbClient sys.path.append(os.path.abspath("C:/Users/sreya.kumar/Documents/GitHub/metadata-chatbot")) from metadata_chatbot.utils import BEDROCK_EMBEDDINGS +API_GATEWAY_HOST = "api.allenneuraldynamics-test.org" +DATABASE = "metadata_vector_index" +COLLECTION = "bigger_LANGCHAIN_curated_chunks" + +docdb_api_client = MetadataDbClient( + host=API_GATEWAY_HOST, + database=DATABASE, + collection=COLLECTION, +) class DocDBRetriever(BaseRetriever): """A retriever that contains the top k documents, retrieved from the DocDB index, aligned with the user's query.""" - collection: Any = Field(description="DocDB collection to retrieve from") + #collection: Any = Field(description="DocDB collection to retrieve from") k: int = Field(default=10, description="Number of documents to retrieve") def _get_relevant_documents( @@ -46,14 +56,14 @@ def _get_relevant_documents( if query_filter: pipeline.insert(0, query_filter) - cursor = self.collection.aggregate(pipeline) + result = docdb_api_client.aggregate_docdb_records(pipeline=pipeline) page_content_field = 'textContent' results = [] #Transform retrieved docs to langchain Documents - for document in cursor: + for document in result: values_to_metadata = dict() json_doc = json.loads(json_util.dumps(document)) @@ -97,7 +107,7 @@ async def _aget_relevant_documents( if query_filter: pipeline.insert(0, query_filter) - cursor = self.collection.aggregate(pipeline, allowDiskUse=True) + result = docdb_api_client.aggregate_docdb_records(pipeline=pipeline) #results = await cursor.to_list(length=1000) page_content_field = 'textContent' @@ -105,7 +115,7 @@ async def _aget_relevant_documents( results = [] #Transform retrieved docs to langchain Documents - async for document in cursor: + for document in result: values_to_metadata = dict() json_doc = json.loads(json_util.dumps(document)) diff --git a/src/metadata_chatbot/agents/workflow.py b/src/metadata_chatbot/agents/workflow.py index 5912d00..b4f5a6a 100644 --- a/src/metadata_chatbot/agents/workflow.py +++ b/src/metadata_chatbot/agents/workflow.py @@ -107,11 +107,14 @@ def retrieve(state): filter = state["filter"] top_k = state["top_k"] - # Retrieval - with ResourceManager() as RM: - collection = RM.client['metadata_vector_index']['LANGCHAIN_ALL_curated_assets'] - retriever = DocDBRetriever(collection = collection, k = top_k) - documents = retriever.get_relevant_documents(query = query, query_filter = filter) + retriever = DocDBRetriever(k = top_k) + documents = retriever.get_relevant_documents(query = query, query_filter = filter) + + # # Retrieval + # with ResourceManager() as RM: + # collection = RM.client['metadata_vector_index']['LANGCHAIN_ALL_curated_assets'] + # retriever = DocDBRetriever(collection = collection, k = top_k) + # documents = retriever.get_relevant_documents(query = query, query_filter = filter) return {"documents": documents, "query": query} def grade_documents(state): @@ -186,8 +189,8 @@ def generate(state): app = workflow.compile() -# query = "Give me the names of 5 assets have injections and are smartspim?" +query = "Can you give me a timeline of events for subject 675387?" -# inputs = {"query" : query} -# answer = app.invoke(inputs) -# print(answer['generation']) \ No newline at end of file +inputs = {"query" : query} +answer = app.invoke(inputs) +print(answer['generation']) \ No newline at end of file