diff --git a/agent/src/og_agent/agent_api_server.py b/agent/src/og_agent/agent_api_server.py index 011f050..a5ac299 100644 --- a/agent/src/og_agent/agent_api_server.py +++ b/agent/src/og_agent/agent_api_server.py @@ -47,23 +47,21 @@ class StepResponseType(str, Enum): class ContextState(BaseModel): - generated_token_count: int - iteration_count: int - llm_model_name: str + output_token_count: int + llm_name: str total_duration: int - sent_token_count: int - llm_model_response_duration: int + output_token_count: int + llm_response_duration: int context_id: str | None = None @classmethod def new_from(cls, state): return cls( - generated_token_count=state.generated_token_count, - iteration_count=state.iteration_count, - llm_model_name=state.model_name, + output_token_count=state.output_token_count, + llm_name=state.llm_name, total_duration=state.total_duration, - sent_token_count=state.sent_token_count, - llm_model_response_duration=state.model_respond_duration, + input_token_count=state.input_token_count, + llm_response_duration=state.llm_response_duration, ) @@ -73,7 +71,7 @@ class StepActionEnd(BaseModel): has_error: bool @classmethod - def new_from(cls, step_action_end: agent_server_pb2.OnAgentActionEnd): + def new_from(cls, step_action_end: agent_server_pb2.OnStepActionEnd): return cls( output=step_action_end.output, output_files=step_action_end.output_files, @@ -85,7 +83,7 @@ class FinalAnswer(BaseModel): answer: str @classmethod - def new_from(cls, final_answer: agent_server_pb2.FinalRespond): + def new_from(cls, final_answer: agent_server_pb2.FinalAnswer): return cls(answer=final_answer.answer) @@ -94,7 +92,7 @@ class StepActionStart(BaseModel): tool: str @classmethod - def new_from(cls, step_action_start: agent_server_pb2.OnAgentAction): + def new_from(cls, step_action_start: agent_server_pb2.OnStepActionStart): return cls(input=step_action_start.input, tool=step_action_start.tool) @@ -109,49 +107,57 @@ class StepResponse(BaseModel): final_answer: FinalAnswer | None = None @classmethod - def new_from(cls, response: agent_server_pb2.TaskRespond): - if response.respond_type == agent_server_pb2.TaskRespond.OnAgentActionType: + def new_from(cls, response: agent_server_pb2.TaskResponse): + if response.response_type == agent_server_pb2.TaskResponse.OnStepActionStart: return cls( step_type=StepResponseType.OnStepActionStart, step_state=ContextState.new_from(response.state), - step_action_start=StepActionStart.new_from(response.on_agent_action), + step_action_start=StepActionStart.new_from( + response.on_step_action_start + ), ) - elif response.respond_type == agent_server_pb2.TaskRespond.OnAgentCodeTyping: + elif response.response_type == agent_server_pb2.TaskResponse.OnModelTypeCode: return cls( step_type=StepResponseType.OnStepCodeTyping, step_state=ContextState.new_from(response.state), typing_content=response.typing_content, ) - elif response.respond_type == agent_server_pb2.TaskRespond.OnAgentTextTyping: + elif response.response_type == agent_server_pb2.TaskResponse.OnModelTypeText: return cls( step_type=StepResponseType.OnStepTextTyping, step_state=ContextState.new_from(response.state), typing_content=response.typing_content, ) - elif response.respond_type == agent_server_pb2.TaskRespond.OnAgentActionStdout: + elif ( + response.response_type + == agent_server_pb2.TaskResponse.OnStepActionStreamStdout + ): return cls( step_type=StepResponseType.OnStepActionStdout, step_state=ContextState.new_from(response.state), step_action_stdout=response.console_stdout, ) - elif response.respond_type == agent_server_pb2.TaskRespond.OnAgentActionStderr: + elif ( + response.response_type + == agent_server_pb2.TaskResponse.OnStepActionStreamStderr + ): return cls( step_type=StepResponseType.OnStepActionStderr, step_state=ContextState.new_from(response.state), step_action_stderr=response.console_stderr, ) - elif response.respond_type == agent_server_pb2.TaskRespond.OnAgentActionEndType: + elif response.response_type == agent_server_pb2.TaskResponse.OnStepActionEnd: return cls( step_type=StepResponseType.OnStepActionEnd, step_state=ContextState.new_from(response.state), - step_action_end=StepActionEnd.new_from(response.on_agent_action_end), + step_action_end=StepActionEnd.new_from(response.on_step_action_end), ) - elif response.respond_type == agent_server_pb2.TaskRespond.OnFinalAnswerType: + elif response.response_type == agent_server_pb2.TaskResponse.OnFinalAnswer: return cls( step_type=StepResponseType.OnFinalAnswer, step_state=ContextState.new_from(response.state), - final_answer=FinalAnswer.new_from(response.final_respond), + final_answer=FinalAnswer.new_from(response.final_answer), ) @@ -164,11 +170,9 @@ class TaskRequest(BaseModel): async def run_task(task: TaskRequest, key): - index = 0 async for respond in agent_sdk.prompt(task.prompt, key, files=task.input_files): response = StepResponse.new_from(respond).model_dump(exclude_none=True) - yield "\n" + json.dumps(response) if index > 0 else json.dumps(response) - index += 1 + yield "data: %s\n" % json.dumps(response) @app.post("/process") @@ -181,7 +185,7 @@ async def process_task( response.status_code = status.HTTP_401_UNAUTHORIZED return response.status_code = status.HTTP_200_OK - response.media_type = "application/json" + response.media_type = "text/event-stream" agent_sdk.connect() return StreamingResponse(run_task(task, api_token)) diff --git a/agent/tests/agent_api_tests.py b/agent/tests/agent_api_tests.py index dbcf790..0c166f2 100644 --- a/agent/tests/agent_api_tests.py +++ b/agent/tests/agent_api_tests.py @@ -16,7 +16,6 @@ from pathlib import Path from og_sdk.agent_sdk import AgentProxySDK from og_sdk.utils import random_str -from og_proto.agent_server_pb2 import TaskRespond from og_agent import agent_api_server logger = logging.getLogger(__name__) @@ -40,7 +39,8 @@ async def test_helloworld_test(agent_sdk): ) responds = [] async for respond in agent_api_server.run_task(request, api_key): - responds.append(json.loads(respond)) + json_data = respond[6:] + responds.append(json.loads(respond[6:])) logger.debug(f"{responds}") assert len(responds) > 0, "no responds for the prompt" assert ( @@ -67,7 +67,7 @@ async def test_run_code_test(agent_sdk): ) responds = [] async for respond in agent_api_server.run_task(request, api_key): - responds.append(json.loads(respond)) + responds.append(json.loads(respond[6:])) logger.debug(f"{responds}") assert len(responds) > 0, "no responds for the prompt" assert ( diff --git a/start_sandbox.sh b/start_sandbox.sh index ff9d65a..63ec67c 100644 --- a/start_sandbox.sh +++ b/start_sandbox.sh @@ -41,6 +41,7 @@ cases_path=${WORKDIR}/sdk/tests/mock_messages.json EOF og_agent_rpc_server > agent_rpc.log 2>&1 & +og_agent_http_server > agent_http.log 2>&1 & sleep 2 echo "add a kernel" og_agent_setup --kernel_endpoint=127.0.0.1:9527 --kernel_api_key=${KERNEL_RPC_KEY} --agent_endpoint=127.0.0.1:9528 --admin_key=${AGENT_RPC_KEY}