Skip to content

Commit

Permalink
changed credentials to rest API
Browse files Browse the repository at this point in the history
  • Loading branch information
sreyakumar committed Oct 17, 2024
1 parent 1983032 commit 584320d
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 31 deletions.
49 changes: 32 additions & 17 deletions src/metadata_chatbot/agents/async_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,23 @@

sys.path.append(os.path.abspath("C:/Users/sreya.kumar/Documents/GitHub/metadata-chatbot"))
from metadata_chatbot.utils import ResourceManager
from aind_data_access_api.document_db import MetadataDbClient

from metadata_chatbot.agents.docdb_retriever import DocDBRetriever
from metadata_chatbot.agents.agentic_graph import datasource_router, db_surveyor, query_grader, filter_generation_chain, doc_grader, rag_chain

logging.basicConfig(filename='async_workflow.log', level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', filemode="w")

API_GATEWAY_HOST = "api.allenneuraldynamics-test.org"
DATABASE = "metadata_vector_index"
COLLECTION = "bigger_LANGCHAIN_curated_chunks"

docdb_api_client = MetadataDbClient(
host=API_GATEWAY_HOST,
database=DATABASE,
collection=COLLECTION,
)

class GraphState(TypedDict):
"""
Represents the state of our graph.
Expand Down Expand Up @@ -109,11 +120,15 @@ async def retrieve_async(state):
filter = state["filter"]

# Retrieval
with ResourceManager() as RM:
db = RM.async_client.get_database('metadata_vector_index')
collection = db.get_collection('bigger_LANGCHAIN_curated_chunks')
retriever = DocDBRetriever(collection = collection, k = 10)
documents = await retriever.aget_relevant_documents(query = query, query_filter = filter)

retriever = DocDBRetriever(k = 10)
documents = await retriever.aget_relevant_documents(query = query, query_filter = filter)

# with ResourceManager() as RM:
# db = RM.async_client.get_database('metadata_vector_index')
# collection = db.get_collection('bigger_LANGCHAIN_curated_chunks')
# retriever = DocDBRetriever(collection = collection, k = 10)
# documents = await retriever.aget_relevant_documents(query = query, query_filter = filter)
return {"documents": documents, "query": query}

async def grade_doc_async(query, doc: Document):
Expand Down Expand Up @@ -190,18 +205,18 @@ async def generate_async(state):

async_app = async_workflow.compile()

# async def main():
# query = "Can you give me a timeline of events for subject 675387?"
# inputs = {"query": query}
# result = async_app.astream(inputs)
async def main():
query = "Can you give me a timeline of events for subject 675387?"
inputs = {"query": query}
result = async_app.astream(inputs)

# value = None
# async for output in result:
# for key, value in output.items():
# logging.info(f"Currently on node '{key}':")
value = None
async for output in result:
for key, value in output.items():
logging.info(f"Currently on node '{key}':")

# if value:
# print(value['generation'])
if value:
print(value['generation'])

# #Run the async function
# asyncio.run(main())
#Run the async function
asyncio.run(main())
20 changes: 15 additions & 5 deletions src/metadata_chatbot/agents/docdb_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,25 @@
from bson import json_util
from langsmith import trace as langsmith_trace
from pydantic import Field
from aind_data_access_api.document_db import MetadataDbClient

sys.path.append(os.path.abspath("C:/Users/sreya.kumar/Documents/GitHub/metadata-chatbot"))
from metadata_chatbot.utils import BEDROCK_EMBEDDINGS

API_GATEWAY_HOST = "api.allenneuraldynamics-test.org"
DATABASE = "metadata_vector_index"
COLLECTION = "bigger_LANGCHAIN_curated_chunks"

docdb_api_client = MetadataDbClient(
host=API_GATEWAY_HOST,
database=DATABASE,
collection=COLLECTION,
)


class DocDBRetriever(BaseRetriever):
"""A retriever that contains the top k documents, retrieved from the DocDB index, aligned with the user's query."""
collection: Any = Field(description="DocDB collection to retrieve from")
#collection: Any = Field(description="DocDB collection to retrieve from")
k: int = Field(default=10, description="Number of documents to retrieve")

def _get_relevant_documents(
Expand Down Expand Up @@ -46,14 +56,14 @@ def _get_relevant_documents(
if query_filter:
pipeline.insert(0, query_filter)

cursor = self.collection.aggregate(pipeline)
result = docdb_api_client.aggregate_docdb_records(pipeline=pipeline)

page_content_field = 'textContent'

results = []

#Transform retrieved docs to langchain Documents
for document in cursor:
for document in result:
values_to_metadata = dict()

json_doc = json.loads(json_util.dumps(document))
Expand Down Expand Up @@ -97,15 +107,15 @@ async def _aget_relevant_documents(
if query_filter:
pipeline.insert(0, query_filter)

cursor = self.collection.aggregate(pipeline, allowDiskUse=True)
result = docdb_api_client.aggregate_docdb_records(pipeline=pipeline)
#results = await cursor.to_list(length=1000)

page_content_field = 'textContent'

results = []

#Transform retrieved docs to langchain Documents
async for document in cursor:
for document in result:
values_to_metadata = dict()

json_doc = json.loads(json_util.dumps(document))
Expand Down
21 changes: 12 additions & 9 deletions src/metadata_chatbot/agents/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,14 @@ def retrieve(state):
filter = state["filter"]
top_k = state["top_k"]

# Retrieval
with ResourceManager() as RM:
collection = RM.client['metadata_vector_index']['LANGCHAIN_ALL_curated_assets']
retriever = DocDBRetriever(collection = collection, k = top_k)
documents = retriever.get_relevant_documents(query = query, query_filter = filter)
retriever = DocDBRetriever(k = top_k)
documents = retriever.get_relevant_documents(query = query, query_filter = filter)

# # Retrieval
# with ResourceManager() as RM:
# collection = RM.client['metadata_vector_index']['LANGCHAIN_ALL_curated_assets']
# retriever = DocDBRetriever(collection = collection, k = top_k)
# documents = retriever.get_relevant_documents(query = query, query_filter = filter)
return {"documents": documents, "query": query}

def grade_documents(state):
Expand Down Expand Up @@ -186,8 +189,8 @@ def generate(state):

app = workflow.compile()

# query = "Give me the names of 5 assets have injections and are smartspim?"
query = "Can you give me a timeline of events for subject 675387?"

# inputs = {"query" : query}
# answer = app.invoke(inputs)
# print(answer['generation'])
inputs = {"query" : query}
answer = app.invoke(inputs)
print(answer['generation'])

0 comments on commit 584320d

Please sign in to comment.