Skip to content

Commit

Permalink
Improve multi-turn capability for agent (#1248)
Browse files Browse the repository at this point in the history
* first code for multi-turn

Signed-off-by: minmin-intel <[email protected]>

* test redispersistence

Signed-off-by: minmin-intel <[email protected]>

* integrate persistent store in react llama

Signed-off-by: minmin-intel <[email protected]>

* test multi-turn

Signed-off-by: minmin-intel <[email protected]>

* multiturn for assistants api and chatcompletion api

Signed-off-by: minmin-intel <[email protected]>

* update readme and ut script

Signed-off-by: minmin-intel <[email protected]>

* update readme and ut scripts

Signed-off-by: minmin-intel <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix bug

Signed-off-by: minmin-intel <[email protected]>

* change memory type naming

Signed-off-by: minmin-intel <[email protected]>

* fix with_memory as str

Signed-off-by: minmin-intel <[email protected]>

---------

Signed-off-by: minmin-intel <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
minmin-intel and pre-commit-ci[bot] authored Feb 14, 2025
1 parent 4a90692 commit 0e3f8ab
Show file tree
Hide file tree
Showing 19 changed files with 813 additions and 335 deletions.
137 changes: 89 additions & 48 deletions comps/agent/src/README.md

Large diffs are not rendered by default.

67 changes: 41 additions & 26 deletions comps/agent/src/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from comps.agent.src.integrations.agent import instantiate_agent
from comps.agent.src.integrations.global_var import assistants_global_kv, threads_global_kv
from comps.agent.src.integrations.thread import instantiate_thread_memory, thread_completion_callback
from comps.agent.src.integrations.utils import assemble_store_messages, get_args
from comps.agent.src.integrations.utils import assemble_store_messages, get_args, get_latest_human_message_from_store
from comps.cores.proto.api_protocol import (
AssistantsObject,
ChatCompletionRequest,
Expand All @@ -40,7 +40,7 @@

logger.info("========initiating agent============")
logger.info(f"args: {args}")
agent_inst = instantiate_agent(args, args.strategy, with_memory=args.with_memory)
agent_inst = instantiate_agent(args)


class AgentCompletionRequest(ChatCompletionRequest):
Expand Down Expand Up @@ -76,7 +76,7 @@ async def llm_generate(input: AgentCompletionRequest):
if isinstance(input.messages, str):
messages = input.messages
else:
# TODO: need handle multi-turn messages
# last user message
messages = input.messages[-1]["content"]

# 2. prepare the input for the agent
Expand All @@ -90,7 +90,6 @@ async def llm_generate(input: AgentCompletionRequest):
else:
logger.info("-----------NOT STREAMING-------------")
response = await agent_inst.non_streaming_run(messages, config)
logger.info("-----------Response-------------")
return GeneratedDoc(text=response, prompt=messages)


Expand All @@ -100,14 +99,14 @@ class RedisConfig(BaseModel):

class AgentConfig(BaseModel):
stream: Optional[bool] = False
agent_name: Optional[str] = "OPEA_Default_Agent"
agent_name: Optional[str] = "OPEA_Agent"
strategy: Optional[str] = "react_llama"
role_description: Optional[str] = "LLM enhanced agent"
role_description: Optional[str] = "AI assistant"
tools: Optional[str] = None
recursion_limit: Optional[int] = 5

model: Optional[str] = "meta-llama/Meta-Llama-3-8B-Instruct"
llm_engine: Optional[str] = None
model: Optional[str] = "meta-llama/Llama-3.3-70B-Instruct"
llm_engine: Optional[str] = "vllm"
llm_endpoint_url: Optional[str] = None
max_new_tokens: Optional[int] = 1024
top_k: Optional[int] = 10
Expand All @@ -117,10 +116,14 @@ class AgentConfig(BaseModel):
return_full_text: Optional[bool] = False
custom_prompt: Optional[str] = None

# short/long term memory
with_memory: Optional[bool] = False
# persistence
with_store: Optional[bool] = False
# # short/long term memory
with_memory: Optional[bool] = True
# agent memory config
# chat_completion api: only supports checkpointer memory
# assistants api: supports checkpointer and store memory
# checkpointer: in-memory checkpointer - MemorySaver()
# store: redis store
memory_type: Optional[str] = "checkpointer" # choices: checkpointer, store
store_config: Optional[RedisConfig] = None

timeout: Optional[int] = 60
Expand All @@ -147,18 +150,17 @@ class CreateAssistant(CreateAssistantsRequest):
)
def create_assistants(input: CreateAssistant):
# 1. initialize the agent
agent_inst = instantiate_agent(
input.agent_config, input.agent_config.strategy, with_memory=input.agent_config.with_memory
)
print("@@@ Initializing agent with config: ", input.agent_config)
agent_inst = instantiate_agent(input.agent_config)
assistant_id = agent_inst.id
created_at = int(datetime.now().timestamp())
with assistants_global_kv as g_assistants:
g_assistants[assistant_id] = (agent_inst, created_at)
logger.info(f"Record assistant inst {assistant_id} in global KV")

if input.agent_config.with_store:
if input.agent_config.memory_type == "store":
logger.info("Save Agent Config to database")
agent_inst.with_store = input.agent_config.with_store
# agent_inst.memory_type = input.agent_config.memory_type
print(input)
global db_client
if db_client is None:
Expand All @@ -172,6 +174,7 @@ def create_assistants(input: CreateAssistant):
return AssistantsObject(
id=assistant_id,
created_at=created_at,
model=input.agent_config.model,
)


Expand Down Expand Up @@ -211,7 +214,7 @@ def create_messages(thread_id, input: CreateMessagesRequest):
if isinstance(input.content, str):
query = input.content
else:
query = input.content[-1]["text"]
query = input.content[-1]["text"] # content is a list of MessageContent
msg_id, created_at = thread_inst.add_query(query)

structured_content = MessageContent(text=query)
Expand All @@ -224,15 +227,18 @@ def create_messages(thread_id, input: CreateMessagesRequest):
assistant_id=input.assistant_id,
)

