Skip to content

Commit

Permalink
added query generation
Browse files Browse the repository at this point in the history
  • Loading branch information
sreyakumar committed Oct 25, 2024
1 parent e20958d commit 76b25f1
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 22 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ chatbot.ipynb
*.pkl
*.png
queries.txt
virtualenv/
test.py
graph_viz.ipynb

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
16 changes: 6 additions & 10 deletions GAMER_workbook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -174,22 +174,18 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\sreya.kumar\\Documents\\GitHub\\metadata-chatbot\\venv\\Lib\\site-packages\\metadata_chatbot\\agents\\workflow.py:113: LangChainDeprecationWarning: The method `BaseRetriever.get_relevant_documents` was deprecated in langchain-core 0.1.46 and will be removed in 1.0. Use :meth:`~invoke` instead.\n",
" documents = retriever.get_relevant_documents(query = query, query_filter = filter)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Based on the retrieved information, the genotype of subject 675387 is wt/wt (wild-type).\n"
"Based on the provided context, the assets using mouse 675387 are:\n",
"\n",
"Selective plane illumination microscopy: 2\n",
"\n",
"The context states that there are two data assets for subject 675387, both from selective plane illumination microscopy (SmartSPIM) experiments.\n"
]
}
],
Expand Down
7 changes: 5 additions & 2 deletions src/metadata_chatbot/agents/agentic_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@
from langchain_core.tools import tool
from langchain.agents import AgentExecutor, create_tool_calling_agent
from aind_data_access_api.document_db import MetadataDbClient
from pprint import pprint

logging.basicConfig(filename='agentic_graph.log', level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', filemode="w")

GLOBAL_VRAIABLE = 90

MODEL_ID = "anthropic.claude-3-sonnet-20240229-v1:0"
LLM = ChatBedrock(
model_id= MODEL_ID,
Expand Down Expand Up @@ -74,8 +77,8 @@ def aggregation_retrieval(agg_pipeline: list) -> list:
#llm_with_tools = LLM.bind_tools(tools)

db_surveyor_agent = create_tool_calling_agent(LLM, tools, db_prompt)
db_surveyor = AgentExecutor(agent=db_surveyor_agent, tools=tools, verbose=False)
#print(db_surveyor.invoke({'chat_history': [],"query": "What are the data asset names using mouse 675387"}))
query_retriever = AgentExecutor(agent=db_surveyor_agent, tools=tools, return_intermediate_steps = True, verbose=False)
# pprint(query_retriever.invoke({'chat_history': [],"query": "What are the data asset names using mouse 675387", "agent_scratchpad":[]}))
# class retrieve_aggregation(BaseModel):
# """List of results retrieved from mongodb database after applying pipeline generated by the model"""

Expand Down
23 changes: 14 additions & 9 deletions src/metadata_chatbot/agents/workflow.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import logging, sys, os
import logging, sys, os, json
from typing import List, Optional
from typing_extensions import TypedDict
from langgraph.graph import END, StateGraph, START
Expand All @@ -9,7 +9,7 @@
# from metadata_chatbot.utils import ResourceManager

from metadata_chatbot.agents.docdb_retriever import DocDBRetriever
from metadata_chatbot.agents.agentic_graph import datasource_router, db_surveyor, query_grader, filter_generation_chain, doc_grader, rag_chain
from metadata_chatbot.agents.agentic_graph import datasource_router, query_retriever, query_grader, filter_generation_chain, doc_grader, rag_chain

logging.basicConfig(filename='async_workflow.log', level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', filemode="w")

Expand Down Expand Up @@ -60,12 +60,17 @@ def generate_for_whole_db(state):
"""

query = state['query']
chat_history = []

logging.info("Generating answer...")

documents_dict = db_surveyor.invoke({'query': query, 'chat_history': chat_history, 'agent_scratchpad': []})
documents = documents_dict['output'][0]['text']
document_dict = dict()
retrieved_dict = query_retriever.invoke({'query': query, 'chat_history': [], 'agent_scratchpad' : []})
document_dict['mongodb_query'] = retrieved_dict['intermediate_steps'][0][0].tool_input['agg_pipeline']
document_dict['retrieved_output'] = retrieved_dict['intermediate_steps'][0][1]

print(document_dict)

documents = json.dumps(document_dict)
return {"query": query, "documents": documents}

def filter_generator(state):
Expand Down Expand Up @@ -190,8 +195,8 @@ def generate(state):

app = workflow.compile()

# query = "What are all the assets using mouse 675387"
query = "What is the mongodb query to find all the assets using mouse 675387"

# inputs = {"query" : query}
# answer = app.invoke(inputs)
# print(answer['generation'])
inputs = {"query" : query}
answer = app.invoke(inputs)
print(answer['generation'])

0 comments on commit 76b25f1

Please sign in to comment.