diff --git a/GAMER_workbook.ipynb b/GAMER_workbook.ipynb index b796d0c..e89629a 100644 --- a/GAMER_workbook.ipynb +++ b/GAMER_workbook.ipynb @@ -174,26 +174,20 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 5, "metadata": {}, "outputs": [ { - "ename": "ImportError", - "evalue": "cannot import name 'db_surveyor' from 'metadata_chatbot.agents.agentic_graph' (c:\\Users\\sreya.kumar\\Documents\\GitHub\\metadata-chatbot\\venv\\Lib\\site-packages\\metadata_chatbot\\agents\\agentic_graph.py)", - "output_type": "error", - "traceback": [ - "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[1;31mImportError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[1;32mIn[1], line 1\u001b[0m\n\u001b[1;32m----> 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mmetadata_chatbot\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01magents\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mGAMER\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m GAMER\n\u001b[0;32m 2\u001b[0m query \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mWhat are all the assets using mouse 675387\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 4\u001b[0m model \u001b[38;5;241m=\u001b[39m GAMER()\n", - "File \u001b[1;32mc:\\Users\\sreya.kumar\\Documents\\GitHub\\metadata-chatbot\\venv\\Lib\\site-packages\\metadata_chatbot\\agents\\GAMER.py:9\u001b[0m\n\u001b[0;32m 5\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mlangchain_core\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01moutputs\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m GenerationChunk\n\u001b[0;32m 7\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mlogging\u001b[39;00m\u001b[38;5;241m,\u001b[39m \u001b[38;5;21;01masyncio\u001b[39;00m\n\u001b[1;32m----> 9\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mmetadata_chatbot\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01magents\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01masync_workflow\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m async_app\n\u001b[0;32m 10\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mmetadata_chatbot\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01magents\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mworkflow\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m app\n\u001b[0;32m 13\u001b[0m \u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mGAMER\u001b[39;00m(LLM):\n", - "File \u001b[1;32mc:\\Users\\sreya.kumar\\Documents\\GitHub\\metadata-chatbot\\venv\\Lib\\site-packages\\metadata_chatbot\\agents\\async_workflow.py:14\u001b[0m\n\u001b[0;32m 11\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01maind_data_access_api\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdocument_db\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m MetadataDbClient\n\u001b[0;32m 13\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mmetadata_chatbot\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01magents\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdocdb_retriever\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m DocDBRetriever\n\u001b[1;32m---> 14\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mmetadata_chatbot\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01magents\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01magentic_graph\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m datasource_router, db_surveyor, query_grader, filter_generation_chain, doc_grader, rag_chain\n\u001b[0;32m 16\u001b[0m logging\u001b[38;5;241m.\u001b[39mbasicConfig(filename\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124masync_workflow.log\u001b[39m\u001b[38;5;124m'\u001b[39m, level\u001b[38;5;241m=\u001b[39mlogging\u001b[38;5;241m.\u001b[39mINFO, \u001b[38;5;28mformat\u001b[39m\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m%(asctime)s\u001b[39;00m\u001b[38;5;124m - \u001b[39m\u001b[38;5;132;01m%(name)s\u001b[39;00m\u001b[38;5;124m - \u001b[39m\u001b[38;5;132;01m%(levelname)s\u001b[39;00m\u001b[38;5;124m - \u001b[39m\u001b[38;5;132;01m%(message)s\u001b[39;00m\u001b[38;5;124m'\u001b[39m, filemode\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mw\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m 18\u001b[0m API_GATEWAY_HOST \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mapi.allenneuraldynamics-test.org\u001b[39m\u001b[38;5;124m\"\u001b[39m\n", - "\u001b[1;31mImportError\u001b[0m: cannot import name 'db_surveyor' from 'metadata_chatbot.agents.agentic_graph' (c:\\Users\\sreya.kumar\\Documents\\GitHub\\metadata-chatbot\\venv\\Lib\\site-packages\\metadata_chatbot\\agents\\agentic_graph.py)" + "name": "stdout", + "output_type": "stream", + "text": [ + "Based on the retrieved context, there is no information about the genotype for subject 675387. The context describes details about the imaging acquisition parameters and tile locations, but does not mention the subject's genotype.\n" ] } ], "source": [ "from metadata_chatbot.agents.GAMER import GAMER\n", - "query = \"What are all the assets using mouse 675387\"\n", + "query = \"What is the genotype for subject 675387?\"\n", "\n", "model = GAMER()\n", "result = model.invoke(query)\n", diff --git a/src/metadata_chatbot/agents/agentic_graph.py b/src/metadata_chatbot/agents/agentic_graph.py index de4a26c..366893e 100644 --- a/src/metadata_chatbot/agents/agentic_graph.py +++ b/src/metadata_chatbot/agents/agentic_graph.py @@ -1,5 +1,5 @@ from pydantic import BaseModel, Field -from langchain_aws import ChatBedrock +from langchain_aws.chat_models.bedrock import ChatBedrock from langchain import hub import logging from typing import Literal @@ -11,7 +11,7 @@ logging.basicConfig(filename='agentic_graph.log', level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', filemode="w") -GLOBAL_VRAIABLE = 90 +GLOBAL_VARIABLE = 90 MODEL_ID = "anthropic.claude-3-sonnet-20240229-v1:0" LLM = ChatBedrock( @@ -25,6 +25,10 @@ class RouteQuery(BaseModel): """Route a user query to the most relevant datasource.""" + reasoning: str = Field( + description="Give a justification for the chosen method", + ) + datasource: Literal["vectorstore", "direct_database"] = Field( description="Given a user question choose to route it to the direct database or its vectorstore.", ) @@ -125,10 +129,19 @@ class FilterGenerator(BaseModel): # Check if retrieved documents answer question class RetrievalGrader(BaseModel): """Binary score to check whether retrieved documents are relevant to the question""" + + reasoning: str = Field( + description="Give a reasoning as to what makes the document relevant for the chosen method", + ) + binary_score: str = Field( description="Retrieved documents are relevant to the query, 'yes' or 'no'" ) + relevant_context: str = Field( + description="Relevant pieces of context in document" + ) + retrieval_grader = LLM.with_structured_output(RetrievalGrader) retrieval_grade_prompt = hub.pull("eden19/retrievalgrader") doc_grader = retrieval_grade_prompt | retrieval_grader @@ -138,5 +151,8 @@ class RetrievalGrader(BaseModel): # Generating response to documents answer_generation_prompt = hub.pull("eden19/answergeneration") rag_chain = answer_generation_prompt | LLM | StrOutputParser() + +db_answer_generation_prompt = hub.pull("eden19/db_answergeneration") +db_rag_chain = answer_generation_prompt | LLM | StrOutputParser() # generation = rag_chain.invoke({"documents": doc, "query": question}) # logging.info(f"Final answer: {generation}") \ No newline at end of file diff --git a/src/metadata_chatbot/agents/docdb_retriever.py b/src/metadata_chatbot/agents/docdb_retriever.py index f5380a2..c2632c0 100644 --- a/src/metadata_chatbot/agents/docdb_retriever.py +++ b/src/metadata_chatbot/agents/docdb_retriever.py @@ -27,7 +27,7 @@ class DocDBRetriever(BaseRetriever): """A retriever that contains the top k documents, retrieved from the DocDB index, aligned with the user's query.""" #collection: Any = Field(description="DocDB collection to retrieve from") - k: int = Field(default=10, description="Number of documents to retrieve") + k: int = Field(default=5, description="Number of documents to retrieve") def _get_relevant_documents( self, diff --git a/src/metadata_chatbot/agents/workflow.py b/src/metadata_chatbot/agents/workflow.py index 6822b8d..c46dbf3 100644 --- a/src/metadata_chatbot/agents/workflow.py +++ b/src/metadata_chatbot/agents/workflow.py @@ -3,13 +3,11 @@ from typing_extensions import TypedDict from langgraph.graph import END, StateGraph, START from langgraph.checkpoint.memory import MemorySaver - - -# sys.path.append(os.path.abspath("C:/Users/sreya.kumar/Documents/GitHub/metadata-chatbot")) +from metadata_chatbot.agents.agentic_graph import datasource_router, query_retriever, query_grader, filter_generation_chain, doc_grader, rag_chain, db_rag_chain # from metadata_chatbot.utils import ResourceManager from metadata_chatbot.agents.docdb_retriever import DocDBRetriever -from metadata_chatbot.agents.agentic_graph import datasource_router, query_retriever, query_grader, filter_generation_chain, doc_grader, rag_chain +from metadata_chatbot.agents.agentic_graph import datasource_router, query_retriever, query_grader, filter_generation_chain, doc_grader, rag_chain, db_rag_chain logging.basicConfig(filename='async_workflow.log', level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', filemode="w") @@ -111,7 +109,7 @@ def retrieve(state): filter = state["filter"] #top_k = state["top_k"] - retriever = DocDBRetriever(k = 10) + retriever = DocDBRetriever(k = 5) documents = retriever.get_relevant_documents(query = query, query_filter = filter) # # Retrieval @@ -144,14 +142,33 @@ def grade_documents(state): logging.info(f"Retrieved document matched query: {grade}") if grade == "yes": logging.info("Document is relevant to the query") - filtered_docs.append(doc) + relevant_context = score.relevant_context + filtered_docs.append(relevant_context) else: logging.info("Document is not relevant and will be removed") continue - doc_text = "\n\n".join(doc.page_content for doc in filtered_docs) + #doc_text = "\n\n".join(doc.page_content for doc in filtered_docs) return {"documents": filtered_docs, "query": query} -def generate(state): +def generate_db(state): + """ + Generate answer + + Args: + state (dict): The current graph state + + Returns: + state (dict): New key added to state, generation, that contains LLM generation + """ + logging.info("Generating answer...") + query = state["query"] + documents = state["documents"] + + # RAG generation + generation = db_rag_chain.invoke({"documents": documents, "query": query}) + return {"documents": documents, "query": query, "generation": generation, "filter": state.get("filter", None)} + +def generate_vi(state): """ Generate answer @@ -174,7 +191,8 @@ def generate(state): workflow.add_node("filter_generation", filter_generator) workflow.add_node("retrieve", retrieve) workflow.add_node("document_grading", grade_documents) -workflow.add_node("generate", generate) +workflow.add_node("generate_db", generate_db) +workflow.add_node("generate_vi", generate_vi) workflow.add_conditional_edges( START, @@ -184,16 +202,17 @@ def generate(state): "vectorstore": "filter_generation", }, ) -workflow.add_edge("database_query", "generate") +workflow.add_edge("database_query", "generate_db") +workflow.add_edge("generate_db", END) workflow.add_edge("filter_generation", "retrieve") workflow.add_edge("retrieve", "document_grading") -workflow.add_edge("document_grading","generate") -workflow.add_edge("generate", END) +workflow.add_edge("document_grading","generate_vi") +workflow.add_edge("generate_vi", END) app = workflow.compile() -# query = "What is the mongodb query to find all the assets using mouse 675387" +# query = "Write a MongoDB query to find the genotype of SmartSPIM_675387_2023-05-23_23-05-56" # inputs = {"query" : query} # answer = app.invoke(inputs)