Skip to content

Commit

Permalink
feat: openai agent passed
Browse files Browse the repository at this point in the history
  • Loading branch information
imotai committed Oct 31, 2023
1 parent f5e0e87 commit 9890c14
Show file tree
Hide file tree
Showing 10 changed files with 255 additions and 206 deletions.
4 changes: 1 addition & 3 deletions agent/src/og_agent/agent_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@ def build_llama_agent(endpoint, key, sdk, grammer_path):
"""
with open(grammer_path, "r") as fd:
grammar = fd.read()

client = LlamaClient(endpoint, key, grammar)

# init the agent
return LlamaAgent(client, sdk)

Expand All @@ -30,7 +28,7 @@ def build_openai_agent(sdk, model_name, is_azure=True):
# TODO a data dir per user
# init the agent

agent = OpenaiAgent(model_name, OCTOGEN_FUNCTION_SYSTEM, sdk, is_azure=is_azure)
agent = OpenaiAgent(model_name, sdk, is_azure=is_azure)
return agent


Expand Down
7 changes: 4 additions & 3 deletions agent/src/og_agent/agent_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ async def process_task(
sdk = self.agents[metadata["api_key"]]["sdk"]
queue = asyncio.Queue()

async def worker(task, agent, queue, context, task_opt):
return await agent.arun(task, queue, context, task_opt)
async def worker(request, agent, queue, context, task_opt):
return await agent.arun(request, queue, context, task_opt)

options = (
request.options
Expand All @@ -166,8 +166,9 @@ async def worker(task, agent, queue, context, task_opt):
timeout=10,
)
)

logger.debug("create the agent task")
task = asyncio.create_task(worker(request.task, agent, queue, context, options))
task = asyncio.create_task(worker(request, agent, queue, context, options))
while True:
logger.debug("start wait the queue message")
# TODO add timeout
Expand Down
103 changes: 64 additions & 39 deletions agent/src/og_agent/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,15 @@ class FunctionResult(BaseModel):
has_result: bool = False
has_error: bool = False


class TaskContext(BaseModel):
start_time: float = 0
output_token_count: int = 0
input_token_count: int = 0
llm_name: str = ""
llm_response_duration: int = 0
context_id: str = ""

def to_context_state_proto(self):
# in ms
total_duration = int((time.time() - self.start_time) * 1000)
Expand All @@ -52,10 +54,12 @@ def to_context_state_proto(self):
llm_response_duration=self.llm_response_duration,
)


class TypingState:
START = 0
EXPLANATION = 1
CODE = 2
LANGUAGE = 3


class BaseAgent:
Expand Down Expand Up @@ -117,17 +121,16 @@ def _parse_arguments(
self,
arguments,
is_code=False,
first_field_name="explanation",
second_field_name="code",
):
"""
parse the partial key with string value from json
"""
if is_code:
return TypingState.CODE, "", arguments
return TypingState.CODE, "", arguments, "python"
state = TypingState.START
explanation_str = ""
code_str = ""
language_str = ""
logger.debug(f"the arguments {arguments}")
for token_state, token in tokenize(io.StringIO(arguments)):
if token_state == None:
Expand All @@ -137,17 +140,21 @@ def _parse_arguments(
if state == TypingState.CODE and token[0] == 1:
code_str = token[1]
state = TypingState.START
if token[1] == first_field_name:
if token[1] == "explanation":
state = TypingState.EXPLANATION
if token[1] == second_field_name:
if token[1] == "code":
state = TypingState.CODE
if token[1] == "language":
state = TypingState.LANGUAGE
else:
# String
if token_state == 9 and state == TypingState.EXPLANATION:
explanation_str = "".join(token)
elif token_state == 9 and state == TypingState.CODE:
code_str = "".join(token)
return (state, explanation_str, code_str)
elif token_state == 9 and state == TypingState.LANGUAGE:
language_str = "".join(token)
return (state, explanation_str, code_str, language_str)

def _get_message_token_count(self, message):
response_token_count = 0
Expand All @@ -159,16 +166,16 @@ def _get_message_token_count(self, message):
return response_token_count

async def _read_function_call_message(
self, message, queue, old_text_content, old_code_content, task_context, task_opt
self,
message,
queue,
old_text_content,
old_code_content,
language_str,
task_context,
task_opt,
):
typing_language = "text"
if message["function_call"].get("name", "") in [
"execute_python_code",
"python",
]:
typing_language = "python"
elif message["function_call"].get("name", "") == "execute_bash_code":
typing_language = "bash"
is_code = False
if message["function_call"].get("name", "") == "python":
is_code = True
Expand All @@ -178,28 +185,30 @@ async def _read_function_call_message(
queue,
old_text_content,
old_code_content,
typing_language,
language_str,
task_context,
task_opt,
is_code=is_code,
)

async def _read_json_message(
self, message, queue, old_text_content, old_code_content, task_context, task_opt
self,
message,
queue,
old_text_content,
old_code_content,
old_language_str,
task_context,
task_opt,
):
arguments = message.get("content", "")
typing_language = "text"
if arguments.find("execute_python_code") >= 0:
typing_language = "python"
elif arguments.find("execute_bash_code") >= 0:
typing_language = "bash"

return await self._send_typing_message(
arguments,
queue,
old_text_content,
old_code_content,
typing_language,
old_language_str,
task_context,
task_opt,
)
Expand All @@ -210,18 +219,19 @@ async def _send_typing_message(
queue,
old_text_content,
old_code_content,
language,
old_language_str,
task_context,
task_opt,
is_code=False,
):
"""
send the typing message to the client
"""
task_opt = request
(state, explanation_str, code_str) = self._parse_arguments(arguments, is_code)
(state, explanation_str, code_str, language_str) = self._parse_arguments(
arguments, is_code
)
logger.debug(
f"argument explanation:{explanation_str} code:{code_str} text_content:{old_text_content}"
f"argument explanation:{explanation_str} code:{code_str} language_str:{language_str} text_content:{old_text_content}"
)
if explanation_str and old_text_content != explanation_str:
typed_chars = explanation_str[len(old_text_content) :]
Expand All @@ -231,10 +241,10 @@ async def _send_typing_message(
state=task_context.to_context_state_proto(),
response_type=TaskResponse.OnModelTypeText,
typing_content=TypingContent(content=typed_chars, language="text"),
context_id=task_context.context_id
context_id=task_context.context_id,
)
await queue.put(task_response)
return new_text_content, old_code_content
return new_text_content, old_code_content, old_language_str
if code_str and old_code_content != code_str:
typed_chars = code_str[len(old_code_content) :]
code_content = code_str
Expand All @@ -244,31 +254,43 @@ async def _send_typing_message(
state=task_context.to_context_state_proto(),
response_type=TaskResponse.OnModelTypeCode,
typing_content=TypingContent(
content=typed_chars, language=language
content=typed_chars, language="text"
),
context_id=task_context.context_id
context_id=task_context.context_id,
)
)
return old_text_content, code_content, old_language_str
if language_str and old_language_str != language_str:
typed_chars = language_str[len(old_language_str) :]
if task_opt.streaming and len(typed_chars) > 0:
await queue.put(
TaskResponse(
state=task_context.to_context_state_proto(),
response_type=TaskResponse.OnModelTypeCode,
typing_content=TypingContent(content="", language=language_str),
context_id=task_context.context_id,
)
)
return old_text_content, code_content
return old_text_content, old_code_content
return old_text_content, old_code_content, language_str
return old_text_content, old_code_content, old_language_str

async def extract_message(
self,
response_generator,
queue,
rpc_context,
task_context,
request,
task_opt,
start_time,
is_json_format=False,
):
"""
extract the chunk from the response generator
"""
task_opt = request.options
message = {}
text_content = ""
code_content = ""
language_str = ""
context_output_token_count = task_context.output_token_count
start_time = time.time()
async for chunk in response_generator:
Expand All @@ -294,11 +316,13 @@ async def extract_message(
(
new_text_content,
new_code_content,
new_language_str,
) = await self._read_function_call_message(
message,
queue,
text_content,
code_content,
language_str,
task_context,
task_opt,
)
Expand All @@ -324,6 +348,7 @@ async def extract_message(
queue,
text_content,
code_content,
language_str,
task_context,
task_opt,
)
Expand All @@ -338,7 +363,7 @@ async def extract_message(
typing_content=TypingContent(
content=delta["content"], language="text"
),
context_id=task_context.context_id
context_id=task_context.context_id,
)
)
logger.info(
Expand Down Expand Up @@ -375,7 +400,7 @@ async def call_function(self, code, context, task_context):
state=task_context.to_context_state_proto(),
response_type=TaskResponse.OnStepActionStreamStdout,
console_stdout=kernel_output,
context_id=task_context.context_id
context_id=task_context.context_id,
),
)
# process the stderr
Expand All @@ -390,7 +415,7 @@ async def call_function(self, code, context, task_context):
state=task_context.to_context_state_proto(),
response_type=TaskResponse.OnStepActionStreamStderr,
console_stderr=kernel_err,
context_id=task_context.context_id
context_id=task_context.context_id,
),
)
elif kernel_respond.output_type == ExecuteResponse.TracebackType:
Expand All @@ -404,7 +429,7 @@ async def call_function(self, code, context, task_context):
state=task_context.to_context_state_proto(),
response_type=TaskResponse.OnStepActionStreamStderr,
console_stderr=traceback,
context_id=task_context.context_id
context_id=task_context.context_id,
),
)
else:
Expand All @@ -428,7 +453,7 @@ async def call_function(self, code, context, task_context):
state=task_context.to_context_state_proto(),
response_type=TaskResponse.OnStepActionStreamStdout,
console_stdout=console_stdout,
context_id=task_context.context_id
context_id=task_context.context_id,
),
)
output_files = []
Expand Down
Loading

0 comments on commit 9890c14

Please sign in to comment.