Skip to content

Commit

Permalink
implemented top_k
Browse files Browse the repository at this point in the history
  • Loading branch information
sreyakumar committed Nov 4, 2024
1 parent 8b90e94 commit b9b9c47
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 69 deletions.
39 changes: 24 additions & 15 deletions GAMER_workbook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -37,42 +37,51 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\sreya.kumar\\Documents\\GitHub\\metadata-chatbot\\venv\\Lib\\site-packages\\metadata_chatbot\\agents\\workflow.py:111: LangChainDeprecationWarning: The method `BaseRetriever.get_relevant_documents` was deprecated in langchain-core 0.1.46 and will be removed in 1.0. Use :meth:`~invoke` instead.\n",
"c:\\Users\\sreya.kumar\\Documents\\GitHub\\metadata-chatbot\\venv\\Lib\\site-packages\\metadata_chatbot\\agents\\workflow.py:105: LangChainDeprecationWarning: The method `BaseRetriever.get_relevant_documents` was deprecated in langchain-core 0.1.46 and will be removed in 1.0. Use :meth:`~invoke` instead.\n",
" documents = retriever.get_relevant_documents(query = query, query_filter = filter)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Based on the provided context, here are the procedures performed on specimen 662616 along with their start and end dates:\n",
"Based on the provided context, here is a summary of the subject and acquisition information for SmartSPIM_662616_2023-03-06_17-47-13:\n",
"\n",
"Subject procedures:\n",
"1. Surgery on 2023-01-25 with virus injections (SL1-hSyn-Cre, AAV1-CAG-H2B-mTurquoise2-WPRE, AAV-Syn-DIO-TVA66T-dTomato-CVS N2cG)\n",
"2. Surgery on 2023-01-25 with virus injection (EnvA CVS-N2C-histone-GFP)\n",
"Subject Information:\n",
"- Subject ID: 662616\n",
"- Sex: Female\n",
"- Date of Birth: 2022-11-29\n",
"- Genotype: wt/wt\n",
"\n",
"Specimen procedures:\n",
"1. Fixation (SHIELD OFF) from 2023-02-10 to 2023-02-12\n",
"2. Fixation (SHIELD ON) from 2023-02-12 to 2023-02-13\n",
"3. Delipidation (24h Delipidation) from 2023-02-15 to 2023-02-16 \n",
"4. Delipidation (Active Delipidation) from 2023-02-16 to 2023-02-18\n",
"5. Refractive index matching (50% EasyIndex) from 2023-02-19 to 2023-02-20\n",
"6. Refractive index matching (100% EasyIndex) from 2023-02-20 to 2023-02-21\n",
"Acquisition Information:\n",
"- SmartSPIM imaging experiment conducted on 2023-03-06 at 17:47:13\n",
"- Multiple imaging tiles (73-83) with different channels (445nm, 488nm, 561nm wavelengths)\n",
"- Coordinate transformations (translations, scaling) applied to each tile\n",
"- File names/paths for each tile's image data provided\n",
"- Laser power settings for each channel provided\n",
"- Imaging angle of 0 degrees for all tiles\n",
"- Notes on laser power units (percentage of total, needs calibration)\n",
"\n",
"Instrument Details:\n",
"- SmartSPIM1-2 instrument from LifeCanvas manufacturer\n",
"- Temperature control enabled, no humidity control\n",
"- Optical table: MKS Newport VIS3648-PG4-325A, 36x48 inch, vibration control\n",
"- Objective: Thorlabs TL2X-SAP, 0.1 NA, 1.6x magnification, multi-immersion\n",
"\n",
"The context does not provide the end dates for the subject procedures (virus injections).\n"
"The context also includes details about virus injections and specimen procedures performed on the subject prior to imaging.\n"
]
}
],
"source": [
"from metadata_chatbot.agents.GAMER import GAMER\n",
"query = \"Can you list all the procedures performed on the specimen, including their start and end dates? in SmartSPIM_662616_2023-03-06_17-47-13\"\n",
"query = \"Can you tell me summarize the subject and acquisition information in SmartSPIM_662616_2023-03-06_17-47-13\"\n",
"\n",
"model = GAMER()\n",
"result = model.invoke(query)\n",
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ print(result)

## High Level Overview

The project's main goal is to developing a chat bot that is able to ingest, analyze and query metadata. Metadata is accumulated in lieu with experiments and consists of information about the data description, subject, equipment and session. To maintain reproducibility standards, it is important for metadata to be documented well.
The project's main goal is to developing a chat bot that is able to ingest, analyze and query metadata. Metadata is accumulated in lieu with experiments and consists of information about the data description, subject, equipment and session. To maintain reproducibility standards, it is important for metadata to be documented well. GAMER is designed to streamline the querying process for neuroscientists and other users.

## Model Overview

Expand Down
2 changes: 1 addition & 1 deletion src/metadata_chatbot/agents/agentic_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class FilterGenerator(TypedDict):
"""MongoDB filter to be applied before vector retrieval"""

filter_query: Annotated[dict, ..., "MongoDB filter"]
#top_k: int = Field(description="Number of documents to retrieve from the database")
top_k: int = Annotated[dict, ..., "MongoDB filter"]

filter_prompt = hub.pull("eden19/filtergeneration")
filter_generator_llm = SONNET_3_LLM.with_structured_output(FilterGenerator)
Expand Down
66 changes: 25 additions & 41 deletions src/metadata_chatbot/agents/async_workflow.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import logging, asyncio, json
import asyncio, json
from typing import List, Optional
from typing_extensions import TypedDict
from langchain_core.documents import Document
from langgraph.graph import END, StateGraph, START
from metadata_chatbot.agents.docdb_retriever import DocDBRetriever
from agentic_graph import datasource_router, query_retriever, filter_generation_chain, doc_grader, rag_chain, db_rag_chain

from metadata_chatbot.agents.agentic_graph import datasource_router, query_retriever, filter_generation_chain, doc_grader, rag_chain, db_rag_chain

logging.basicConfig(filename='async_workflow.log', level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', filemode="w")
#from metadata_chatbot.agents.agentic_graph import datasource_router, query_retriever, filter_generation_chain, doc_grader, rag_chain, db_rag_chain

class GraphState(TypedDict):
"""
Expand All @@ -23,7 +22,7 @@ class GraphState(TypedDict):
generation: str
documents: List[str]
filter: Optional[dict]
#top_k: Optional[int]
top_k: Optional[int]

async def route_question_async(state):
"""
Expand All @@ -37,14 +36,13 @@ async def route_question_async(state):
query = state["query"]

source = await datasource_router.ainvoke({"query": query})

if source['datasource'] == "direct_database":
logging.info("Entire database needs to be queried.")
return "direct_database"
elif source['datasource'] == "vectorstore":
logging.info("Querying against vector embeddings...")
return "vectorstore"

async def generate_for_whole_db_async(state):
async def retrieve_DB_async(state):
"""
Filter database
Expand All @@ -57,8 +55,6 @@ async def generate_for_whole_db_async(state):

query = state["query"]

logging.info("Generating answer...")

document_dict = dict()
retrieved_dict = await query_retriever.ainvoke({'query': query, 'chat_history': [], 'agent_scratchpad' : []})
document_dict['mongodb_query'] = retrieved_dict['intermediate_steps'][0][0].tool_input['agg_pipeline']
Expand All @@ -77,17 +73,15 @@ async def filter_generator_async(state):
Returns:
state (dict): New key may be added to state, filter, which contains the MongoDB query that will be applied before retrieval
"""
logging.info("Determining whether filter is required...")

query = state["query"]

result = await filter_generation_chain.ainvoke({"query": query})
filter = result['filter_query']
top_k = result['top_k']

logging.info(f"Database will be filtered using: {filter}")
return {"filter": filter, "query": query}
return {"filter": filter, "top_k": top_k, "query": query}

async def retrieve_async(state):
async def retrieve_VI_async(state):
"""
Retrieve documents
Expand All @@ -97,26 +91,22 @@ async def retrieve_async(state):
Returns:
state (dict): New key added to state, documents, that contains retrieved documents
"""
logging.info("Retrieving documents...")
query = state["query"]
filter = state["filter"]
top_k = state["top_k"]

# Retrieval

retriever = DocDBRetriever(k = 10)
retriever = DocDBRetriever(k = top_k)
documents = await retriever.aget_relevant_documents(query = query, query_filter = filter)
return {"documents": documents, "query": query}

async def grade_doc_async(query, doc: Document):
score = await doc_grader.ainvoke({"query": query, "document": doc.page_content})
grade = score['binary_score']
logging.info(f"Retrieved document matched query: {grade}")

if grade == "yes":
logging.info("Document is relevant to the query")
relevant_context = score['relevant_context']
return relevant_context
else:
logging.info("Document is not relevant and will be removed")
return None


Expand All @@ -130,16 +120,14 @@ async def grade_documents_async(state):
Returns:
state (dict): Updates documents key with only filtered relevant documents
"""

logging.info("Checking relevance of documents to question asked...")
query = state["query"]
documents = state["documents"]

filtered_docs = await asyncio.gather(*[grade_doc_async(query, doc) for doc in documents])
filtered_docs = [doc for doc in filtered_docs if doc is not None]
return {"documents": filtered_docs, "query": query}

async def generate_db_async(state):
async def generate_DB_async(state):
"""
Generate answer
Expand All @@ -149,17 +137,14 @@ async def generate_db_async(state):
Returns:
state (dict): New key added to state, generation, that contains LLM generation
"""
logging.info("Generating answer...")
query = state["query"]
documents = state["documents"]

#doc_text = "\n\n".join(doc.page_content for doc in documents)

# RAG generation
generation = await db_rag_chain.ainvoke({"documents": documents, "query": query})
return {"documents": documents, "query": query, "generation": generation, "filter": state.get("filter", None)}

async def generate_vi_async(state):
async def generate_VI_async(state):
"""
Generate answer
Expand All @@ -169,7 +154,6 @@ async def generate_vi_async(state):
Returns:
state (dict): New key added to state, generation, that contains LLM generation
"""
logging.info("Generating answer...")
query = state["query"]
documents = state["documents"]

Expand All @@ -178,12 +162,12 @@ async def generate_vi_async(state):
return {"documents": documents, "query": query, "generation": generation, "filter": state.get("filter", None)}

async_workflow = StateGraph(GraphState)
async_workflow.add_node("database_query", generate_for_whole_db_async)
async_workflow.add_node("database_query", retrieve_DB_async)
async_workflow.add_node("filter_generation", filter_generator_async)
async_workflow.add_node("retrieve", retrieve_async)
async_workflow.add_node("retrieve", retrieve_VI_async)
async_workflow.add_node("document_grading", grade_documents_async)
async_workflow.add_node("generate_db", generate_db_async)
async_workflow.add_node("generate_vi", generate_vi_async)
async_workflow.add_node("generate_db", generate_DB_async)
async_workflow.add_node("generate_vi", generate_VI_async)

async_workflow.add_conditional_edges(
START,
Expand All @@ -202,11 +186,11 @@ async def generate_vi_async(state):

async_app = async_workflow.compile()

# async def main():
# query = "How many records are stored in the database?"
# inputs = {"query": query}
# answer = await async_app.ainvoke(inputs)
# return answer['generation']
async def main():
query = "Can you list all the procedures performed on the specimen, including their start and end dates? in SmartSPIM_662616_2023-03-06_17-47-13"
inputs = {"query": query}
answer = await async_app.ainvoke(inputs)
return answer['generation']

# #Run the async function
# print(asyncio.run(main()))
#Run the async function
print(asyncio.run(main()))
23 changes: 12 additions & 11 deletions src/metadata_chatbot/agents/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class GraphState(TypedDict):
generation: str
documents: List[str]
filter: Optional[dict]
#top_k: Optional[int]
top_k: Optional[int]

def route_question(state):
"""
Expand Down Expand Up @@ -80,10 +80,11 @@ def filter_generator(state):

query = state["query"]

filter = filter_generation_chain.invoke({"query": query})['filter_query']
#top_k = filter_generation_chain.invoke({"query": query}).top_k
result = filter_generation_chain.invoke({"query": query})
filter = result['filter_query']
top_k = result['top_k']
logging.info(f"Database will be filtered using: {filter}")
return {"filter": filter, "query": query}
return {"filter": filter, "top_k": top_k, "query": query}


def retrieve_VI(state):
Expand All @@ -99,9 +100,9 @@ def retrieve_VI(state):
logging.info("Retrieving documents...")
query = state["query"]
filter = state["filter"]
#top_k = state["top_k"]
top_k = state["top_k"]

retriever = DocDBRetriever(k = 5)
retriever = DocDBRetriever(k = top_k)
documents = retriever.get_relevant_documents(query = query, query_filter = filter)
return {"documents": documents, "query": query}

Expand Down Expand Up @@ -137,7 +138,7 @@ def grade_documents(state):
#print(filtered_docs)
return {"documents": filtered_docs, "query": query}

def generate_db(state):
def generate_DB(state):
"""
Generate answer
Expand All @@ -155,7 +156,7 @@ def generate_db(state):
generation = db_rag_chain.invoke({"documents": documents, "query": query})
return {"documents": documents, "query": query, "generation": generation, "filter": state.get("filter", None)}

def generate_vi(state):
def generate_VI(state):
"""
Generate answer
Expand All @@ -178,8 +179,8 @@ def generate_vi(state):
workflow.add_node("filter_generation", filter_generator)
workflow.add_node("retrieve", retrieve_VI)
workflow.add_node("document_grading", grade_documents)
workflow.add_node("generate_db", generate_db)
workflow.add_node("generate_vi", generate_vi)
workflow.add_node("generate_db", generate_DB)
workflow.add_node("generate_vi", generate_VI)

workflow.add_conditional_edges(
START,
Expand All @@ -199,7 +200,7 @@ def generate_vi(state):

app = workflow.compile()

# query = "How many records are stored in the database?"
# query = "What was the refractive index of the chamber immersion medium used in this experiment SmartSPIM_675387_2023-05-23_23-05-56?"

# inputs = {"query": query}
# answer = app.invoke(inputs)
Expand Down

0 comments on commit b9b9c47

Please sign in to comment.