Skip to content

Commit

Permalink
added reasoning checks
Browse files Browse the repository at this point in the history
  • Loading branch information
sreyakumar committed Oct 28, 2024
1 parent 82f1771 commit 18e3212
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 28 deletions.
18 changes: 6 additions & 12 deletions GAMER_workbook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -174,26 +174,20 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 5,
"metadata": {},
"outputs": [
{
"ename": "ImportError",
"evalue": "cannot import name 'db_surveyor' from 'metadata_chatbot.agents.agentic_graph' (c:\\Users\\sreya.kumar\\Documents\\GitHub\\metadata-chatbot\\venv\\Lib\\site-packages\\metadata_chatbot\\agents\\agentic_graph.py)",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mImportError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[1;32mIn[1], line 1\u001b[0m\n\u001b[1;32m----> 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mmetadata_chatbot\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01magents\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mGAMER\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m GAMER\n\u001b[0;32m 2\u001b[0m query \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mWhat are all the assets using mouse 675387\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 4\u001b[0m model \u001b[38;5;241m=\u001b[39m GAMER()\n",
"File \u001b[1;32mc:\\Users\\sreya.kumar\\Documents\\GitHub\\metadata-chatbot\\venv\\Lib\\site-packages\\metadata_chatbot\\agents\\GAMER.py:9\u001b[0m\n\u001b[0;32m 5\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mlangchain_core\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01moutputs\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m GenerationChunk\n\u001b[0;32m 7\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mlogging\u001b[39;00m\u001b[38;5;241m,\u001b[39m \u001b[38;5;21;01masyncio\u001b[39;00m\n\u001b[1;32m----> 9\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mmetadata_chatbot\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01magents\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01masync_workflow\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m async_app\n\u001b[0;32m 10\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mmetadata_chatbot\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01magents\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mworkflow\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m app\n\u001b[0;32m 13\u001b[0m \u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mGAMER\u001b[39;00m(LLM):\n",
"File \u001b[1;32mc:\\Users\\sreya.kumar\\Documents\\GitHub\\metadata-chatbot\\venv\\Lib\\site-packages\\metadata_chatbot\\agents\\async_workflow.py:14\u001b[0m\n\u001b[0;32m 11\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01maind_data_access_api\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdocument_db\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m MetadataDbClient\n\u001b[0;32m 13\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mmetadata_chatbot\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01magents\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdocdb_retriever\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m DocDBRetriever\n\u001b[1;32m---> 14\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mmetadata_chatbot\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01magents\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01magentic_graph\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m datasource_router, db_surveyor, query_grader, filter_generation_chain, doc_grader, rag_chain\n\u001b[0;32m 16\u001b[0m logging\u001b[38;5;241m.\u001b[39mbasicConfig(filename\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124masync_workflow.log\u001b[39m\u001b[38;5;124m'\u001b[39m, level\u001b[38;5;241m=\u001b[39mlogging\u001b[38;5;241m.\u001b[39mINFO, \u001b[38;5;28mformat\u001b[39m\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m%(asctime)s\u001b[39;00m\u001b[38;5;124m - \u001b[39m\u001b[38;5;132;01m%(name)s\u001b[39;00m\u001b[38;5;124m - \u001b[39m\u001b[38;5;132;01m%(levelname)s\u001b[39;00m\u001b[38;5;124m - \u001b[39m\u001b[38;5;132;01m%(message)s\u001b[39;00m\u001b[38;5;124m'\u001b[39m, filemode\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mw\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m 18\u001b[0m API_GATEWAY_HOST \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mapi.allenneuraldynamics-test.org\u001b[39m\u001b[38;5;124m\"\u001b[39m\n",
"\u001b[1;31mImportError\u001b[0m: cannot import name 'db_surveyor' from 'metadata_chatbot.agents.agentic_graph' (c:\\Users\\sreya.kumar\\Documents\\GitHub\\metadata-chatbot\\venv\\Lib\\site-packages\\metadata_chatbot\\agents\\agentic_graph.py)"
"name": "stdout",
"output_type": "stream",
"text": [
"Based on the retrieved context, there is no information about the genotype for subject 675387. The context describes details about the imaging acquisition parameters and tile locations, but does not mention the subject's genotype.\n"
]
}
],
"source": [
"from metadata_chatbot.agents.GAMER import GAMER\n",
"query = \"What are all the assets using mouse 675387\"\n",
"query = \"What is the genotype for subject 675387?\"\n",
"\n",
"model = GAMER()\n",
"result = model.invoke(query)\n",
Expand Down
20 changes: 18 additions & 2 deletions src/metadata_chatbot/agents/agentic_graph.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pydantic import BaseModel, Field
from langchain_aws import ChatBedrock
from langchain_aws.chat_models.bedrock import ChatBedrock
from langchain import hub
import logging
from typing import Literal
Expand All @@ -11,7 +11,7 @@

logging.basicConfig(filename='agentic_graph.log', level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', filemode="w")

GLOBAL_VRAIABLE = 90
GLOBAL_VARIABLE = 90

MODEL_ID = "anthropic.claude-3-sonnet-20240229-v1:0"
LLM = ChatBedrock(
Expand All @@ -25,6 +25,10 @@
class RouteQuery(BaseModel):
"""Route a user query to the most relevant datasource."""

reasoning: str = Field(
description="Give a justification for the chosen method",
)

datasource: Literal["vectorstore", "direct_database"] = Field(
description="Given a user question choose to route it to the direct database or its vectorstore.",
)
Expand Down Expand Up @@ -125,10 +129,19 @@ class FilterGenerator(BaseModel):
# Check if retrieved documents answer question
class RetrievalGrader(BaseModel):
"""Binary score to check whether retrieved documents are relevant to the question"""

reasoning: str = Field(
description="Give a reasoning as to what makes the document relevant for the chosen method",
)

binary_score: str = Field(
description="Retrieved documents are relevant to the query, 'yes' or 'no'"
)

relevant_context: str = Field(
description="Relevant pieces of context in document"
)

retrieval_grader = LLM.with_structured_output(RetrievalGrader)
retrieval_grade_prompt = hub.pull("eden19/retrievalgrader")
doc_grader = retrieval_grade_prompt | retrieval_grader
Expand All @@ -138,5 +151,8 @@ class RetrievalGrader(BaseModel):
# Generating response to documents
answer_generation_prompt = hub.pull("eden19/answergeneration")
rag_chain = answer_generation_prompt | LLM | StrOutputParser()

db_answer_generation_prompt = hub.pull("eden19/db_answergeneration")
db_rag_chain = answer_generation_prompt | LLM | StrOutputParser()
# generation = rag_chain.invoke({"documents": doc, "query": question})
# logging.info(f"Final answer: {generation}")
2 changes: 1 addition & 1 deletion src/metadata_chatbot/agents/docdb_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
class DocDBRetriever(BaseRetriever):
"""A retriever that contains the top k documents, retrieved from the DocDB index, aligned with the user's query."""
#collection: Any = Field(description="DocDB collection to retrieve from")
k: int = Field(default=10, description="Number of documents to retrieve")
k: int = Field(default=5, description="Number of documents to retrieve")

def _get_relevant_documents(
self,
Expand Down
45 changes: 32 additions & 13 deletions src/metadata_chatbot/agents/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@
from typing_extensions import TypedDict
from langgraph.graph import END, StateGraph, START
from langgraph.checkpoint.memory import MemorySaver


# sys.path.append(os.path.abspath("C:/Users/sreya.kumar/Documents/GitHub/metadata-chatbot"))
from metadata_chatbot.agents.agentic_graph import datasource_router, query_retriever, query_grader, filter_generation_chain, doc_grader, rag_chain, db_rag_chain
# from metadata_chatbot.utils import ResourceManager

from metadata_chatbot.agents.docdb_retriever import DocDBRetriever
from metadata_chatbot.agents.agentic_graph import datasource_router, query_retriever, query_grader, filter_generation_chain, doc_grader, rag_chain
from metadata_chatbot.agents.agentic_graph import datasource_router, query_retriever, query_grader, filter_generation_chain, doc_grader, rag_chain, db_rag_chain

logging.basicConfig(filename='async_workflow.log', level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', filemode="w")

Expand Down Expand Up @@ -111,7 +109,7 @@ def retrieve(state):
filter = state["filter"]
#top_k = state["top_k"]

retriever = DocDBRetriever(k = 10)
retriever = DocDBRetriever(k = 5)
documents = retriever.get_relevant_documents(query = query, query_filter = filter)

# # Retrieval
Expand Down Expand Up @@ -144,14 +142,33 @@ def grade_documents(state):
logging.info(f"Retrieved document matched query: {grade}")
if grade == "yes":
logging.info("Document is relevant to the query")
filtered_docs.append(doc)
relevant_context = score.relevant_context
filtered_docs.append(relevant_context)
else:
logging.info("Document is not relevant and will be removed")
continue
doc_text = "\n\n".join(doc.page_content for doc in filtered_docs)
#doc_text = "\n\n".join(doc.page_content for doc in filtered_docs)
return {"documents": filtered_docs, "query": query}

def generate(state):
def generate_db(state):
"""
Generate answer
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, generation, that contains LLM generation
"""
logging.info("Generating answer...")
query = state["query"]
documents = state["documents"]

# RAG generation
generation = db_rag_chain.invoke({"documents": documents, "query": query})
return {"documents": documents, "query": query, "generation": generation, "filter": state.get("filter", None)}

def generate_vi(state):
"""
Generate answer
Expand All @@ -174,7 +191,8 @@ def generate(state):
workflow.add_node("filter_generation", filter_generator)
workflow.add_node("retrieve", retrieve)
workflow.add_node("document_grading", grade_documents)
workflow.add_node("generate", generate)
workflow.add_node("generate_db", generate_db)
workflow.add_node("generate_vi", generate_vi)

workflow.add_conditional_edges(
START,
Expand All @@ -184,16 +202,17 @@ def generate(state):
"vectorstore": "filter_generation",
},
)
workflow.add_edge("database_query", "generate")
workflow.add_edge("database_query", "generate_db")
workflow.add_edge("generate_db", END)
workflow.add_edge("filter_generation", "retrieve")
workflow.add_edge("retrieve", "document_grading")
workflow.add_edge("document_grading","generate")
workflow.add_edge("generate", END)
workflow.add_edge("document_grading","generate_vi")
workflow.add_edge("generate_vi", END)


app = workflow.compile()

# query = "What is the mongodb query to find all the assets using mouse 675387"
# query = "Write a MongoDB query to find the genotype of SmartSPIM_675387_2023-05-23_23-05-56"

# inputs = {"query" : query}
# answer = app.invoke(inputs)
Expand Down

0 comments on commit 18e3212

Please sign in to comment.