Skip to content

Commit

Permalink
edited utils
Browse files Browse the repository at this point in the history
  • Loading branch information
sreyakumar committed Oct 17, 2024
1 parent 1ae7f9e commit 43d6e47
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 24 deletions.
26 changes: 18 additions & 8 deletions src/metadata_chatbot/agents/agentic_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from langchain_core.tools import tool
from aind_data_access_api.document_db_ssh import DocumentDbSSHClient, DocumentDbSSHCredentials
from langchain.agents import AgentExecutor, create_tool_calling_agent
from aind_data_access_api.document_db import MetadataDbClient

logging.basicConfig(filename='agentic_graph.log', level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', filemode="w")

Expand All @@ -32,9 +33,19 @@ class RouteQuery(BaseModel):


# Queries that require surveying the entire database (like count based questions)
credentials = DocumentDbSSHCredentials()
credentials.database = "metadata_vector_index"
credentials.collection = "curated_assets"
# credentials = DocumentDbSSHCredentials()
# credentials.database = "metadata_vector_index"
# credentials.collection = "curated_assets"
API_GATEWAY_HOST = "api.allenneuraldynamics-test.org"
DATABASE = "metadata_vector_index"
COLLECTION = "bigger_LANGCHAIN_curated_chunks"

docdb_api_client = MetadataDbClient(
host=API_GATEWAY_HOST,
database=DATABASE,
collection=COLLECTION,
)

@tool
def aggregation_retrieval(agg_pipeline: list) -> list:
"""Given a MongoDB query and list of projections, this function retrieves and returns the
Expand All @@ -52,12 +63,11 @@ def aggregation_retrieval(agg_pipeline: list) -> list:
list
List of retrieved documents
"""
with DocumentDbSSHClient(credentials=credentials) as doc_db_client:

result = list(doc_db_client.collection.aggregate(
pipeline=agg_pipeline
))
return result
result = docdb_api_client.aggregate_docdb_records(
pipeline=agg_pipeline
)
return result

tools = [aggregation_retrieval]
prompt = hub.pull("eden19/entire_db_retrieval")
Expand Down
28 changes: 14 additions & 14 deletions src/metadata_chatbot/agents/async_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from langgraph.checkpoint.memory import MemorySaver


sys.path.append(os.path.abspath("C:/Users/sreya.kumar/Documents/GitHub/metadata-chatbot"))
from metadata_chatbot.utils import ResourceManager
# sys.path.append(os.path.abspath("C:/Users/sreya.kumar/Documents/GitHub/metadata-chatbot"))
# from metadata_chatbot.utils import ResourceManager
from aind_data_access_api.document_db import MetadataDbClient

from metadata_chatbot.agents.docdb_retriever import DocDBRetriever
Expand Down Expand Up @@ -205,18 +205,18 @@ async def generate_async(state):

async_app = async_workflow.compile()

async def main():
query = "Can you give me a timeline of events for subject 675387?"
inputs = {"query": query}
result = async_app.astream(inputs)
# async def main():
# query = "Can you give me a timeline of events for subject 675387?"
# inputs = {"query": query}
# result = async_app.astream(inputs)

value = None
async for output in result:
for key, value in output.items():
logging.info(f"Currently on node '{key}':")
# value = None
# async for output in result:
# for key, value in output.items():
# logging.info(f"Currently on node '{key}':")

if value:
print(value['generation'])
# if value:
# print(value['generation'])

#Run the async function
asyncio.run(main())
# #Run the async function
# asyncio.run(main())
4 changes: 2 additions & 2 deletions src/metadata_chatbot/agents/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from langgraph.checkpoint.memory import MemorySaver


sys.path.append(os.path.abspath("C:/Users/sreya.kumar/Documents/GitHub/metadata-chatbot"))
from metadata_chatbot.utils import ResourceManager
# sys.path.append(os.path.abspath("C:/Users/sreya.kumar/Documents/GitHub/metadata-chatbot"))
# from metadata_chatbot.utils import ResourceManager

from metadata_chatbot.agents.docdb_retriever import DocDBRetriever
from metadata_chatbot.agents.agentic_graph import datasource_router, db_surveyor, query_grader, filter_generation_chain, doc_grader, rag_chain
Expand Down

0 comments on commit 43d6e47

Please sign in to comment.