Skip to content

Commit

Permalink
feat:support llama agent
Browse files Browse the repository at this point in the history
  • Loading branch information
imotai committed Nov 5, 2023
1 parent c19bfd3 commit c1e0de3
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 157 deletions.
1 change: 0 additions & 1 deletion agent/src/og_agent/agent_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

""" """
import json
from .prompt import OCTOGEN_FUNCTION_SYSTEM, OCTOGEN_CODELLAMA_SYSTEM
from .llama_agent import LlamaAgent
from .openai_agent import OpenaiAgent
from .llama_client import LlamaClient
Expand Down
30 changes: 21 additions & 9 deletions agent/src/og_agent/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class TypingState:
CODE = 2
LANGUAGE = 3
MESSAGE = 4

OTHER = 5

class BaseAgent:

Expand All @@ -70,7 +70,7 @@ def __init__(self, sdk):
self.model_name = ""
self.agent_memories = {}

def create_new_memory_with_default_prompt(self, user_name, user_id):
def create_new_memory_with_default_prompt(self, user_name, user_id, actions = ACTIONS):
"""
create a new memory for the user
"""
Expand All @@ -79,7 +79,7 @@ def create_new_memory_with_default_prompt(self, user_name, user_id):
agent_prompt = AgentPrompt(
role=ROLE,
rules=RULES,
actions=ACTIONS,
actions=actions,
output_format=OUTPUT_FORMAT,
)
agent_memory = MemoryAgentMemory(memory_id, user_name, user_id)
Expand Down Expand Up @@ -127,7 +127,7 @@ def _parse_arguments(
parse the partial key with string value from json
"""
if is_code:
return TypingState.CODE, "", arguments, "python"
return TypingState.CODE, "", arguments, "python", ""
state = TypingState.START
explanation_str = ""
code_str = ""
Expand All @@ -142,6 +142,15 @@ def _parse_arguments(
if state == TypingState.CODE and token[0] == 1:
code_str = token[1]
state = TypingState.START
if state == TypingState.LANGUAGE and token[0] == 1:
language_str = token[1]
state = TypingState.START
if state == TypingState.MESSAGE and token[0] == 1:
message_str = token[1]
state = TypingState.START
if state == TypingState.OTHER and token[0] == 1:
state = TypingState.START

if token[1] == "explanation":
state = TypingState.EXPLANATION
if token[1] == "code":
Expand All @@ -150,6 +159,8 @@ def _parse_arguments(
state = TypingState.LANGUAGE
if token[1] == "message":
state = TypingState.MESSAGE
if token[1] == "saved_filenames":
state = TypingState.OTHER
else:
# String
if token_state == 9 and state == TypingState.EXPLANATION:
Expand Down Expand Up @@ -210,10 +221,8 @@ async def _read_json_message(
task_context,
task_opt,
):
arguments = message.get("content", "")
typing_language = "text"
return await self._send_typing_message(
arguments,
message.get("content", ""),
queue,
old_text_content,
old_code_content,
Expand All @@ -229,8 +238,8 @@ async def _send_typing_message(
queue,
old_text_content,
old_code_content,
old_language_str,
old_message_str,
old_language_str,
task_context,
task_opt,
is_code=False,
Expand All @@ -247,8 +256,9 @@ async def _send_typing_message(
) = self._parse_arguments(arguments, is_code)

logger.debug(
f"argument explanation:{explanation_str} code:{code_str} language_str:{language_str} text_content:{old_text_content}"
f"argument explanation:{explanation_str} code:{code_str} language_str:{language_str} text_content:{old_text_content} old_message_str:{old_message_str}"
)

if explanation_str and old_text_content != explanation_str:
typed_chars = explanation_str[len(old_text_content) :]
new_text_content = explanation_str
Expand Down Expand Up @@ -301,6 +311,7 @@ async def _send_typing_message(
context_id=task_context.context_id,
)
)

return old_text_content, old_code_content, old_language_str, message_str
return old_text_content, old_code_content, old_language_str, old_message_str

Expand Down Expand Up @@ -375,6 +386,7 @@ async def extract_message(
response_token_count + context_output_token_count
)
if is_json_format:

(
new_text_content,
new_code_content,
Expand Down
105 changes: 57 additions & 48 deletions agent/src/og_agent/llama_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
from .llama_client import LlamaClient
from og_proto.agent_server_pb2 import OnStepActionStart, TaskResponse, OnStepActionEnd, FinalAnswer, TypingContent
from .base_agent import BaseAgent, TypingState, TaskContext
from og_memory.memory import AgentMemoryOption
from .prompt import FUNCTION_DIRECT_MESSAGE, FUNCTION_EXECUTE
from .tokenizer import tokenize
from .prompt import OCTOGEN_CODELLAMA_SYSTEM
import tiktoken

logger = logging.getLogger(__name__)
Expand All @@ -26,12 +27,16 @@ class LlamaAgent(BaseAgent):
def __init__(self, client, kernel_sdk):
super().__init__(kernel_sdk)
self.client = client
self.memory_option = AgentMemoryOption(
show_function_instruction=True, disable_output_format=False
)

def _output_exception(self):
return (
"Sorry, the LLM did return nothing, You can use a better performance model"
)


def _format_output(self, json_response):
"""
format the response and send it to the user
Expand Down Expand Up @@ -94,7 +99,7 @@ async def handle_bash_code(
state=task_context.to_context_state_proto(),
response_type=TaskResponse.OnStepActionStart,
on_step_action_start=OnStepActionStart(
input=tool_input, tool="execute_bash_code"
input=tool_input, tool="execute"
),
)
)
Expand All @@ -108,7 +113,7 @@ async def handle_bash_code(
await queue.put(respond)
return function_result

async def handle_function(
async def handle_python_function(
self, json_response, queue, context, task_context, task_opt
):
code = json_response["code"]
Expand All @@ -125,7 +130,7 @@ async def handle_function(
state=task_context.to_context_state_proto(),
response_type=TaskResponse.OnStepActionStart,
on_step_action_start=OnStepActionStart(
input=tool_input, tool=json_response["action"]
input=tool_input, tool='execute'
),
)
)
Expand All @@ -139,18 +144,19 @@ async def handle_function(
await queue.put(respond)
return function_result

async def call_llama(self, messages, queue, context, task_context, task_opt):
async def call_llama(self, agent_memory, queue, context, task_context, task_opt):
"""
call llama api
"""
input_token_count = 0
messages = agent_memory.to_messages()
for message in messages:
if not message["content"]:
continue
input_token_count += len(encoding.encode(message["content"]))
task_context.input_token_count += input_token_count
start_time = time.time()
response = self.client.chat(messages, "codellama", max_tokens=2048)
response = self.client.chat(messages, "llama", max_tokens=2048)
message = await self.extract_message(
response,
queue,
Expand All @@ -162,19 +168,39 @@ async def call_llama(self, messages, queue, context, task_context, task_opt):
)
return message

async def arun(self, question, queue, context, task_opt):
async def arun(self, request, queue, context, task_opt):
"""
run the agent
"""
messages = [
{"role": "system", "content": OCTOGEN_CODELLAMA_SYSTEM},
question = request.task
context_id = (
request.context_id
if request.context_id
else self.create_new_memory_with_default_prompt("", "", actions=[FUNCTION_EXECUTE,
FUNCTION_DIRECT_MESSAGE])
)

if context_id not in self.agent_memories:
await queue.put(
TaskResponse(
state=task_context.to_context_state_proto(),
response_type=TaskResponse.OnSystemError,
error_msg="invalid context id",
context_id=context_id,
)
)
return

agent_memory = self.agent_memories[context_id]
agent_memory.update_options(self.memory_option)
agent_memory.append_chat_message(
{"role": "user", "content": question},
]
)
task_context = TaskContext(
start_time=time.time(),
output_token_count=0,
input_token_count=0,
llm_name="codellama",
llm_name="llama",
llm_respond_duration=0,
)
try:
Expand All @@ -198,7 +224,7 @@ async def arun(self, question, queue, context, task_opt):
)
break
message = await self.call_llama(
messages,
agent_memory,
queue,
context,
task_context,
Expand All @@ -225,21 +251,20 @@ async def arun(self, question, queue, context, task_opt):
)
)
break

