diff --git a/GAMER_workbook.ipynb b/GAMER_workbook.ipynb index 08df786..1f3823d 100644 --- a/GAMER_workbook.ipynb +++ b/GAMER_workbook.ipynb @@ -37,14 +37,14 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "c:\\Users\\sreya.kumar\\Documents\\GitHub\\metadata-chatbot\\venv\\Lib\\site-packages\\metadata_chatbot\\agents\\workflow.py:111: LangChainDeprecationWarning: The method `BaseRetriever.get_relevant_documents` was deprecated in langchain-core 0.1.46 and will be removed in 1.0. Use :meth:`~invoke` instead.\n", + "c:\\Users\\sreya.kumar\\Documents\\GitHub\\metadata-chatbot\\venv\\Lib\\site-packages\\metadata_chatbot\\agents\\workflow.py:105: LangChainDeprecationWarning: The method `BaseRetriever.get_relevant_documents` was deprecated in langchain-core 0.1.46 and will be removed in 1.0. Use :meth:`~invoke` instead.\n", " documents = retriever.get_relevant_documents(query = query, query_filter = filter)\n" ] }, @@ -52,27 +52,36 @@ "name": "stdout", "output_type": "stream", "text": [ - "Based on the provided context, here are the procedures performed on specimen 662616 along with their start and end dates:\n", + "Based on the provided context, here is a summary of the subject and acquisition information for SmartSPIM_662616_2023-03-06_17-47-13:\n", "\n", - "Subject procedures:\n", - "1. Surgery on 2023-01-25 with virus injections (SL1-hSyn-Cre, AAV1-CAG-H2B-mTurquoise2-WPRE, AAV-Syn-DIO-TVA66T-dTomato-CVS N2cG)\n", - "2. Surgery on 2023-01-25 with virus injection (EnvA CVS-N2C-histone-GFP)\n", + "Subject Information:\n", + "- Subject ID: 662616\n", + "- Sex: Female\n", + "- Date of Birth: 2022-11-29\n", + "- Genotype: wt/wt\n", "\n", - "Specimen procedures:\n", - "1. Fixation (SHIELD OFF) from 2023-02-10 to 2023-02-12\n", - "2. Fixation (SHIELD ON) from 2023-02-12 to 2023-02-13\n", - "3. Delipidation (24h Delipidation) from 2023-02-15 to 2023-02-16 \n", - "4. Delipidation (Active Delipidation) from 2023-02-16 to 2023-02-18\n", - "5. Refractive index matching (50% EasyIndex) from 2023-02-19 to 2023-02-20\n", - "6. Refractive index matching (100% EasyIndex) from 2023-02-20 to 2023-02-21\n", + "Acquisition Information:\n", + "- SmartSPIM imaging experiment conducted on 2023-03-06 at 17:47:13\n", + "- Multiple imaging tiles (73-83) with different channels (445nm, 488nm, 561nm wavelengths)\n", + "- Coordinate transformations (translations, scaling) applied to each tile\n", + "- File names/paths for each tile's image data provided\n", + "- Laser power settings for each channel provided\n", + "- Imaging angle of 0 degrees for all tiles\n", + "- Notes on laser power units (percentage of total, needs calibration)\n", + "\n", + "Instrument Details:\n", + "- SmartSPIM1-2 instrument from LifeCanvas manufacturer\n", + "- Temperature control enabled, no humidity control\n", + "- Optical table: MKS Newport VIS3648-PG4-325A, 36x48 inch, vibration control\n", + "- Objective: Thorlabs TL2X-SAP, 0.1 NA, 1.6x magnification, multi-immersion\n", "\n", - "The context does not provide the end dates for the subject procedures (virus injections).\n" + "The context also includes details about virus injections and specimen procedures performed on the subject prior to imaging.\n" ] } ], "source": [ "from metadata_chatbot.agents.GAMER import GAMER\n", - "query = \"Can you list all the procedures performed on the specimen, including their start and end dates? in SmartSPIM_662616_2023-03-06_17-47-13\"\n", + "query = \"Can you tell me summarize the subject and acquisition information in SmartSPIM_662616_2023-03-06_17-47-13\"\n", "\n", "model = GAMER()\n", "result = model.invoke(query)\n", diff --git a/README.md b/README.md index 418e9a4..b767f1d 100644 --- a/README.md +++ b/README.md @@ -52,7 +52,7 @@ print(result) ## High Level Overview -The project's main goal is to developing a chat bot that is able to ingest, analyze and query metadata. Metadata is accumulated in lieu with experiments and consists of information about the data description, subject, equipment and session. To maintain reproducibility standards, it is important for metadata to be documented well. +The project's main goal is to developing a chat bot that is able to ingest, analyze and query metadata. Metadata is accumulated in lieu with experiments and consists of information about the data description, subject, equipment and session. To maintain reproducibility standards, it is important for metadata to be documented well. GAMER is designed to streamline the querying process for neuroscientists and other users. ## Model Overview diff --git a/src/metadata_chatbot/agents/agentic_graph.py b/src/metadata_chatbot/agents/agentic_graph.py index 3bb923b..49e2d93 100644 --- a/src/metadata_chatbot/agents/agentic_graph.py +++ b/src/metadata_chatbot/agents/agentic_graph.py @@ -79,7 +79,7 @@ class FilterGenerator(TypedDict): """MongoDB filter to be applied before vector retrieval""" filter_query: Annotated[dict, ..., "MongoDB filter"] - #top_k: int = Field(description="Number of documents to retrieve from the database") + top_k: int = Annotated[dict, ..., "MongoDB filter"] filter_prompt = hub.pull("eden19/filtergeneration") filter_generator_llm = SONNET_3_LLM.with_structured_output(FilterGenerator) diff --git a/src/metadata_chatbot/agents/async_workflow.py b/src/metadata_chatbot/agents/async_workflow.py index 97a9cf4..c0e34a4 100644 --- a/src/metadata_chatbot/agents/async_workflow.py +++ b/src/metadata_chatbot/agents/async_workflow.py @@ -1,13 +1,12 @@ -import logging, asyncio, json +import asyncio, json from typing import List, Optional from typing_extensions import TypedDict from langchain_core.documents import Document from langgraph.graph import END, StateGraph, START from metadata_chatbot.agents.docdb_retriever import DocDBRetriever +from agentic_graph import datasource_router, query_retriever, filter_generation_chain, doc_grader, rag_chain, db_rag_chain -from metadata_chatbot.agents.agentic_graph import datasource_router, query_retriever, filter_generation_chain, doc_grader, rag_chain, db_rag_chain - -logging.basicConfig(filename='async_workflow.log', level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', filemode="w") +#from metadata_chatbot.agents.agentic_graph import datasource_router, query_retriever, filter_generation_chain, doc_grader, rag_chain, db_rag_chain class GraphState(TypedDict): """ @@ -23,7 +22,7 @@ class GraphState(TypedDict): generation: str documents: List[str] filter: Optional[dict] - #top_k: Optional[int] + top_k: Optional[int] async def route_question_async(state): """ @@ -37,14 +36,13 @@ async def route_question_async(state): query = state["query"] source = await datasource_router.ainvoke({"query": query}) + if source['datasource'] == "direct_database": - logging.info("Entire database needs to be queried.") return "direct_database" elif source['datasource'] == "vectorstore": - logging.info("Querying against vector embeddings...") return "vectorstore" -async def generate_for_whole_db_async(state): +async def retrieve_DB_async(state): """ Filter database @@ -57,8 +55,6 @@ async def generate_for_whole_db_async(state): query = state["query"] - logging.info("Generating answer...") - document_dict = dict() retrieved_dict = await query_retriever.ainvoke({'query': query, 'chat_history': [], 'agent_scratchpad' : []}) document_dict['mongodb_query'] = retrieved_dict['intermediate_steps'][0][0].tool_input['agg_pipeline'] @@ -77,17 +73,15 @@ async def filter_generator_async(state): Returns: state (dict): New key may be added to state, filter, which contains the MongoDB query that will be applied before retrieval """ - logging.info("Determining whether filter is required...") - query = state["query"] result = await filter_generation_chain.ainvoke({"query": query}) filter = result['filter_query'] + top_k = result['top_k'] - logging.info(f"Database will be filtered using: {filter}") - return {"filter": filter, "query": query} + return {"filter": filter, "top_k": top_k, "query": query} -async def retrieve_async(state): +async def retrieve_VI_async(state): """ Retrieve documents @@ -97,26 +91,22 @@ async def retrieve_async(state): Returns: state (dict): New key added to state, documents, that contains retrieved documents """ - logging.info("Retrieving documents...") query = state["query"] filter = state["filter"] + top_k = state["top_k"] - # Retrieval - - retriever = DocDBRetriever(k = 10) + retriever = DocDBRetriever(k = top_k) documents = await retriever.aget_relevant_documents(query = query, query_filter = filter) return {"documents": documents, "query": query} async def grade_doc_async(query, doc: Document): score = await doc_grader.ainvoke({"query": query, "document": doc.page_content}) grade = score['binary_score'] - logging.info(f"Retrieved document matched query: {grade}") + if grade == "yes": - logging.info("Document is relevant to the query") relevant_context = score['relevant_context'] return relevant_context else: - logging.info("Document is not relevant and will be removed") return None @@ -130,8 +120,6 @@ async def grade_documents_async(state): Returns: state (dict): Updates documents key with only filtered relevant documents """ - - logging.info("Checking relevance of documents to question asked...") query = state["query"] documents = state["documents"] @@ -139,7 +127,7 @@ async def grade_documents_async(state): filtered_docs = [doc for doc in filtered_docs if doc is not None] return {"documents": filtered_docs, "query": query} -async def generate_db_async(state): +async def generate_DB_async(state): """ Generate answer @@ -149,17 +137,14 @@ async def generate_db_async(state): Returns: state (dict): New key added to state, generation, that contains LLM generation """ - logging.info("Generating answer...") query = state["query"] documents = state["documents"] - #doc_text = "\n\n".join(doc.page_content for doc in documents) - # RAG generation generation = await db_rag_chain.ainvoke({"documents": documents, "query": query}) return {"documents": documents, "query": query, "generation": generation, "filter": state.get("filter", None)} -async def generate_vi_async(state): +async def generate_VI_async(state): """ Generate answer @@ -169,7 +154,6 @@ async def generate_vi_async(state): Returns: state (dict): New key added to state, generation, that contains LLM generation """ - logging.info("Generating answer...") query = state["query"] documents = state["documents"] @@ -178,12 +162,12 @@ async def generate_vi_async(state): return {"documents": documents, "query": query, "generation": generation, "filter": state.get("filter", None)} async_workflow = StateGraph(GraphState) -async_workflow.add_node("database_query", generate_for_whole_db_async) +async_workflow.add_node("database_query", retrieve_DB_async) async_workflow.add_node("filter_generation", filter_generator_async) -async_workflow.add_node("retrieve", retrieve_async) +async_workflow.add_node("retrieve", retrieve_VI_async) async_workflow.add_node("document_grading", grade_documents_async) -async_workflow.add_node("generate_db", generate_db_async) -async_workflow.add_node("generate_vi", generate_vi_async) +async_workflow.add_node("generate_db", generate_DB_async) +async_workflow.add_node("generate_vi", generate_VI_async) async_workflow.add_conditional_edges( START, @@ -202,11 +186,11 @@ async def generate_vi_async(state): async_app = async_workflow.compile() -# async def main(): -# query = "How many records are stored in the database?" -# inputs = {"query": query} -# answer = await async_app.ainvoke(inputs) -# return answer['generation'] +async def main(): + query = "Can you list all the procedures performed on the specimen, including their start and end dates? in SmartSPIM_662616_2023-03-06_17-47-13" + inputs = {"query": query} + answer = await async_app.ainvoke(inputs) + return answer['generation'] -# #Run the async function -# print(asyncio.run(main())) +#Run the async function +print(asyncio.run(main())) diff --git a/src/metadata_chatbot/agents/workflow.py b/src/metadata_chatbot/agents/workflow.py index 33dbed5..1908d38 100644 --- a/src/metadata_chatbot/agents/workflow.py +++ b/src/metadata_chatbot/agents/workflow.py @@ -22,7 +22,7 @@ class GraphState(TypedDict): generation: str documents: List[str] filter: Optional[dict] - #top_k: Optional[int] + top_k: Optional[int] def route_question(state): """ @@ -80,10 +80,11 @@ def filter_generator(state): query = state["query"] - filter = filter_generation_chain.invoke({"query": query})['filter_query'] - #top_k = filter_generation_chain.invoke({"query": query}).top_k + result = filter_generation_chain.invoke({"query": query}) + filter = result['filter_query'] + top_k = result['top_k'] logging.info(f"Database will be filtered using: {filter}") - return {"filter": filter, "query": query} + return {"filter": filter, "top_k": top_k, "query": query} def retrieve_VI(state): @@ -99,9 +100,9 @@ def retrieve_VI(state): logging.info("Retrieving documents...") query = state["query"] filter = state["filter"] - #top_k = state["top_k"] + top_k = state["top_k"] - retriever = DocDBRetriever(k = 5) + retriever = DocDBRetriever(k = top_k) documents = retriever.get_relevant_documents(query = query, query_filter = filter) return {"documents": documents, "query": query} @@ -137,7 +138,7 @@ def grade_documents(state): #print(filtered_docs) return {"documents": filtered_docs, "query": query} -def generate_db(state): +def generate_DB(state): """ Generate answer @@ -155,7 +156,7 @@ def generate_db(state): generation = db_rag_chain.invoke({"documents": documents, "query": query}) return {"documents": documents, "query": query, "generation": generation, "filter": state.get("filter", None)} -def generate_vi(state): +def generate_VI(state): """ Generate answer @@ -178,8 +179,8 @@ def generate_vi(state): workflow.add_node("filter_generation", filter_generator) workflow.add_node("retrieve", retrieve_VI) workflow.add_node("document_grading", grade_documents) -workflow.add_node("generate_db", generate_db) -workflow.add_node("generate_vi", generate_vi) +workflow.add_node("generate_db", generate_DB) +workflow.add_node("generate_vi", generate_VI) workflow.add_conditional_edges( START, @@ -199,7 +200,7 @@ def generate_vi(state): app = workflow.compile() -# query = "How many records are stored in the database?" +# query = "What was the refractive index of the chamber immersion medium used in this experiment SmartSPIM_675387_2023-05-23_23-05-56?" # inputs = {"query": query} # answer = app.invoke(inputs)