# save messages using assistant_id as key
# save messages using assistant_id_thread_id as key
if input.assistant_id is not None:
with assistants_global_kv as g_assistants:
agent_inst, _ = g_assistants[input.assistant_id]
if agent_inst.with_store:
logger.info(f"Save Agent Messages, assistant_id: {input.assistant_id}, thread_id: {thread_id}")
if agent_inst.memory_type == "store":
logger.info(f"Save Messages, assistant_id: {input.assistant_id}, thread_id: {thread_id}")
# if with store, db_client initialized already
global db_client
db_client.put(msg_id, message.model_dump_json(), input.assistant_id)
namespace = f"{input.assistant_id}_{thread_id}"
# put(key: str, val: dict, collection: str = DEFAULT_COLLECTION)
db_client.put(msg_id, message.model_dump_json(), namespace)
logger.info(f"@@@ Save message to db: {msg_id}, {message.model_dump_json()}, {namespace}")

return message

Expand All @@ -254,15 +260,24 @@ def create_run(thread_id, input: CreateRunResponse):
with assistants_global_kv as g_assistants:
agent_inst, _ = g_assistants[assistant_id]

config = {"recursion_limit": args.recursion_limit}
config = {
"recursion_limit": args.recursion_limit,
"configurable": {"session_id": thread_id, "thread_id": thread_id, "user_id": assistant_id},
}

if agent_inst.with_store:
# assemble multi-turn messages
if agent_inst.memory_type == "store":
global db_client
input_query = assemble_store_messages(db_client.get_all(assistant_id))
namespace = f"{assistant_id}_{thread_id}"
# get the latest human message from store in the namespace
input_query = get_latest_human_message_from_store(db_client, namespace)
print("@@@@ Input_query from store: ", input_query)
else:
input_query = thread_inst.get_query()
print("@@@@ Input_query from thread_inst: ", input_query)

print("@@@ Agent instance:")
print(agent_inst.id)
print(agent_inst.args)
try:
return StreamingResponse(
thread_completion_callback(agent_inst.stream_generator(input_query, config, thread_id), thread_id),
Expand Down
8 changes: 6 additions & 2 deletions comps/agent/src/integrations/agent.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from .storage.persistence_redis import RedisPersistence
from .utils import load_python_prompt


def instantiate_agent(args, strategy="react_langchain", with_memory=False):
def instantiate_agent(args):
strategy = args.strategy
with_memory = args.with_memory

if args.custom_prompt is not None:
print(f">>>>>> custom_prompt enabled, {args.custom_prompt}")
custom_prompt = load_python_prompt(args.custom_prompt)
Expand All @@ -22,7 +26,7 @@ def instantiate_agent(args, strategy="react_langchain", with_memory=False):
print("Initializing ReAct Agent with LLAMA")
from .strategy.react import ReActAgentLlama

return ReActAgentLlama(args, with_memory, custom_prompt=custom_prompt)
return ReActAgentLlama(args, custom_prompt=custom_prompt)
elif strategy == "plan_execute":
from .strategy.planexec import PlanExecuteAgentWithLangGraph

Expand Down
26 changes: 20 additions & 6 deletions comps/agent/src/integrations/strategy/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

from uuid import uuid4

from langgraph.checkpoint.memory import MemorySaver

from ..storage.persistence_redis import RedisPersistence
from ..tools import get_tools_descriptions
from ..utils import adapt_custom_prompt, setup_chat_model

Expand All @@ -12,11 +15,25 @@ def __init__(self, args, local_vars=None, **kwargs) -> None:
self.llm = setup_chat_model(args)
self.tools_descriptions = get_tools_descriptions(args.tools)
self.app = None
self.memory = None
self.id = f"assistant_{self.__class__.__name__}_{uuid4()}"
self.args = args
adapt_custom_prompt(local_vars, kwargs.get("custom_prompt"))
print(self.tools_descriptions)
print("Registered tools: ", self.tools_descriptions)

if args.with_memory:
if args.memory_type == "checkpointer":
self.memory_type = "checkpointer"
self.checkpointer = MemorySaver()
self.store = None
elif args.memory_type == "store":
# print("Using Redis as store: ", args.store_config.redis_uri)
self.store = RedisPersistence(args.store_config.redis_uri)
self.memory_type = "store"
else:
raise ValueError("Invalid memory type!")
else:
self.store = None
self.checkpointer = None

@property
def is_vllm(self):
Expand Down Expand Up @@ -60,10 +77,7 @@ async def non_streaming_run(self, query, config):
try:
async for s in self.app.astream(initial_state, config=config, stream_mode="values"):
message = s["messages"][-1]
if isinstance(message, tuple):
print(message)
else:
message.pretty_print()
message.pretty_print()

last_message = s["messages"][-1]
print("******Response: ", last_message.content)
Expand Down
Loading

0 comments on commit 0e3f8ab

Please sign in to comment.