logger.debug(f" llama response {json_response}")
if (
json_response["action"]
in ["execute_python_code", "execute_bash_code"]
and json_response["code"]
'function_call'in json_response and json_response["function_call"] == "execute"
):
messages.append(message)
agent_memory.append_chat_message(message)
tools_mapping = {
"execute_python_code": self.handle_function,
"execute_bash_code": self.handle_bash_code,
"python": self.handle_python_function,
"bash": self.handle_bash_code,
}
function_result = await tools_mapping[json_response["action"]](
json_response, queue, context, task_context, task_opt

function_result = await tools_mapping[json_response["arguments"]['language']](
json_response['arguments'], queue, context, task_context, task_opt
)

logger.debug(f"the function result {function_result}")
await queue.put(
TaskResponse(
Expand All @@ -255,52 +280,36 @@ async def arun(self, question, queue, context, task_opt):
),
)
)

action_output = "the output of %s:" % json_response["action"]
action_output = "the output of %s:" % json_response["function_call"]
current_question = "Give me the final answer summary if the above output of action meets the goal Otherwise try a new step"
if function_result.has_result:
messages.append({
agent_memory.append_chat_message({
"role": "user",
"content": f"{action_output} \n {function_result.console_stdout}",
})
messages.append({"role": "user", "content": current_question})
agent_memory.append_chat_message({"role": "user", "content": current_question})
elif function_result.has_error:
messages.append({
agent_memory.append_chat_message({
"role": "user",
"content": f"{action_output} \n {function_result.console_stderr}",
})
current_question = f"Generate a new step to fix the above error"
messages.append({"role": "user", "content": current_question})
agent_memory.append_chat_message({"role": "user", "content": current_question})
else:
messages.append({
agent_memory.append_chat_message({
"role": "user",
"content": f"{action_output} \n {function_result.console_stdout}",
})
messages.append({"role": "user", "content": current_question})
elif (
json_response["action"] == "show_sample_code"
and json_response["code"]
):
await self.handle_show_sample_code(
json_response, queue, context, task_context
)
result = self._format_output(json_response)
await queue.put(
TaskResponse(
state=task_context.to_context_state_proto(),
response_type=TaskResponse.OnFinalAnswer,
final_answer=FinalAnswer(answer=result),
)
)
break
else:
result = self._format_output(json_response)
agent_memory.append_chat_message({
"role": "user", "content": current_question})
elif 'function_call' in json_response and json_response["function_call"] == "direct_message":
message = json_response['arguments']['message']
await queue.put(
TaskResponse(
state=task_context.to_context_state_proto(),
response_type=TaskResponse.OnFinalAnswer,
final_answer=FinalAnswer(
answer=result if not task_opt.streaming else ""
answer=message if not task_opt.streaming else ""
),
)
)
Expand Down
1 change: 1 addition & 0 deletions agent/src/og_agent/llama_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ async def chat(self, messages, model, temperature=0, max_tokens=1024, stop=[]):
continue
try:
content = line[6:]
logger.debug(f"llama response content: {content}")
message = json.loads(content)
yield message
except Exception as e:
Expand Down
4 changes: 2 additions & 2 deletions agent/src/og_agent/openai_agent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: 2023 imotai <[email protected]>
# SPDX-FileCopyrightText: 2023 ghf5t565698```\\\\\\\\\-=[-[9oi86y53e12motai <[email protected]>
# SPDX-FileContributor: imotai
#
# SPDX-License-Identifier: Elastic-2.0
Expand Down Expand Up @@ -29,7 +29,7 @@ def __init__(self, model, sdk, is_azure=True):
self.is_azure = is_azure
self.model_name = model if not is_azure else ""
self.memory_option = AgentMemoryOption(
show_function_instruction=False, disable_output_forat=True
show_function_instruction=False, disable_output_format=True
)

async def call_openai(self, agent_memory, queue, context, task_context, task_opt):
Expand Down
Loading

0 comments on commit c1e0de3

Please sign in to comment.