diff --git a/agent/src/og_agent/codellama_agent.py b/agent/src/og_agent/codellama_agent.py index 61ab616..820dcfe 100644 --- a/agent/src/og_agent/codellama_agent.py +++ b/agent/src/og_agent/codellama_agent.py @@ -11,7 +11,7 @@ import io import time from .codellama_client import CodellamaClient -from og_proto.agent_server_pb2 import OnAgentAction, TaskRespond, OnAgentActionEnd, FinalRespond +from og_proto.agent_server_pb2 import OnStepActionStart, TaskResponse, OnStepActionEnd, FinalAnswer from .base_agent import BaseAgent, TypingState, TaskContext from .tokenizer import tokenize import tiktoken @@ -84,14 +84,14 @@ async def handle_bash_code(self, json_response, queue, context, task_context): "code": commands, "explanation": explanation, "saved_filenames": saved_filenames, - "language": json_response.get("language", "text"), + "language": json_response.get("language"), }) await queue.put( TaskRespond( state=task_context.to_task_state_proto(), respond_type=TaskRespond.OnAgentActionType, on_agent_action=OnAgentAction( - input=tool_input, tool="execute_python_code" + input=tool_input, tool="execute_bash_code" ), ) ) @@ -105,7 +105,9 @@ async def handle_bash_code(self, json_response, queue, context, task_context): await queue.put(respond) return function_result - async def handle_function(self, json_response, queue, context, task_context): + async def handle_function( + self, json_response, queue, context, task_context, task_opt + ): code = json_response["action_input"] explanation = json_response["explanation"] saved_filenames = json_response.get("saved_filenames", []) @@ -113,15 +115,14 @@ async def handle_function(self, json_response, queue, context, task_context): "code": code, "explanation": explanation, "saved_filenames": saved_filenames, - "language": json_response.get("language", "text"), + "language": json_response.get("language"), }) - await queue.put( - TaskRespond( - state=task_context.to_task_state_proto(), - respond_type=TaskRespond.OnAgentActionType, - on_agent_action=OnAgentAction( - input=tool_input, tool="execute_python_code" + TaskResponse( + state=task_context.to_context_state_proto(), + response_type=TaskResponse.OnStepActionStart, + on_agent_action_start=OnStepActionStart( + input=tool_input, tool=json_response["action"] ), ) ) @@ -131,7 +132,7 @@ async def handle_function(self, json_response, queue, context, task_context): logger.debug("the client has cancelled the request") break function_result = result - if respond: + if respond and task_opt.streaming: await queue.put(respond) return function_result @@ -160,14 +161,15 @@ def _get_argument_new_typing(self, message): return (state, explanation_str, action_input_str) async def call_codellama( - self, question, chat_history, queue, context, task_context + self, question, chat_history, queue, context, task_context, task_opt ): """ call codellama api """ start_time = time.time() num_tokens = len(encoding.encode(question)) + len(encoding.encode(chat_history)) - task_context.sent_token_count += num_tokens + task_context.input_token_count += num_tokens + output_token_count = task_context.output_token_count state = None message = "" text_content = "" @@ -179,14 +181,11 @@ async def call_codellama( logger.debug("the client has cancelled the request") break respond = json.loads(line[6:]) - task_context.generated_token_count += len( - encoding.encode(respond["content"]) - ) - task_context.model_respond_duration += int( - (time.time() - start_time) * 1000 - ) + task_context.llm_response_duration += int((time.time() - start_time) * 1000) start_time = time.time() message += respond["content"] + response_token_count = len(encoding.encode(message)) + task_context.output_token_count = output_token_count + response_token_count logger.debug(f" message {message}") ( state, @@ -196,32 +195,33 @@ async def call_codellama( if explanation_str and text_content != explanation_str: typed_chars = explanation_str[len(text_content) :] text_content = explanation_str - await queue.put( - TaskRespond( - state=task_context.to_task_state_proto(), - respond_type=TaskRespond.OnAgentTextTyping, - typing_content=typed_chars, + if task_opt.streaming: + await queue.put( + TaskResponse( + state=task_context.to_context_state_proto(), + response_type=TaskResponse.OnModelTypeText, + typing_content=typed_chars, + ) ) - ) if action_input_str and code_content != action_input_str: typed_chars = action_input_str[len(code_content) :] code_content = action_input_str await queue.put( - TaskRespond( - state=task_context.to_task_state_proto(), - respond_type=TaskRespond.OnAgentCodeTyping, + TaskResponse( + state=task_context.to_context_state_proto(), + response_type=TaskResponse.OnModelTypeCode, typing_content=typed_chars, ) ) logger.debug( f"argument explanation:{explanation_str} code:{action_input_str}" ) - if respond["stop"]: + if respond.get("stop", ""): state = respond return (message, state) - async def arun(self, question, queue, context, max_iteration=5): + async def arun(self, question, queue, context, task_opt): """ run the agent """ @@ -229,46 +229,58 @@ async def arun(self, question, queue, context, max_iteration=5): current_question = question task_context = TaskContext( start_time=time.time(), - generated_token_count=0, - sent_token_count=0, - model_name="codellama", - iteration_count=0, - model_respond_duration=0, + output_token_count=0, + input_token_count=0, + llm_name="codellama", + llm_respond_duration=0, ) - iteration = 0 try: - while iteration < max_iteration: - if context.done(): - logger.debug("the client has cancelled the request") + while not context.done(): + if task_context.input_token_count >= task_opt.input_token_limit: + await queue.put( + TaskResponse( + state=task_context.to_context_state_proto(), + response_type=TaskResponse.OnInputTokenLimitExceed, + error_msg="input token limit reached", + ) + ) + break + if task_context.output_token_count >= task_opt.output_token_limit: + await queue.put( + TaskResponse( + state=task_context.to_context_state_proto(), + response_type=TaskResponse.OnOutputTokenLimitExceed, + error_msg="output token limit reached", + ) + ) break - iteration += 1 - task_context.iteration_count = iteration chat_history = "\n".join(history) (message, state) = await self.call_codellama( - current_question, chat_history, queue, context, task_context + current_question, + chat_history, + queue, + context, + task_context, + task_opt, ) try: json_response = json.loads(message) if not json_response: await queue.put( - TaskRespond( - state=task_context.to_task_state_proto(), - respond_type=TaskRespond.OnFinalAnswerType, - final_respond=FinalRespond( - answer=self._output_exception() - ), + TaskResponse( + state=task_context.to_context_state_proto(), + response_type=TaskResponse.OnModelOutputError, + error_msg=self._output_exception(), ) ) break except Exception as ex: logger.exception(f"fail to load message the message is {message}") await queue.put( - TaskRespond( - state=task_context.to_task_state_proto(), - respond_type=TaskRespond.OnFinalAnswerType, - final_respond=FinalRespond( - answer="The model made an invalid respone" - ), + TaskResponse( + state=task_context.to_context_state_proto(), + response_type=TaskResponse.OnModelOutputError, + error_msg=str(ex), ) ) break @@ -283,15 +295,18 @@ async def arun(self, question, queue, context, max_iteration=5): "execute_bash_code": self.handle_bash_code, } function_result = await tools_mapping[json_response["action"]]( - json_response, queue, context, task_context + json_response, queue, context, task_context, task_opt ) logger.debug(f"the function result {function_result}") await queue.put( - TaskRespond( - state=task_context.to_task_state_proto(), - respond_type=TaskRespond.OnAgentActionEndType, + TaskResponse( + state=task_context.to_context_state_proto(), + response_type=TaskRespond.OnStepActionEnd, on_agent_action_end=OnAgentActionEnd( - output="", + output="" + if task_opt.streaming + else function_result.console_stderr + + function_result.console_stdout, output_files=function_result.saved_filenames, has_error=function_result.has_error, ), @@ -333,20 +348,20 @@ async def arun(self, question, queue, context, max_iteration=5): ) result = self._format_output(json_response) await queue.put( - TaskRespond( - state=task_context.to_task_state_proto(), - respond_type=TaskRespond.OnFinalAnswerType, - final_respond=FinalRespond(answer=result), + TaskResponse( + state=task_context.to_context_state_proto(), + response_type=TaskRespond.OnFinalAnswer, + final_answer=FinalAnswer(answer=result), ) ) break else: result = self._format_output(json_response) await queue.put( - TaskRespond( - state=task_context.to_task_state_proto(), - respond_type=TaskRespond.OnFinalAnswerType, - final_respond=FinalRespond(answer=result), + TaskResponse( + state=task_context.to_context_state_proto(), + response_type=TaskResponse.OnFinalAnswer, + final_answer=FinalAnswer(answer=result), ) ) break diff --git a/agent/src/og_agent/mock_agent.py b/agent/src/og_agent/mock_agent.py index 1a06785..b986869 100644 --- a/agent/src/og_agent/mock_agent.py +++ b/agent/src/og_agent/mock_agent.py @@ -8,7 +8,7 @@ import time import logging from .base_agent import BaseAgent, TypingState, TaskContext -from og_proto.agent_server_pb2 import OnAgentAction, TaskRespond, OnAgentActionEnd, FinalRespond +from og_proto.agent_server_pb2 import OnStepActionStart, TaskResponse, OnStepActionEnd, FinalAnswer from .tokenizer import tokenize logger = logging.getLogger(__name__) @@ -74,14 +74,12 @@ async def arun(self, task, queue, context, max_iteration=5): """ run the agent """ - iteration = 0 task_context = TaskContext( start_time=time.time(), - generated_token_count=10, - sent_token_count=10, - model_name="mock", - iteration_count=1, - model_respond_duration=1000, + output_token_count=10, + input_token_count=10, + llm_name="mock", + llm_respond_duration=1000, ) try: while iteration < max_iteration: diff --git a/agent/tests/codellama_agent_tests.py b/agent/tests/codellama_agent_tests.py new file mode 100644 index 0000000..5602373 --- /dev/null +++ b/agent/tests/codellama_agent_tests.py @@ -0,0 +1,104 @@ +# vim:fenc=utf-8 + +# SPDX-FileCopyrightText: 2023 imotai +# SPDX-FileContributor: imotai +# +# SPDX-License-Identifier: Elastic-2.0 + +""" """ + +import json +import logging +import pytest +from og_sdk.agent_sdk import AgentSDK +from og_agent.codellama_agent import CodellamaAgent +from og_proto.agent_server_pb2 import ProcessOptions, TaskResponse +import asyncio +import pytest_asyncio + +api_base = "127.0.0.1:9528" +api_key = "ZCeI9cYtOCyLISoi488BgZHeBkHWuFUH" +logger = logging.getLogger(__name__) + + +class PayloadStream: + + def __init__(self, payload): + self.payload = payload + + def __aiter__(self): + # create an iterator of the input keys + self.iter_keys = iter(self.payload) + return self + + async def __anext__(self): + try: + k = {"content": next(self.iter_keys)} + output = "data: %s\n" % json.dumps(k) + return output + except StopIteration: + # raise stopasynciteration at the end of iterator + raise StopAsyncIteration + + +class MockContext: + + def done(self): + return False + + +class CodellamaMockClient: + + def __init__(self, payload): + self.payload = payload + + async def prompt(self, question, chat_history=[]): + async for line in PayloadStream(self.payload): + yield line + + +@pytest_asyncio.fixture +async def agent_sdk(): + sdk = AgentSDK(api_base, api_key) + sdk.connect() + yield sdk + await sdk.close() + + +@pytest.mark.asyncio +async def test_codellama_agent_smoke_test(agent_sdk): + sentence = { + "explanation": "Hello, how can I help you?", + "action": "no_action", + "action_input": "", + "saved_filenames": [], + "language": "en", + "is_final_answer": True, + } + client = CodellamaMockClient(json.dumps(sentence)) + agent = CodellamaAgent(client, agent_sdk) + task_opt = ProcessOptions( + streaming=True, + llm_name="codellama", + input_token_limit=100000, + output_token_limit=100000, + timeout=5, + ) + queue = asyncio.Queue() + await agent.arun("hello", queue, MockContext(), task_opt) + responses = [] + while True: + try: + response = await queue.get() + if not response: + break + responses.append(response) + except asyncio.QueueEmpty: + break + logger.info(responses) + assert len(responses) == len(sentence["explanation"]) + 1, "bad response count" + assert ( + responses[-1].response_type == TaskResponse.OnFinalAnswer + ), "bad response type" + assert responses[-1].state.input_token_count == 1 + assert responses[-1].state.output_token_count == 43 diff --git a/agent/tests/openai_agent_tests.py b/agent/tests/openai_agent_tests.py index f37a730..9fbc237 100644 --- a/agent/tests/openai_agent_tests.py +++ b/agent/tests/openai_agent_tests.py @@ -14,6 +14,7 @@ from og_proto.agent_server_pb2 import ProcessOptions, TaskResponse from openai.openai_object import OpenAIObject import asyncio +import pytest_asyncio api_base = "127.0.0.1:9528" api_key = "ZCeI9cYtOCyLISoi488BgZHeBkHWuFUH" @@ -53,17 +54,15 @@ def done(self): return False -@pytest.fixture -def agent_sdk(): +@pytest_asyncio.fixture +async def agent_sdk(): sdk = AgentSDK(api_base, api_key) sdk.connect() yield sdk - sdk.close() - + await sdk.close() @pytest.mark.asyncio async def test_openai_agent_smoke_test(mocker, agent_sdk): - await agent_sdk.add_kernel(api_key, "127.0.0.1:9527") sentence = "Hello, how can I help you?" stream = PayloadStream(sentence) with mocker.patch( diff --git a/sdk/src/og_sdk/agent_sdk.py b/sdk/src/og_sdk/agent_sdk.py index 2b48d98..ac76f25 100644 --- a/sdk/src/og_sdk/agent_sdk.py +++ b/sdk/src/og_sdk/agent_sdk.py @@ -65,27 +65,6 @@ def __init__(self, endpoint, api_key): def connect(self): self.connect_sync() - def assemble(self, name, code, language, desc="", saved_filenames=[]): - request = agent_server_pb2.AssembleAppRequest( - name=name, - language=language, - code=code, - saved_filenames=saved_filenames, - desc=desc, - ) - response = self.stub.assemble(request, metadata=self.metadata) - return response - - def run(self, name): - # TODO support input files - request = agent_server_pb2.RunAppRequest(name=name) - for respond in self.stub.run(request, metadata=self.metadata): - yield respond - - def query_apps(self): - request = agent_server_pb2.QueryAppsRequest() - return self.stub.query_apps(request, metadata=self.metadata) - def add_kernel(self, key, endpoint): """ add kernel instance to the agent and only admin can call this method @@ -154,6 +133,11 @@ async def prompt(self, prompt, api_key, files=[]): async for respond in self.stub.send_task(request, metadata=metadata): yield respond + async def close(self): + if self.channel: + await self.channel.close() + self.channel = None + class AgentSDK(AgentBaseSDK): @@ -171,28 +155,6 @@ async def ping(self): response = await self.stub.ping(request, metadata=self.metadata) return response - async def assemble(self, name, code, language, desc="", saved_filenames=[]): - request = agent_server_pb2.AssembleAppRequest( - name=name, - language=language, - code=code, - saved_filenames=saved_filenames, - desc=desc, - ) - response = await self.stub.assemble(request, metadata=self.metadata) - return response - - async def run(self, name): - # TODO support input files - request = agent_server_pb2.RunAppRequest(name=name) - async for respond in self.stub.run(request, metadata=self.metadata): - yield respond - - async def query_apps(self): - """query all apps""" - request = agent_server_pb2.QueryAppsRequest() - return await self.stub.query_apps(request, metadata=self.metadata) - async def add_kernel(self, key, endpoint): """ add kernel instance to the agent and only admin can call this method @@ -229,7 +191,7 @@ async def upload_file(self, filepath, filename): # TODO limit the file size return await self.upload_binary(generate_async_chunk(filepath, filename)) - def close(self): + async def close(self): if self.channel: - self.channel.close() + await self.channel.close() self.channel = None