From 9890c142b54e82c465c382911d9b07dbdebdd08c Mon Sep 17 00:00:00 2001 From: imotai Date: Wed, 1 Nov 2023 00:06:02 +0800 Subject: [PATCH] feat: openai agent passed --- agent/src/og_agent/agent_builder.py | 4 +- agent/src/og_agent/agent_server.py | 7 +- agent/src/og_agent/base_agent.py | 103 +++++++++++-------- agent/src/og_agent/openai_agent.py | 114 +++++----------------- agent/src/og_agent/prompt.py | 5 +- agent/tests/openai_agent_tests.py | 36 ++++--- chat/src/og_terminal/terminal_chat.py | 105 ++++++++++---------- memory/src/og_memory/memory.py | 17 +++- memory/src/og_memory/template/agent.jinja | 2 +- sdk/src/og_sdk/agent_sdk.py | 68 ++++++++++++- 10 files changed, 255 insertions(+), 206 deletions(-) diff --git a/agent/src/og_agent/agent_builder.py b/agent/src/og_agent/agent_builder.py index f6a1f79..9589976 100644 --- a/agent/src/og_agent/agent_builder.py +++ b/agent/src/og_agent/agent_builder.py @@ -18,9 +18,7 @@ def build_llama_agent(endpoint, key, sdk, grammer_path): """ with open(grammer_path, "r") as fd: grammar = fd.read() - client = LlamaClient(endpoint, key, grammar) - # init the agent return LlamaAgent(client, sdk) @@ -30,7 +28,7 @@ def build_openai_agent(sdk, model_name, is_azure=True): # TODO a data dir per user # init the agent - agent = OpenaiAgent(model_name, OCTOGEN_FUNCTION_SYSTEM, sdk, is_azure=is_azure) + agent = OpenaiAgent(model_name, sdk, is_azure=is_azure) return agent diff --git a/agent/src/og_agent/agent_server.py b/agent/src/og_agent/agent_server.py index 70c99a9..41237da 100644 --- a/agent/src/og_agent/agent_server.py +++ b/agent/src/og_agent/agent_server.py @@ -146,8 +146,8 @@ async def process_task( sdk = self.agents[metadata["api_key"]]["sdk"] queue = asyncio.Queue() - async def worker(task, agent, queue, context, task_opt): - return await agent.arun(task, queue, context, task_opt) + async def worker(request, agent, queue, context, task_opt): + return await agent.arun(request, queue, context, task_opt) options = ( request.options @@ -166,8 +166,9 @@ async def worker(task, agent, queue, context, task_opt): timeout=10, ) ) + logger.debug("create the agent task") - task = asyncio.create_task(worker(request.task, agent, queue, context, options)) + task = asyncio.create_task(worker(request, agent, queue, context, options)) while True: logger.debug("start wait the queue message") # TODO add timeout diff --git a/agent/src/og_agent/base_agent.py b/agent/src/og_agent/base_agent.py index 346a718..5a79b6e 100644 --- a/agent/src/og_agent/base_agent.py +++ b/agent/src/og_agent/base_agent.py @@ -34,6 +34,7 @@ class FunctionResult(BaseModel): has_result: bool = False has_error: bool = False + class TaskContext(BaseModel): start_time: float = 0 output_token_count: int = 0 @@ -41,6 +42,7 @@ class TaskContext(BaseModel): llm_name: str = "" llm_response_duration: int = 0 context_id: str = "" + def to_context_state_proto(self): # in ms total_duration = int((time.time() - self.start_time) * 1000) @@ -52,10 +54,12 @@ def to_context_state_proto(self): llm_response_duration=self.llm_response_duration, ) + class TypingState: START = 0 EXPLANATION = 1 CODE = 2 + LANGUAGE = 3 class BaseAgent: @@ -117,17 +121,16 @@ def _parse_arguments( self, arguments, is_code=False, - first_field_name="explanation", - second_field_name="code", ): """ parse the partial key with string value from json """ if is_code: - return TypingState.CODE, "", arguments + return TypingState.CODE, "", arguments, "python" state = TypingState.START explanation_str = "" code_str = "" + language_str = "" logger.debug(f"the arguments {arguments}") for token_state, token in tokenize(io.StringIO(arguments)): if token_state == None: @@ -137,17 +140,21 @@ def _parse_arguments( if state == TypingState.CODE and token[0] == 1: code_str = token[1] state = TypingState.START - if token[1] == first_field_name: + if token[1] == "explanation": state = TypingState.EXPLANATION - if token[1] == second_field_name: + if token[1] == "code": state = TypingState.CODE + if token[1] == "language": + state = TypingState.LANGUAGE else: # String if token_state == 9 and state == TypingState.EXPLANATION: explanation_str = "".join(token) elif token_state == 9 and state == TypingState.CODE: code_str = "".join(token) - return (state, explanation_str, code_str) + elif token_state == 9 and state == TypingState.LANGUAGE: + language_str = "".join(token) + return (state, explanation_str, code_str, language_str) def _get_message_token_count(self, message): response_token_count = 0 @@ -159,16 +166,16 @@ def _get_message_token_count(self, message): return response_token_count async def _read_function_call_message( - self, message, queue, old_text_content, old_code_content, task_context, task_opt + self, + message, + queue, + old_text_content, + old_code_content, + language_str, + task_context, + task_opt, ): typing_language = "text" - if message["function_call"].get("name", "") in [ - "execute_python_code", - "python", - ]: - typing_language = "python" - elif message["function_call"].get("name", "") == "execute_bash_code": - typing_language = "bash" is_code = False if message["function_call"].get("name", "") == "python": is_code = True @@ -178,28 +185,30 @@ async def _read_function_call_message( queue, old_text_content, old_code_content, - typing_language, + language_str, task_context, task_opt, is_code=is_code, ) async def _read_json_message( - self, message, queue, old_text_content, old_code_content, task_context, task_opt + self, + message, + queue, + old_text_content, + old_code_content, + old_language_str, + task_context, + task_opt, ): arguments = message.get("content", "") typing_language = "text" - if arguments.find("execute_python_code") >= 0: - typing_language = "python" - elif arguments.find("execute_bash_code") >= 0: - typing_language = "bash" - return await self._send_typing_message( arguments, queue, old_text_content, old_code_content, - typing_language, + old_language_str, task_context, task_opt, ) @@ -210,7 +219,7 @@ async def _send_typing_message( queue, old_text_content, old_code_content, - language, + old_language_str, task_context, task_opt, is_code=False, @@ -218,10 +227,11 @@ async def _send_typing_message( """ send the typing message to the client """ - task_opt = request - (state, explanation_str, code_str) = self._parse_arguments(arguments, is_code) + (state, explanation_str, code_str, language_str) = self._parse_arguments( + arguments, is_code + ) logger.debug( - f"argument explanation:{explanation_str} code:{code_str} text_content:{old_text_content}" + f"argument explanation:{explanation_str} code:{code_str} language_str:{language_str} text_content:{old_text_content}" ) if explanation_str and old_text_content != explanation_str: typed_chars = explanation_str[len(old_text_content) :] @@ -231,10 +241,10 @@ async def _send_typing_message( state=task_context.to_context_state_proto(), response_type=TaskResponse.OnModelTypeText, typing_content=TypingContent(content=typed_chars, language="text"), - context_id=task_context.context_id + context_id=task_context.context_id, ) await queue.put(task_response) - return new_text_content, old_code_content + return new_text_content, old_code_content, old_language_str if code_str and old_code_content != code_str: typed_chars = code_str[len(old_code_content) :] code_content = code_str @@ -244,13 +254,25 @@ async def _send_typing_message( state=task_context.to_context_state_proto(), response_type=TaskResponse.OnModelTypeCode, typing_content=TypingContent( - content=typed_chars, language=language + content=typed_chars, language="text" ), - context_id=task_context.context_id + context_id=task_context.context_id, + ) + ) + return old_text_content, code_content, old_language_str + if language_str and old_language_str != language_str: + typed_chars = language_str[len(old_language_str) :] + if task_opt.streaming and len(typed_chars) > 0: + await queue.put( + TaskResponse( + state=task_context.to_context_state_proto(), + response_type=TaskResponse.OnModelTypeCode, + typing_content=TypingContent(content="", language=language_str), + context_id=task_context.context_id, ) ) - return old_text_content, code_content - return old_text_content, old_code_content + return old_text_content, old_code_content, language_str + return old_text_content, old_code_content, old_language_str async def extract_message( self, @@ -258,17 +280,17 @@ async def extract_message( queue, rpc_context, task_context, - request, + task_opt, start_time, is_json_format=False, ): """ extract the chunk from the response generator """ - task_opt = request.options message = {} text_content = "" code_content = "" + language_str = "" context_output_token_count = task_context.output_token_count start_time = time.time() async for chunk in response_generator: @@ -294,11 +316,13 @@ async def extract_message( ( new_text_content, new_code_content, + new_language_str, ) = await self._read_function_call_message( message, queue, text_content, code_content, + language_str, task_context, task_opt, ) @@ -324,6 +348,7 @@ async def extract_message( queue, text_content, code_content, + language_str, task_context, task_opt, ) @@ -338,7 +363,7 @@ async def extract_message( typing_content=TypingContent( content=delta["content"], language="text" ), - context_id=task_context.context_id + context_id=task_context.context_id, ) ) logger.info( @@ -375,7 +400,7 @@ async def call_function(self, code, context, task_context): state=task_context.to_context_state_proto(), response_type=TaskResponse.OnStepActionStreamStdout, console_stdout=kernel_output, - context_id=task_context.context_id + context_id=task_context.context_id, ), ) # process the stderr @@ -390,7 +415,7 @@ async def call_function(self, code, context, task_context): state=task_context.to_context_state_proto(), response_type=TaskResponse.OnStepActionStreamStderr, console_stderr=kernel_err, - context_id=task_context.context_id + context_id=task_context.context_id, ), ) elif kernel_respond.output_type == ExecuteResponse.TracebackType: @@ -404,7 +429,7 @@ async def call_function(self, code, context, task_context): state=task_context.to_context_state_proto(), response_type=TaskResponse.OnStepActionStreamStderr, console_stderr=traceback, - context_id=task_context.context_id + context_id=task_context.context_id, ), ) else: @@ -428,7 +453,7 @@ async def call_function(self, code, context, task_context): state=task_context.to_context_state_proto(), response_type=TaskResponse.OnStepActionStreamStdout, console_stdout=console_stdout, - context_id=task_context.context_id + context_id=task_context.context_id, ), ) output_files = [] diff --git a/agent/src/og_agent/openai_agent.py b/agent/src/og_agent/openai_agent.py index 550e18e..e3cc941 100644 --- a/agent/src/og_agent/openai_agent.py +++ b/agent/src/og_agent/openai_agent.py @@ -18,54 +18,7 @@ logger = logging.getLogger(__name__) encoding = tiktoken.encoding_for_model("gpt-3.5-turbo") -OCTOGEN_FUNCTIONS = [ - { - "name": "execute_python_code", - "description": "Safely execute arbitrary Python code and return the result, stdout, and stderr. ", - "parameters": { - "type": "object", - "properties": { - "explanation": { - "type": "string", - "description": "the explanation about the python code", - }, - "code": { - "type": "string", - "description": "the python code to be executed", - }, - "saved_filenames": { - "type": "array", - "items": {"type": "string"}, - "description": "A list of filenames that were created by the code", - }, - }, - "required": ["explanation", "code"], - }, - }, - { - "name": "execute_bash_code", - "description": "Safely execute arbitrary Bash code and return the result, stdout, and stderr. sudo is not supported.", - "parameters": { - "type": "object", - "properties": { - "explanation": { - "type": "string", - "description": "the explanation about the bash code", - }, - "code": { - "type": "string", - "description": "the bash code to be executed", - }, - "saved_filenames": { - "type": "array", - "items": {"type": "string"}, - "description": "A list of filenames that were created by the code", - }, - }, - "required": ["explanation", "code"], - }, - }, -] + class OpenaiAgent(BaseAgent): @@ -75,14 +28,17 @@ def __init__(self, model, sdk, is_azure=True): logger.info(f"use openai model {model} is_azure {is_azure}") self.is_azure = is_azure self.model_name = model if not is_azure else "" - self.memory_option = AgentMemoryOption(show_function_instruction=False) + self.memory_option = AgentMemoryOption( + show_function_instruction=False, disable_output_format=True + ) - async def call_openai(self, messages, queue, context, task_context, request): + async def call_openai(self, agent_memory, queue, context, task_context, task_opt): """ call the openai api """ - logger.debug(f"call openai with messages {messages}") input_token_count = 0 + messages = agent_memory.to_messages() + logger.debug(f"call openai with messages {messages}") for message in messages: if not message["content"]: continue @@ -94,7 +50,7 @@ async def call_openai(self, messages, queue, context, task_context, request): engine=self.model, messages=messages, temperature=0, - functions=OCTOGEN_FUNCTIONS, + functions=agent_memory.get_functions(), function_call="auto", stream=True, ) @@ -103,17 +59,16 @@ async def call_openai(self, messages, queue, context, task_context, request): model=self.model, messages=messages, temperature=0, - functions=OCTOGEN_FUNCTIONS, + functions=agent_memory.get_functions(), function_call="auto", stream=True, ) message = await self.extract_message( - response, queue, context, task_context, request, start_time + response, queue, context, task_context, task_opt, start_time ) return message - async def handle_function(self, message, queue, context, task_context, request): - task_opt = request.options + async def handle_function(self, message, queue, context, task_context, task_opt): if "function_call" in message: if context.done(): logging.debug("the client has cancelled the request") @@ -130,14 +85,14 @@ async def handle_function(self, message, queue, context, task_context, request): code = raw_code else: arguments = json.loads(message["function_call"]["arguments"]) - logger.debug(f"call function {function_name} with args {arguments}") raw_code = arguments["code"] code = raw_code explanation = arguments["explanation"] saved_filenames = arguments.get("saved_filenames", []) - if function_name == "execute_bash_code": - language = "bash" + language = arguments.get("language") + if language == "bash": code = f"%%bash\n{raw_code}" + tool_input = json.dumps({ "code": raw_code, "explanation": explanation, @@ -152,10 +107,9 @@ async def handle_function(self, message, queue, context, task_context, request): on_step_action_start=OnStepActionStart( input=tool_input, tool=function_name ), - context_id=task_context.context_id + context_id=task_context.context_id, ) ) - function_result = None async for (result, respond) in self.call_function( code, context, task_context @@ -170,12 +124,11 @@ async def handle_function(self, message, queue, context, task_context, request): else: raise Exception("bad message, function message expected") - async def arun(self, request, queue, context): + async def arun(self, request, queue, context, task_opt): """ process the task """ task = request.task - task_opt = request.options context_id = ( request.context_id if request.context_id @@ -195,16 +148,15 @@ async def arun(self, request, queue, context): state=task_context.to_context_state_proto(), response_type=TaskResponse.OnSystemError, error_msg="invalid context id", - context_id = context_id, + context_id=context_id, ) ) return - agent_memory = self.agent_memories[context_id] - agent_memory.update_option(self.memory_option) - agent_memory.append_chat_message({ + agent_memory.update_options(self.memory_option) + agent_memory.append_chat_message( {"role": "user", "content": task}, - }) + ) try: while not context.done(): if task_context.input_token_count >= task_opt.input_token_limit: @@ -213,7 +165,7 @@ async def arun(self, request, queue, context): state=task_context.to_context_state_proto(), response_type=TaskResponse.OnInputTokenLimitExceed, error_msg="input token limit reached", - context_id = context_id, + context_id=context_id, ) ) break @@ -223,13 +175,12 @@ async def arun(self, request, queue, context): state=task_context.to_context_state_proto(), response_type=TaskResponse.OnOutputTokenLimitExceed, error_msg="output token limit reached", - context_id = context_id, + context_id=context_id, ) ) break - logger.debug(f" the input messages {messages}") chat_message = await self.call_openai( - agent_memories.to_messages(), queue, context, task_context, request + agent_memory, queue, context, task_context, task_opt ) logger.debug(f"the response {chat_message}") if "function_call" in chat_message: @@ -238,18 +189,6 @@ async def arun(self, request, queue, context): if "role" not in chat_message: chat_message["role"] = "assistant" agent_memory.append_chat_message(chat_message) - function_name = chat_message["function_call"]["name"] - if function_name not in [ - "execute_python_code", - "python", - "execute_bash_code", - ]: - agent_memory.append_chat_message({ - "role": "function", - "name": function_name, - "content": "You can use the execute_python_code or execute_bash_code", - }) - continue function_result = await self.handle_function( chat_message, queue, context, task_context, task_opt ) @@ -265,9 +204,10 @@ async def arun(self, request, queue, context): output_files=function_result.saved_filenames, has_error=function_result.has_error, ), - context_id = context_id, + context_id=context_id, ) ) + function_name = chat_message["function_call"]["name"] # TODO optimize the token limitation if function_result.has_result: agent_memory.append_chat_message({ @@ -298,7 +238,7 @@ async def arun(self, request, queue, context): if not task_opt.streaming else "" ), - context_id = context_id, + context_id=context_id, ) ) break @@ -307,7 +247,7 @@ async def arun(self, request, queue, context): response = TaskResponse( response_type=TaskResponse.OnSystemError, error_msg=str(ex), - context_id = context_id, + context_id=context_id, ) await queue.put(response) finally: diff --git a/agent/src/og_agent/prompt.py b/agent/src/og_agent/prompt.py index 751f739..1f45a96 100644 --- a/agent/src/og_agent/prompt.py +++ b/agent/src/og_agent/prompt.py @@ -4,7 +4,7 @@ # SPDX-License-Identifier: Elastic-2.0 import json -from og_proto.memory_pb2 import ActionDesc +from og_proto.prompt_pb2 import ActionDesc ROLE = """You are the Programming Copilot called Octogen, a world-class programmer to complete any goal by executing code""" @@ -16,7 +16,6 @@ "You can install new package with pip", "Use `execute` action to execute any code and `direct_message` action to send message to user", ] - ACTIONS = [ ActionDesc( name="execute", @@ -34,7 +33,7 @@ }, "language": { "type": "string", - "description": "the language of the code", + "description": "the language of the code, only python and bash are supported", }, "saved_filenames": { "type": "array", diff --git a/agent/tests/openai_agent_tests.py b/agent/tests/openai_agent_tests.py index d7e92cc..c6c29f4 100644 --- a/agent/tests/openai_agent_tests.py +++ b/agent/tests/openai_agent_tests.py @@ -12,7 +12,7 @@ import pytest from og_sdk.kernel_sdk import KernelSDK from og_agent import openai_agent -from og_proto.agent_server_pb2 import ProcessOptions, TaskResponse +from og_proto.agent_server_pb2 import ProcessOptions, TaskResponse, ProcessTaskRequest from openai.openai_object import OpenAIObject import asyncio import pytest_asyncio @@ -112,8 +112,9 @@ async def test_openai_agent_call_execute_bash_code(mocker, kernel_sdk): "explanation": "the hello world in bash", "code": "echo 'hello world'", "saved_filenames": [], + "language": "bash", } - stream1 = FunctionCallPayloadStream("execute_bash_code", json.dumps(arguments)) + stream1 = FunctionCallPayloadStream("execute", json.dumps(arguments)) sentence = "The output 'hello world' is the result" stream2 = PayloadStream(sentence) call_mock = MultiCallMock([stream1, stream2]) @@ -121,7 +122,7 @@ async def test_openai_agent_call_execute_bash_code(mocker, kernel_sdk): "og_agent.openai_agent.openai.ChatCompletion.acreate", side_effect=call_mock.call, ) as mock_openai: - agent = openai_agent.OpenaiAgent("gpt4", "prompt", kernel_sdk, is_azure=False) + agent = openai_agent.OpenaiAgent("gpt4", kernel_sdk, is_azure=False) queue = asyncio.Queue() task_opt = ProcessOptions( streaming=True, @@ -130,7 +131,13 @@ async def test_openai_agent_call_execute_bash_code(mocker, kernel_sdk): output_token_limit=100000, timeout=5, ) - await agent.arun("write a hello world in bash", queue, MockContext(), task_opt) + request = ProcessTaskRequest( + input_files=[], + task="write a hello world in bash", + context_id="", + options=task_opt, + ) + await agent.arun(request, queue, MockContext(), task_opt) responses = [] while True: try: @@ -157,9 +164,10 @@ async def test_openai_agent_call_execute_python_code(mocker, kernel_sdk): arguments = { "explanation": "the hello world in python", "code": "print('hello world')", + "language": "python", "saved_filenames": [], } - stream1 = FunctionCallPayloadStream("execute_python_code", json.dumps(arguments)) + stream1 = FunctionCallPayloadStream("execute", json.dumps(arguments)) sentence = "The output 'hello world' is the result" stream2 = PayloadStream(sentence) call_mock = MultiCallMock([stream1, stream2]) @@ -167,7 +175,7 @@ async def test_openai_agent_call_execute_python_code(mocker, kernel_sdk): "og_agent.openai_agent.openai.ChatCompletion.acreate", side_effect=call_mock.call, ) as mock_openai: - agent = openai_agent.OpenaiAgent("gpt4", "prompt", kernel_sdk, is_azure=False) + agent = openai_agent.OpenaiAgent("gpt4", kernel_sdk, is_azure=False) queue = asyncio.Queue() task_opt = ProcessOptions( streaming=True, @@ -176,10 +184,13 @@ async def test_openai_agent_call_execute_python_code(mocker, kernel_sdk): output_token_limit=100000, timeout=5, ) - - await agent.arun( - "write a hello world in python", queue, MockContext(), task_opt + request = ProcessTaskRequest( + input_files=[], + task="write a hello world in python", + context_id="", + options=task_opt, ) + await agent.arun(request, queue, MockContext(), task_opt) responses = [] while True: try: @@ -207,7 +218,7 @@ async def test_openai_agent_smoke_test(mocker, kernel_sdk): with mocker.patch( "og_agent.openai_agent.openai.ChatCompletion.acreate", return_value=stream ) as mock_openai: - agent = openai_agent.OpenaiAgent("gpt4", "prompt", kernel_sdk, is_azure=False) + agent = openai_agent.OpenaiAgent("gpt4", kernel_sdk, is_azure=False) queue = asyncio.Queue() task_opt = ProcessOptions( streaming=True, @@ -216,7 +227,10 @@ async def test_openai_agent_smoke_test(mocker, kernel_sdk): output_token_limit=100000, timeout=5, ) - await agent.arun("hello", queue, MockContext(), task_opt) + request = ProcessTaskRequest( + input_files=[], task="hello", context_id="", options=task_opt + ) + await agent.arun(request, queue, MockContext(), task_opt) responses = [] while True: try: diff --git a/chat/src/og_terminal/terminal_chat.py b/chat/src/og_terminal/terminal_chat.py index d3895e3..cee41e6 100644 --- a/chat/src/og_terminal/terminal_chat.py +++ b/chat/src/og_terminal/terminal_chat.py @@ -330,7 +330,7 @@ def upload_file(prompt, console, history_prompt, sdk, values): return real_prompt -def run_chat(prompt, sdk, session, console, values, filedir=None): +def run_chat(prompt, session, console, values, filedir=None): """ run the chat """ @@ -341,7 +341,7 @@ def run_chat(prompt, sdk, session, console, values, filedir=None): with Live(Group(*[]), console=console) as live: refresh(live, task_blocks) task_state = None - for respond in sdk.prompt(prompt): + for respond in session.prompt(prompt): if not respond: break if respond.response_type in [ @@ -360,7 +360,6 @@ def run_chat(prompt, sdk, session, console, values, filedir=None): task_state = respond.state refresh(live, task_blocks, task_state=respond.state) refresh(live, task_blocks, task_state=task_state) - if error_responses: task_blocks = TaskBlocks(values) task_blocks.begin() @@ -389,54 +388,54 @@ def app(octogen_dir): os.makedirs(filedir, exist_ok=True) sdk = AgentSyncSDK(octopus_config["endpoint"], octopus_config["api_key"]) sdk.connect() - history = FileHistory(real_octogen_dir + "/history") - values = [] - completer = OctogenCompleter(values) - session = PromptSession( - history=history, - completer=completer, - complete_in_thread=True, - complete_while_typing=True, - complete_style=CompleteStyle.MULTI_COLUMN, - ) - index = 0 - show_welcome(console) - while True: - index = index + 1 - real_prompt = session.prompt( - "[%s]%s>" % (index, "🎧"), - multiline=True, - prompt_continuation=prompt_continuation, - ) - if not "".join(real_prompt.strip().split("\n")): - continue - if real_prompt.find("/help") >= 0: - show_help(console) - continue - if real_prompt.find("/exit") >= 0: - console.print("πŸ‘‹πŸ‘‹!") - return - if real_prompt.find("/clear") >= 0: - clear() - continue - if real_prompt.find("/cc") >= 0: - # handle copy - for number in parse_numbers(real_prompt): - num = int(number) - if num < len(values): - clipboard.copy(values[num]) - console.print(f"πŸ‘ /cc{number} has been copied to clipboard!") - break - else: - console.print(f"❌ /cc{number} was not found!") - continue - # try to upload firstβŒ›β³βŒ - real_prompt = upload_file(real_prompt, console, history, sdk, values) - run_chat( - real_prompt, - sdk, - session, - console, - values, - filedir=filedir, + with sdk.create_session() as agent_session: + history = FileHistory(real_octogen_dir + "/history") + values = [] + completer = OctogenCompleter(values) + session = PromptSession( + history=history, + completer=completer, + complete_in_thread=True, + complete_while_typing=True, + complete_style=CompleteStyle.MULTI_COLUMN, ) + index = 0 + show_welcome(console) + while True: + index = index + 1 + real_prompt = session.prompt( + "[%s]%s>" % (index, "🎧"), + multiline=True, + prompt_continuation=prompt_continuation, + ) + if not "".join(real_prompt.strip().split("\n")): + continue + if real_prompt.find("/help") >= 0: + show_help(console) + continue + if real_prompt.find("/exit") >= 0: + console.print("πŸ‘‹πŸ‘‹!") + return + if real_prompt.find("/clear") >= 0: + clear() + continue + if real_prompt.find("/cc") >= 0: + # handle copy + for number in parse_numbers(real_prompt): + num = int(number) + if num < len(values): + clipboard.copy(values[num]) + console.print(f"πŸ‘ /cc{number} has been copied to clipboard!") + break + else: + console.print(f"❌ /cc{number} was not found!") + continue + # try to upload firstβŒ›β³βŒ + real_prompt = upload_file(real_prompt, console, history, sdk, values) + run_chat( + real_prompt, + agent_session, + console, + values, + filedir=filedir, + ) diff --git a/memory/src/og_memory/memory.py b/memory/src/og_memory/memory.py index c39fa8d..1e5203a 100644 --- a/memory/src/og_memory/memory.py +++ b/memory/src/og_memory/memory.py @@ -79,12 +79,20 @@ def reset_memory(self): """ pass + @abstractmethod + def get_functions(self): + """ + return the function definitions for model that supports the function_call + """ + pass + class AgentMemoryOption(BaseModel): """ The agent memory option """ show_function_instruction: bool = Field(False, description="Show the function instruction") + disable_output_format: bool = Field(False, description="Disable the output format") class MemoryAgentMemory(BaseAgentMemory): """ @@ -115,11 +123,16 @@ def append_chat_message(self, message): def swap_instruction(self, instruction): self.instruction = instruction + def get_functions(self): + return [{"name": action.name, "description": action.desc, "parameters": + json.loads(action.parameters)} for action in self.instruction.actions] + def to_messages(self): system_message = { "role":"system", - "content":agent_memory_to_context(self.instruction, self.guide_memory) + "content":agent_memory_to_context(self.instruction, self.guide_memory, options = self.options) } - logging.degug(f"system message: {system_message}") + logging.debug(f"system message: {system_message}") return [system_message] + self.chat_memory + diff --git a/memory/src/og_memory/template/agent.jinja b/memory/src/og_memory/template/agent.jinja index 200f163..c4eb807 100644 --- a/memory/src/og_memory/template/agent.jinja +++ b/memory/src/og_memory/template/agent.jinja @@ -15,6 +15,6 @@ Use the following actions to help you finishing your task {% endfor -%} {% endif -%}{%if guides -%}The instructions for the tools and libraries you recently used. {% for guide in guides if guide -%}{{loop.index}}.{{guide.name}}{{guide.what_it_can_do}}{{guide.how_to_use}} -{% endfor -%}{% endif -%}{%if prompt.output_format -%} +{% endfor -%}{% endif -%}{%if prompt.output_format and not options.disable_output_format -%} {{prompt.output_format}} {% endif -%} diff --git a/sdk/src/og_sdk/agent_sdk.py b/sdk/src/og_sdk/agent_sdk.py index 9d82a93..14ac5ad 100644 --- a/sdk/src/og_sdk/agent_sdk.py +++ b/sdk/src/og_sdk/agent_sdk.py @@ -16,6 +16,50 @@ logger = logging.getLogger(__name__) +class AgentSyncSession: + + def __init__(self, agent_sdk): + self.context_id = None + self.agent_sdk = agent_sdk + + def prompt(self, prompt, files=[]): + """ + ask the ai with prompt and uploaded files + """ + for respond in self.agent_sdk.prompt(prompt, files, self.context_id): + if respond.context_id: + self.context_id = respond.context_id + yield respond + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.context_id = None + + +class AgentAsyncSession: + + def __init__(self, agent_sdk): + self.context_id = None + self.agent_sdk = agent_sdk + + async def prompt(self, prompt, files=[]): + """ + ask the ai with prompt and uploaded files + """ + async for respond in self.agent_sdk.prompt(prompt, files, self.context_id): + if respond.context_id: + self.context_id = respond.context_id + yield respond + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.context_id = None + + class AgentBaseSDK: def __init__(self, endpoint): @@ -65,6 +109,12 @@ def __init__(self, endpoint, api_key): def connect(self): self.connect_sync() + def create_session(self): + """ + create a session for the agent + """ + return AgentSyncSession(self) + def add_kernel(self, key, endpoint): """ add kernel instance to the agent and only admin can call this method @@ -96,11 +146,13 @@ def upload_file(self, filepath, filename): generate_chunk(filepath, filename), metadata=self.metadata ) - def prompt(self, prompt, files=[]): + def prompt(self, prompt, files=[], context_id=None): """ ask the ai with prompt and uploaded files """ - request = agent_server_pb2.ProcessTaskRequest(task=prompt, input_files=files) + request = agent_server_pb2.ProcessTaskRequest( + task=prompt, input_files=files, context_id=context_id + ) for respond in self.stub.process_task(request, metadata=self.metadata): yield respond @@ -168,11 +220,19 @@ async def add_kernel(self, key, endpoint): response = await self.stub.add_kernel(request, metadata=self.metadata) return response - async def prompt(self, prompt, files=[]): + def create_session(self): + """ + create a session for the agent + """ + return AgentAsyncSession(self) + + async def prompt(self, prompt, files=[], context_id=None): """ ask the ai with prompt and uploaded files """ - request = agent_server_pb2.ProcessTaskRequest(task=prompt, input_files=files) + request = agent_server_pb2.ProcessTaskRequest( + task=prompt, input_files=files, context_id=context_id + ) async for respond in self.stub.process_task(request, metadata=self.metadata): yield respond