From c19bfd34a6c66d4575a85ab93cf05e0d713becbd Mon Sep 17 00:00:00 2001 From: imotai Date: Sat, 4 Nov 2023 12:36:24 +0800 Subject: [PATCH] feat: support openai direct message --- agent/src/og_agent/base_agent.py | 56 +++++++++++++++++++++++++----- agent/src/og_agent/openai_agent.py | 2 +- agent/tests/openai_agent_tests.py | 43 ++++++++++++++++++++++- chat/src/og_terminal/ui_block.py | 1 + 4 files changed, 92 insertions(+), 10 deletions(-) diff --git a/agent/src/og_agent/base_agent.py b/agent/src/og_agent/base_agent.py index 5a79b6e..f915fe5 100644 --- a/agent/src/og_agent/base_agent.py +++ b/agent/src/og_agent/base_agent.py @@ -60,6 +60,7 @@ class TypingState: EXPLANATION = 1 CODE = 2 LANGUAGE = 3 + MESSAGE = 4 class BaseAgent: @@ -131,6 +132,7 @@ def _parse_arguments( explanation_str = "" code_str = "" language_str = "" + message_str = "" logger.debug(f"the arguments {arguments}") for token_state, token in tokenize(io.StringIO(arguments)): if token_state == None: @@ -146,6 +148,8 @@ def _parse_arguments( state = TypingState.CODE if token[1] == "language": state = TypingState.LANGUAGE + if token[1] == "message": + state = TypingState.MESSAGE else: # String if token_state == 9 and state == TypingState.EXPLANATION: @@ -154,7 +158,9 @@ def _parse_arguments( code_str = "".join(token) elif token_state == 9 and state == TypingState.LANGUAGE: language_str = "".join(token) - return (state, explanation_str, code_str, language_str) + elif token_state == 9 and state == TypingState.MESSAGE: + message_str = "".join(token) + return (state, explanation_str, code_str, language_str, message_str) def _get_message_token_count(self, message): response_token_count = 0 @@ -171,6 +177,7 @@ async def _read_function_call_message( queue, old_text_content, old_code_content, + old_message_str, language_str, task_context, task_opt, @@ -185,6 +192,7 @@ async def _read_function_call_message( queue, old_text_content, old_code_content, + old_message_str, language_str, task_context, task_opt, @@ -197,6 +205,7 @@ async def _read_json_message( queue, old_text_content, old_code_content, + old_message_str, old_language_str, task_context, task_opt, @@ -208,6 +217,7 @@ async def _read_json_message( queue, old_text_content, old_code_content, + old_message_str, old_language_str, task_context, task_opt, @@ -220,6 +230,7 @@ async def _send_typing_message( old_text_content, old_code_content, old_language_str, + old_message_str, task_context, task_opt, is_code=False, @@ -227,9 +238,14 @@ async def _send_typing_message( """ send the typing message to the client """ - (state, explanation_str, code_str, language_str) = self._parse_arguments( - arguments, is_code - ) + ( + state, + explanation_str, + code_str, + language_str, + message_str, + ) = self._parse_arguments(arguments, is_code) + logger.debug( f"argument explanation:{explanation_str} code:{code_str} language_str:{language_str} text_content:{old_text_content}" ) @@ -244,7 +260,7 @@ async def _send_typing_message( context_id=task_context.context_id, ) await queue.put(task_response) - return new_text_content, old_code_content, old_language_str + return new_text_content, old_code_content, old_language_str, old_message_str if code_str and old_code_content != code_str: typed_chars = code_str[len(old_code_content) :] code_content = code_str @@ -259,7 +275,7 @@ async def _send_typing_message( context_id=task_context.context_id, ) ) - return old_text_content, code_content, old_language_str + return old_text_content, code_content, old_language_str, old_message_str if language_str and old_language_str != language_str: typed_chars = language_str[len(old_language_str) :] if task_opt.streaming and len(typed_chars) > 0: @@ -271,8 +287,22 @@ async def _send_typing_message( context_id=task_context.context_id, ) ) - return old_text_content, old_code_content, language_str - return old_text_content, old_code_content, old_language_str + return old_text_content, old_code_content, language_str, old_message_str + if message_str and old_message_str != message_str: + typed_chars = message_str[len(old_message_str) :] + if task_opt.streaming and len(typed_chars) > 0: + await queue.put( + TaskResponse( + state=task_context.to_context_state_proto(), + response_type=TaskResponse.OnModelTypeText, + typing_content=TypingContent( + content=typed_chars, language="text" + ), + context_id=task_context.context_id, + ) + ) + return old_text_content, old_code_content, old_language_str, message_str + return old_text_content, old_code_content, old_language_str, old_message_str async def extract_message( self, @@ -291,6 +321,7 @@ async def extract_message( text_content = "" code_content = "" language_str = "" + message_str = "" context_output_token_count = task_context.output_token_count start_time = time.time() async for chunk in response_generator: @@ -317,17 +348,21 @@ async def extract_message( new_text_content, new_code_content, new_language_str, + new_message_str, ) = await self._read_function_call_message( message, queue, text_content, code_content, + message_str, language_str, task_context, task_opt, ) text_content = new_text_content code_content = new_code_content + message_str = new_message_str + language_str = new_language_str else: self._merge_delta_for_content(message, delta) task_context.llm_response_duration += int( @@ -343,17 +378,22 @@ async def extract_message( ( new_text_content, new_code_content, + new_language_str, + new_message_str, ) = await self._read_json_message( message, queue, text_content, code_content, + message_str, language_str, task_context, task_opt, ) text_content = new_text_content code_content = new_code_content + message_str = new_message_str + language_str = new_language_str elif task_opt.streaming and delta.get("content"): await queue.put( diff --git a/agent/src/og_agent/openai_agent.py b/agent/src/og_agent/openai_agent.py index c5967cf..f6901d9 100644 --- a/agent/src/og_agent/openai_agent.py +++ b/agent/src/og_agent/openai_agent.py @@ -90,7 +90,7 @@ async def handle_function(self, message, queue, context, task_context, task_opt) if not task_opt.streaming else "" ), - context_id=context_id, + context_id=task_context.context_id, ) ) return None diff --git a/agent/tests/openai_agent_tests.py b/agent/tests/openai_agent_tests.py index c6c29f4..d5926a0 100644 --- a/agent/tests/openai_agent_tests.py +++ b/agent/tests/openai_agent_tests.py @@ -158,6 +158,47 @@ async def test_openai_agent_call_execute_bash_code(mocker, kernel_sdk): assert console_output[0].console_stdout == "hello world\n", "bad console output" +@pytest.mark.asyncio +async def test_openai_agent_direct_message(mocker, kernel_sdk): + kernel_sdk.connect() + arguments = { + "message": "hello world", + } + stream1 = FunctionCallPayloadStream("direct_message", json.dumps(arguments)) + call_mock = MultiCallMock([stream1]) + with mocker.patch( + "og_agent.openai_agent.openai.ChatCompletion.acreate", + side_effect=call_mock.call, + ) as mock_openai: + agent = openai_agent.OpenaiAgent("gpt4", kernel_sdk, is_azure=False) + queue = asyncio.Queue() + task_opt = ProcessOptions( + streaming=False, + llm_name="gpt4", + input_token_limit=100000, + output_token_limit=100000, + timeout=5, + ) + request = ProcessTaskRequest( + input_files=[], + task="say hello world", + context_id="", + options=task_opt, + ) + await agent.arun(request, queue, MockContext(), task_opt) + responses = [] + while True: + try: + response = await queue.get() + if not response: + break + responses.append(response) + except asyncio.QueueEmpty: + break + logger.info(responses) + assert responses[0].final_answer.answer == "hello world" + + @pytest.mark.asyncio async def test_openai_agent_call_execute_python_code(mocker, kernel_sdk): kernel_sdk.connect() @@ -245,5 +286,5 @@ async def test_openai_agent_smoke_test(mocker, kernel_sdk): assert ( responses[-1].response_type == TaskResponse.OnFinalAnswer ), "bad response type" - assert responses[-1].state.input_token_count == 2 + assert responses[-1].state.input_token_count == 325 assert responses[-1].state.output_token_count == 8 diff --git a/chat/src/og_terminal/ui_block.py b/chat/src/og_terminal/ui_block.py index 6dd098b..7011884 100644 --- a/chat/src/og_terminal/ui_block.py +++ b/chat/src/og_terminal/ui_block.py @@ -214,6 +214,7 @@ def add_code(self, code, language): self.values.append(code) else: last_block.append(code) + last_block.language = language self.values[last_block.get_index()] = last_block.content else: last_block.finish()