Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

upgraded to 20k collection #3

Merged
merged 1 commit into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ test_agentic_graph_VI.py
test.py
graph_viz.ipynb
multi-agent-workflow-11-01.jpeg
bedrock_model/
embeddings.py
umap_visualization.py

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
56 changes: 29 additions & 27 deletions GAMER_workbook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -44,44 +44,46 @@
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\sreya.kumar\\Documents\\GitHub\\metadata-chatbot\\venv\\Lib\\site-packages\\metadata_chatbot\\agents\\workflow.py:105: LangChainDeprecationWarning: The method `BaseRetriever.get_relevant_documents` was deprecated in langchain-core 0.1.46 and will be removed in 1.0. Use :meth:`~invoke` instead.\n",
"c:\\Users\\sreya.kumar\\Documents\\GitHub\\metadata-chatbot\\venv\\Lib\\site-packages\\metadata_chatbot\\agents\\workflow.py:106: LangChainDeprecationWarning: The method `BaseRetriever.get_relevant_documents` was deprecated in langchain-core 0.1.46 and will be removed in 1.0. Use :meth:`~invoke` instead.\n",
" documents = retriever.get_relevant_documents(query = query, query_filter = filter)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Based on the provided context, here is a summary of the subject and acquisition information for SmartSPIM_662616_2023-03-06_17-47-13:\n",
"The subject information provided is:\n",
"\n",
"Subject Information:\n",
"- Subject ID: 662616\n",
"- Sex: Female\n",
"- Date of Birth: 2022-11-29\n",
"- Genotype: wt/wt\n",
"- subject_id: 662616\n",
"- sex: Female\n",
"- date_of_birth: 2022-11-29\n",
"- genotype: wt/wt\n",
"\n",
"Acquisition Information:\n",
"- SmartSPIM imaging experiment conducted on 2023-03-06 at 17:47:13\n",
"- Multiple imaging tiles (73-83) with different channels (445nm, 488nm, 561nm wavelengths)\n",
"- Coordinate transformations (translations, scaling) applied to each tile\n",
"- File names/paths for each tile's image data provided\n",
"- Laser power settings for each channel provided\n",
"- Imaging angle of 0 degrees for all tiles\n",
"- Notes on laser power units (percentage of total, needs calibration)\n",
"The acquisition information includes:\n",
"\n",
"Instrument Details:\n",
"- SmartSPIM1-2 instrument from LifeCanvas manufacturer\n",
"- Temperature control enabled, no humidity control\n",
"- Optical table: MKS Newport VIS3648-PG4-325A, 36x48 inch, vibration control\n",
"- Objective: Thorlabs TL2X-SAP, 0.1 NA, 1.6x magnification, multi-immersion\n",
"- Tiles 73-83 with different channels (445nm, 488nm, 561nm) and coordinate transformations\n",
"- Laser powers in milliwatts for each channel \n",
"- File names with coordinates for each tile/channel\n",
"- Imaging angle of 0 degrees\n",
"- Notes about laser power needing calibration\n",
"\n",
"The context also includes details about virus injections and specimen procedures performed on the subject prior to imaging.\n"
"There were two procedures performed:\n",
"\n",
"Procedure 1 (injection):\n",
"- Injection materials included viruses SL1-hSyn-Cre, AAV1-CAG-H2B-mTurquoise2-WPRE, and AAV-Syn-DIO-TVA66T-dTomato-CVS N2cG\n",
"- Injection coordinates and volumes provided\n",
"\n",
"Procedure 2 (surgery): \n",
"- Injection material was virus EnvA CVS-N2C-histone-GFP\n",
"- Injection coordinates and volumes provided\n",
"\n",
"Additional information includes specimen procedures like fixation, delipidation, and refractive index matching.\n"
]
}
],
"source": [
"from metadata_chatbot.agents.GAMER import GAMER\n",
"query = \"Can you tell me summarize the subject and acquisition information in SmartSPIM_662616_2023-03-06_17-47-13\"\n",
"query = \"Can you summarize the subject and acquisition information in SmartSPIM_662616_2023-03-06_17-47-13\"\n",
"\n",
"model = GAMER()\n",
"result = model.invoke(query)\n",
Expand All @@ -98,22 +100,22 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\sreya.kumar\\Documents\\GitHub\\metadata-chatbot\\venv\\Lib\\site-packages\\metadata_chatbot\\agents\\async_workflow.py:106: LangChainDeprecationWarning: The method `BaseRetriever.aget_relevant_documents` was deprecated in langchain-core 0.1.46 and will be removed in 1.0. Use :meth:`~ainvoke` instead.\n",
"c:\\Users\\sreya.kumar\\Documents\\GitHub\\metadata-chatbot\\venv\\Lib\\site-packages\\metadata_chatbot\\agents\\async_workflow.py:98: LangChainDeprecationWarning: The method `BaseRetriever.aget_relevant_documents` was deprecated in langchain-core 0.1.46 and will be removed in 1.0. Use :meth:`~ainvoke` instead.\n",
" documents = await retriever.aget_relevant_documents(query = query, query_filter = filter)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Based on the provided context, the procedures performed on specimen 662616 and their start and end dates are:\n",
"Based on the provided context, here are the procedures performed on the specimen SmartSPIM_662616_2023-03-06_17-47-13 and their start and end dates:\n",
"\n",
"Subject procedures:\n",
"1. Surgery on 2023-01-25 with virus injections (SL1-hSyn-Cre, AAV1-CAG-H2B-mTurquoise2-WPRE, AAV-Syn-DIO-TVA66T-dTomato-CVS N2cG)\n",
Expand All @@ -122,11 +124,11 @@
"Specimen procedures:\n",
"1. Fixation (SHIELD OFF) from 2023-02-10 to 2023-02-12\n",
"2. Fixation (SHIELD ON) from 2023-02-12 to 2023-02-13\n",
"3. Delipidation (24h Delipidation) from 2023-02-15 to 2023-02-16\n",
"3. Delipidation (24h Delipidation) from 2023-02-15 to 2023-02-16 \n",
"4. Delipidation (Active Delipidation) from 2023-02-16 to 2023-02-18\n",
"5. Refractive index matching (50% EasyIndex) from 2023-02-19 to 2023-02-20\n",
"6. Refractive index matching (100% EasyIndex) from 2023-02-20 to 2023-02-21\n",
"7. Imaging acquisition session from 2023-03-06T17:47:13 to 2023-03-06T22:59:16\n"
"7. Imaging on SmartSPIM1-1 from 2023-03-06T17:47:13 to 2023-03-06T22:59:16\n"
]
}
],
Expand Down
13 changes: 9 additions & 4 deletions src/metadata_chatbot/agents/agentic_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from langchain.agents import AgentExecutor, create_tool_calling_agent
from aind_data_access_api.document_db import MetadataDbClient
from typing_extensions import Annotated, TypedDict
from langgraph.prebuilt import create_react_agent

