Skip to content

Commit

Permalink
Merge pull request #5 from AllenNeuralDynamics/persistence
Browse files Browse the repository at this point in the history
Persistence/Memory/Efficiency
  • Loading branch information
sreyakumar authored Dec 6, 2024
2 parents e6ed30a + a998250 commit 8e3844f
Show file tree
Hide file tree
Showing 9 changed files with 266 additions and 147 deletions.
Binary file added GAMER_workflow.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
47 changes: 0 additions & 47 deletions app.py

This file was deleted.

44 changes: 30 additions & 14 deletions src/metadata_chatbot/agents/GAMER.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,18 @@

import logging, asyncio, uuid

from metadata_chatbot.agents.async_workflow import async_app
#from metadata_chatbot.agents.async_workflow import async_app
from metadata_chatbot.agents.workflow import app
from async_workflow import async_app

from langchain_core.messages import AIMessage, HumanMessage
from streamlit.runtime.scriptrunner import add_script_run_ctx

from typing import Optional, List, Any, AsyncIterator
from langchain.callbacks.manager import AsyncCallbackManager, CallbackManagerForLLMRun
import streamlit as st



class GAMER(LLM):

def _call(
Expand Down Expand Up @@ -52,6 +54,7 @@ async def _acall(
"""
Asynchronous call.
"""

async def main(query):

unique_id = str(uuid.uuid4())
Expand Down Expand Up @@ -101,7 +104,7 @@ async def streamlit_astream(
"""
Asynchronous call.
"""
async def main(query:str):
async def main(query:str, unique_id : str):
config = {"configurable":{"thread_id": unique_id}}
inputs = {
"messages": [HumanMessage(query)],
Expand All @@ -111,17 +114,29 @@ async def main(query:str):
if key != "database_query":
yield value['messages'][0].content
else:
for response in value['messages']:
print(response.content)
yield value['generation']


curr = None
prev = None
generation = None
async for result in main(query):
if curr != None:
st.write(curr)
curr = generation
generation = result
async for result in main(query, unique_id):
if prev != None:
print(prev)
prev = result
generation = prev
return generation

# curr = None
# generation = None
# async for result in main(query):
# if curr != None:
# st.write(curr)
# if "messages" in st.session_state:
# st.session_state.messages.append({"role": "assistant", "content": curr})
# curr = generation
# generation = result
# return generation



Expand All @@ -137,17 +152,18 @@ def _llm_type(self) -> str:
"""Get the type of language model used by this chat model. Used for logging purposes only."""
return "Claude 3 Sonnet"

# llm = GAMER()
llm = GAMER()

# async def main():
# query = "Can you list all the procedures performed on the specimen, including their start and end dates? in SmartSPIM_662616_2023-03-06_17-47-13"
# result = await llm.ainvoke(query)
# query = "How many records are in the database?"
# result = await llm.streamlit_astream(query, unique_id = "1")
# print(result)


# asyncio.run(main())

# async def main():
# result = await llm.ainvoke("Can you give me a timeline of events for subject 675387?")
# result = await llm.ainvoke("How many records are in the database?")
# print(result)

# asyncio.run(main())
2 changes: 0 additions & 2 deletions src/metadata_chatbot/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1 @@
"""Init package"""
__version__ = "0.0.12"

38 changes: 34 additions & 4 deletions src/metadata_chatbot/agents/agentic_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@
from aind_data_access_api.document_db import MetadataDbClient
from typing_extensions import Annotated, TypedDict
from langgraph.prebuilt import create_react_agent
from langchain_core.prompts import ChatPromptTemplate

MODEL_ID_SONNET_3 = "anthropic.claude-3-sonnet-20240229-v1:0"
MODEL_ID_SONNET_3_5 = "anthropic.claude-3-5-sonnet-20240620-v1:0"
MODEL_ID_HAIKU_3_5 = "anthropic.claude-3-5-haiku-20241022-v1:0"

SONNET_3_LLM = ChatBedrock(
model_id= MODEL_ID_SONNET_3,
model_kwargs= {
Expand All @@ -26,17 +29,39 @@
streaming = True
)

HAIKU_3_5_LLM = ChatBedrock(
model_id= MODEL_ID_HAIKU_3_5,
model_kwargs= {
"temperature": 0
},
streaming = True
)

# Determining if entire database needs to be surveyed
class RouteQuery(TypedDict):
"""Route a user query to the most relevant datasource."""

#reasoning: Annotated[str, ..., "Give a one sentence justification for the chosen method"]
datasource: Annotated[Literal["vectorstore", "direct_database"], ..., "Given a user question choose to route it to the direct database or its vectorstore."]
datasource: Annotated[Literal["vectorstore", "direct_database", "claude"],
...,
"Given a user question choose to route it to the direct database or its vectorstore. If a question can be answered without retrieval, route to claude"]

structured_llm_router = SONNET_3_LLM.with_structured_output(RouteQuery)
structured_llm_router = HAIKU_3_5_LLM.with_structured_output(RouteQuery)
router_prompt = hub.pull("eden19/query_rerouter")
datasource_router = router_prompt | structured_llm_router

# Check if retrieved documents answer question
class QueryRewriter(TypedDict):
"""Rewrite ambiguous queries"""

#relevant_context:Annotated[str, ..., "Relevant context extracted from document that helps directly answer the question"]
binary_score: Annotated[Literal["yes", "no"], ..., "Query is ambiguous, 'yes' or 'no'"]
rewritten_query: Annotated[str, ..., "user's query, rewritten to be more specific"]

query_rewriter = HAIKU_3_5_LLM.with_structured_output(QueryRewriter)
query_rewriter_prompt = hub.pull("eden19/query_rewriter")
query_rewriter_chain = query_rewriter_prompt | query_rewriter

# Generating appropriate filter
class FilterGenerator(TypedDict):
"""MongoDB filter to be applied before vector retrieval"""
Expand All @@ -45,7 +70,7 @@ class FilterGenerator(TypedDict):
top_k: int = Annotated[dict, ..., "MongoDB filter"]

filter_prompt = hub.pull("eden19/filtergeneration")
filter_generator_llm = SONNET_3_LLM.with_structured_output(FilterGenerator)
filter_generator_llm = HAIKU_3_5_LLM .with_structured_output(FilterGenerator)
filter_generation_chain = filter_prompt | filter_generator_llm


Expand All @@ -66,5 +91,10 @@ class RetrievalGrader(TypedDict):

# Generating response to documents retrieved from the database
db_answer_generation_prompt = hub.pull("eden19/db_answergeneration")
db_rag_chain = db_answer_generation_prompt | SONNET_3_5_LLM | StrOutputParser()
db_rag_chain = db_answer_generation_prompt | SONNET_3_5_LLM | StrOutputParser()

# Generating response from previous context
prompt = ChatPromptTemplate.from_template("Answer {query} based on the following texts: {chat_history}")
prev_context_chain = prompt | HAIKU_3_5_LLM | StrOutputParser()


95 changes: 95 additions & 0 deletions src/metadata_chatbot/agents/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Import the Streamlit library
import streamlit as st
import asyncio
import uuid

# import sys
# import os
# sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from async_workflow import async_workflow
from react_agent import astream_input

from langchain_core.messages import HumanMessage, AIMessage

import warnings
warnings.filterwarnings('ignore')

#run on terminal with streamlit run c:/Users/sreya.kumar/Documents/GitHub/metadata-chatbot/src/metadata_chatbot/agents/app.py [ARGUMENTS]

unique_id = str(uuid.uuid4())

async def main():
st.title("GAMER: Generative Analysis of Metadata Retrieval")

message = st.chat_message("assistant")
message.write("Hello! How can I help you?")

query = st.chat_input("Ask a question about the AIND Metadata!")

if "messages" not in st.session_state:
st.session_state.messages = []

model = async_workflow.compile()

for message in st.session_state.messages:
if isinstance(message, HumanMessage):
with st.chat_message("user"):
st.markdown(message.content)
else:
with st.chat_message("assistant"):
st.markdown(message.content)

if query is not None and query != '':
st.session_state.messages.append(HumanMessage(query))

with st.chat_message("user"):
st.markdown(query)

with st.chat_message("assistant"):
async def main(query: str):
chat_history = st.session_state.messages
#config = {"configurable":{"thread_id": st.session_state.unique_id}}
inputs = {
"messages": chat_history,
}
async for output in model.astream(inputs):
for key, value in output.items():
if key != "database_query":
yield value['messages'][0].content
else:
try:
query = str(chat_history) + query
async for result in astream_input(query = query):
response = result['type']
if response == 'intermediate_steps':
yield result['content']
if response == 'agg_pipeline':
yield f"The MongoDB pipeline used to on the database is: {result['content']}"
if response == 'tool_response':
yield f"Retrieved output from MongoDB: {result['content']}"
if response == 'final_answer':
yield result['content']

except Exception as e:
yield f"An error has occured with the retrieval from DocDB: {e}. Try structuring your query another way."
# for response in value['messages']:
# yield response.content
# yield value['generation']

prev = None
generation = None
async for result in main(query):
if prev != None:
st.markdown(prev)
prev = result
generation = prev
st.markdown(generation)
st.session_state.messages.append(AIMessage(generation))
# response = await llm.streamlit_astream(query, unique_id = unique_id)
# st.markdown(response)



if __name__ == "__main__":
asyncio.run(main())
Loading

0 comments on commit 8e3844f

Please sign in to comment.