Skip to content

Commit

Permalink
feat: support openai direct message
Browse files Browse the repository at this point in the history
  • Loading branch information
imotai committed Nov 4, 2023
1 parent c6cf285 commit c19bfd3
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 10 deletions.
56 changes: 48 additions & 8 deletions agent/src/og_agent/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class TypingState:
EXPLANATION = 1
CODE = 2
LANGUAGE = 3
MESSAGE = 4


class BaseAgent:
Expand Down Expand Up @@ -131,6 +132,7 @@ def _parse_arguments(
explanation_str = ""
code_str = ""
language_str = ""
message_str = ""
logger.debug(f"the arguments {arguments}")
for token_state, token in tokenize(io.StringIO(arguments)):
if token_state == None:
Expand All @@ -146,6 +148,8 @@ def _parse_arguments(
state = TypingState.CODE
if token[1] == "language":
state = TypingState.LANGUAGE
if token[1] == "message":
state = TypingState.MESSAGE
else:
# String
if token_state == 9 and state == TypingState.EXPLANATION:
Expand All @@ -154,7 +158,9 @@ def _parse_arguments(
code_str = "".join(token)
elif token_state == 9 and state == TypingState.LANGUAGE:
language_str = "".join(token)
return (state, explanation_str, code_str, language_str)
elif token_state == 9 and state == TypingState.MESSAGE:
message_str = "".join(token)
return (state, explanation_str, code_str, language_str, message_str)

def _get_message_token_count(self, message):
response_token_count = 0
Expand All @@ -171,6 +177,7 @@ async def _read_function_call_message(
queue,
old_text_content,
old_code_content,
old_message_str,
language_str,
task_context,
task_opt,
Expand All @@ -185,6 +192,7 @@ async def _read_function_call_message(
queue,
old_text_content,
old_code_content,
old_message_str,
language_str,
task_context,
task_opt,
Expand All @@ -197,6 +205,7 @@ async def _read_json_message(
queue,
old_text_content,
old_code_content,
old_message_str,
old_language_str,
task_context,
task_opt,
Expand All @@ -208,6 +217,7 @@ async def _read_json_message(
queue,
old_text_content,
old_code_content,
old_message_str,
old_language_str,
task_context,
task_opt,
Expand All @@ -220,16 +230,22 @@ async def _send_typing_message(
old_text_content,
old_code_content,
old_language_str,
old_message_str,
task_context,
task_opt,
is_code=False,
):
"""
send the typing message to the client
"""
(state, explanation_str, code_str, language_str) = self._parse_arguments(
arguments, is_code
)
(
state,
explanation_str,
code_str,
language_str,
message_str,
) = self._parse_arguments(arguments, is_code)

logger.debug(
f"argument explanation:{explanation_str} code:{code_str} language_str:{language_str} text_content:{old_text_content}"
)
Expand All @@ -244,7 +260,7 @@ async def _send_typing_message(
context_id=task_context.context_id,
)
await queue.put(task_response)
return new_text_content, old_code_content, old_language_str
return new_text_content, old_code_content, old_language_str, old_message_str
if code_str and old_code_content != code_str:
typed_chars = code_str[len(old_code_content) :]
code_content = code_str
Expand All @@ -259,7 +275,7 @@ async def _send_typing_message(
context_id=task_context.context_id,
)
)
return old_text_content, code_content, old_language_str
return old_text_content, code_content, old_language_str, old_message_str
if language_str and old_language_str != language_str:
typed_chars = language_str[len(old_language_str) :]
if task_opt.streaming and len(typed_chars) > 0:
Expand All @@ -271,8 +287,22 @@ async def _send_typing_message(
context_id=task_context.context_id,
)
)
return old_text_content, old_code_content, language_str
return old_text_content, old_code_content, old_language_str
return old_text_content, old_code_content, language_str, old_message_str
if message_str and old_message_str != message_str:
typed_chars = message_str[len(old_message_str) :]
if task_opt.streaming and len(typed_chars) > 0:
await queue.put(
TaskResponse(
state=task_context.to_context_state_proto(),
response_type=TaskResponse.OnModelTypeText,
typing_content=TypingContent(
content=typed_chars, language="text"
),
context_id=task_context.context_id,
)
)
return old_text_content, old_code_content, old_language_str, message_str
return old_text_content, old_code_content, old_language_str, old_message_str

async def extract_message(
self,
Expand All @@ -291,6 +321,7 @@ async def extract_message(
text_content = ""
code_content = ""
language_str = ""
message_str = ""
context_output_token_count = task_context.output_token_count
start_time = time.time()
async for chunk in response_generator:
Expand All @@ -317,17 +348,21 @@ async def extract_message(
new_text_content,
new_code_content,
new_language_str,
new_message_str,
) = await self._read_function_call_message(
message,
queue,
text_content,
code_content,
message_str,
language_str,
task_context,
task_opt,
)
text_content = new_text_content
code_content = new_code_content
message_str = new_message_str
language_str = new_language_str
else:
self._merge_delta_for_content(message, delta)
task_context.llm_response_duration += int(
Expand All @@ -343,17 +378,22 @@ async def extract_message(
(
new_text_content,
new_code_content,
new_language_str,
new_message_str,
) = await self._read_json_message(
message,
queue,
text_content,
code_content,
message_str,
language_str,
task_context,
task_opt,
)
text_content = new_text_content
code_content = new_code_content
message_str = new_message_str
language_str = new_language_str

elif task_opt.streaming and delta.get("content"):
await queue.put(
Expand Down
2 changes: 1 addition & 1 deletion agent/src/og_agent/openai_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ async def handle_function(self, message, queue, context, task_context, task_opt)
if not task_opt.streaming
else ""
),
context_id=context_id,
context_id=task_context.context_id,
)
)
return None
Expand Down
43 changes: 42 additions & 1 deletion agent/tests/openai_agent_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,47 @@ async def test_openai_agent_call_execute_bash_code(mocker, kernel_sdk):
assert console_output[0].console_stdout == "hello world\n", "bad console output"


@pytest.mark.asyncio
async def test_openai_agent_direct_message(mocker, kernel_sdk):
kernel_sdk.connect()
arguments = {
"message": "hello world",
}
stream1 = FunctionCallPayloadStream("direct_message", json.dumps(arguments))
call_mock = MultiCallMock([stream1])
with mocker.patch(
"og_agent.openai_agent.openai.ChatCompletion.acreate",
side_effect=call_mock.call,
) as mock_openai:
agent = openai_agent.OpenaiAgent("gpt4", kernel_sdk, is_azure=False)
queue = asyncio.Queue()
task_opt = ProcessOptions(
streaming=False,
llm_name="gpt4",
input_token_limit=100000,
output_token_limit=100000,
timeout=5,
)
request = ProcessTaskRequest(
input_files=[],
task="say hello world",
context_id="",
options=task_opt,
)
await agent.arun(request, queue, MockContext(), task_opt)
responses = []
while True:
try:
response = await queue.get()
if not response:
break
responses.append(response)
except asyncio.QueueEmpty:
break
logger.info(responses)
assert responses[0].final_answer.answer == "hello world"


@pytest.mark.asyncio
async def test_openai_agent_call_execute_python_code(mocker, kernel_sdk):
kernel_sdk.connect()
Expand Down Expand Up @@ -245,5 +286,5 @@ async def test_openai_agent_smoke_test(mocker, kernel_sdk):
assert (
responses[-1].response_type == TaskResponse.OnFinalAnswer
), "bad response type"
assert responses[-1].state.input_token_count == 2
assert responses[-1].state.input_token_count == 325
assert responses[-1].state.output_token_count == 8
1 change: 1 addition & 0 deletions chat/src/og_terminal/ui_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ def add_code(self, code, language):
self.values.append(code)
else:
last_block.append(code)
last_block.language = language
self.values[last_block.get_index()] = last_block.content
else:
last_block.finish()
Expand Down

0 comments on commit c19bfd3

Please sign in to comment.