MODEL_ID_SONNET_3 = "anthropic.claude-3-sonnet-20240229-v1:0"
MODEL_ID_SONNET_3_5 = "anthropic.claude-3-sonnet-20240229-v1:0"
Expand All @@ -27,17 +28,17 @@
class RouteQuery(TypedDict):
"""Route a user query to the most relevant datasource."""

reasoning: Annotated[str, ..., "Give a one sentence justification for the chosen method"]
#reasoning: Annotated[str, ..., "Give a one sentence justification for the chosen method"]
datasource: Annotated[Literal["vectorstore", "direct_database"], ..., "Given a user question choose to route it to the direct database or its vectorstore."]

structured_llm_router = SONNET_3_LLM.with_structured_output(RouteQuery)
router_prompt = hub.pull("eden19/query_rerouter")
datasource_router = router_prompt | structured_llm_router

# Tool to survey entire database
API_GATEWAY_HOST = "api.allenneuraldynamics-test.org"
DATABASE = "metadata_vector_index"
COLLECTION = "curated_assets"
API_GATEWAY_HOST = "api.allenneuraldynamics.org"
DATABASE = "metadata_index"
COLLECTION = "data_assets"

docdb_api_client = MetadataDbClient(
host=API_GATEWAY_HOST,
Expand Down Expand Up @@ -69,7 +70,11 @@ def aggregation_retrieval(agg_pipeline: list) -> list:
return result

tools = [aggregation_retrieval]
tool_model = SONNET_3_5_LLM.bind_tools(tools)

db_prompt = hub.pull("eden19/entire_db_retrieval")
langgraph_agent_executor = create_react_agent(SONNET_3_LLM, tools=tools, state_modifier= db_prompt)

db_surveyor_agent = create_tool_calling_agent(SONNET_3_LLM, tools, db_prompt)
query_retriever = AgentExecutor(agent=db_surveyor_agent, tools=tools, return_intermediate_steps = True, verbose=False)

Expand Down
73 changes: 50 additions & 23 deletions src/metadata_chatbot/agents/async_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
from typing_extensions import TypedDict
from langchain_core.documents import Document
from langgraph.graph import END, StateGraph, START
from docdb_retriever import DocDBRetriever
from metadata_chatbot.agents.docdb_retriever import DocDBRetriever

from metadata_chatbot.agents.agentic_graph import datasource_router, query_retriever, filter_generation_chain, doc_grader, rag_chain, db_rag_chain
from react_agent import react_agent
from langchain_core.messages.ai import AIMessage
from metadata_chatbot.agents.agentic_graph import datasource_router, filter_generation_chain, doc_grader, rag_chain, db_rag_chain

class GraphState(TypedDict):
"""
Expand All @@ -19,7 +21,7 @@ class GraphState(TypedDict):

query: str
generation: str
documents: List[str]
documents: Optional[List[str]]
filter: Optional[dict]
top_k: Optional[int]

Expand All @@ -41,9 +43,24 @@ async def route_question_async(state):
elif source['datasource'] == "vectorstore":
return "vectorstore"

def print_stream(stream):
message_list = []
for s in stream:
message_list.append(s)
message = s["messages"][-1]
if isinstance(message, tuple):
print(message)
else:
message.pretty_print()

for message in message_list[-1]['messages']:
if isinstance(message, AIMessage):
final_answer = message.content
return final_answer

async def retrieve_DB_async(state):
"""
Filter database
Retrieves from data asset collection in prod DB after constructing a MongoDB query

Args:
state (dict): The current graph state
Expand All @@ -53,14 +70,27 @@ async def retrieve_DB_async(state):
"""

query = state["query"]
inputs = {"messages": [("user", query)]}

generation = print_stream(react_agent.stream(inputs, stream_mode="values"))

# generation = react_agent.invoke(inputs)
# AIMessage_list = []
# for message in generation['messages']:
# if isinstance(message, AIMessage):
# AIMessage_list.append(message)

# final_answer = AIMessage_list[-1].content

# document_dict = dict()
# retrieved_dict = await query_retriever.ainvoke({'query': query, 'chat_history': [], 'agent_scratchpad' : []})
# document_dict['mongodb_query'] = retrieved_dict['intermediate_steps'][0][0].tool_input['agg_pipeline']
# document_dict['retrieved_output'] = retrieved_dict['intermediate_steps'][0][1]
# documents = await asyncio.to_thread(json.dumps, document_dict)


document_dict = dict()
retrieved_dict = await query_retriever.ainvoke({'query': query, 'chat_history': [], 'agent_scratchpad' : []})
document_dict['mongodb_query'] = retrieved_dict['intermediate_steps'][0][0].tool_input['agg_pipeline']
document_dict['retrieved_output'] = retrieved_dict['intermediate_steps'][0][1]
documents = await asyncio.to_thread(json.dumps, document_dict)

return {"query": query, "documents": documents}
return {"query": query, "generation": ''}

async def filter_generator_async(state):
"""
Expand Down Expand Up @@ -103,10 +133,7 @@ async def grade_doc_async(query, doc: Document):
grade = score['binary_score']

if grade == "yes":
relevant_context = score['relevant_context']
return relevant_context
else:
return None
return doc.page_content


async def grade_documents_async(state):
Expand Down Expand Up @@ -165,7 +192,7 @@ async def generate_VI_async(state):
async_workflow.add_node("filter_generation", filter_generator_async)
async_workflow.add_node("retrieve", retrieve_VI_async)
async_workflow.add_node("document_grading", grade_documents_async)
async_workflow.add_node("generate_db", generate_DB_async)
#async_workflow.add_node("generate_db", generate_DB_async)
async_workflow.add_node("generate_vi", generate_VI_async)

async_workflow.add_conditional_edges(
Expand All @@ -176,20 +203,20 @@ async def generate_VI_async(state):
"vectorstore": "filter_generation",
},
)
async_workflow.add_edge("database_query", "generate_db")
async_workflow.add_edge("generate_db", END)
async_workflow.add_edge("database_query", END)
#async_workflow.add_edge("generate_db", END)
async_workflow.add_edge("filter_generation", "retrieve")
async_workflow.add_edge("retrieve", "document_grading")
async_workflow.add_edge("document_grading","generate_vi")
async_workflow.add_edge("generate_vi", END)

async_app = async_workflow.compile()

# async def main():
# query = "Can you list all the procedures performed on the specimen, including their start and end dates? in SmartSPIM_662616_2023-03-06_17-47-13"
# inputs = {"query": query}
# answer = await async_app.ainvoke(inputs)
# return answer['generation']
async def main():
query = "How many records are in the dataset?"
inputs = {"query": query}
answer = await async_app.ainvoke(inputs)
return answer['generation']

# #Run the async function
#Run the async function
# print(asyncio.run(main()))
Loading
Loading