Skip to content

Commit

Permalink
fix: add start api server
Browse files Browse the repository at this point in the history
  • Loading branch information
imotai committed Oct 19, 2023
1 parent 2ba9fa3 commit 4513e71
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 31 deletions.
60 changes: 32 additions & 28 deletions agent/src/og_agent/agent_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,23 +47,21 @@ class StepResponseType(str, Enum):


class ContextState(BaseModel):
generated_token_count: int
iteration_count: int
llm_model_name: str
output_token_count: int
llm_name: str
total_duration: int
sent_token_count: int
llm_model_response_duration: int
output_token_count: int
llm_response_duration: int
context_id: str | None = None

@classmethod
def new_from(cls, state):
return cls(
generated_token_count=state.generated_token_count,
iteration_count=state.iteration_count,
llm_model_name=state.model_name,
output_token_count=state.output_token_count,
llm_name=state.llm_name,
total_duration=state.total_duration,
sent_token_count=state.sent_token_count,
llm_model_response_duration=state.model_respond_duration,
input_token_count=state.input_token_count,
llm_response_duration=state.llm_response_duration,
)


Expand All @@ -73,7 +71,7 @@ class StepActionEnd(BaseModel):
has_error: bool

@classmethod
def new_from(cls, step_action_end: agent_server_pb2.OnAgentActionEnd):
def new_from(cls, step_action_end: agent_server_pb2.OnStepActionEnd):
return cls(
output=step_action_end.output,
output_files=step_action_end.output_files,
Expand All @@ -85,7 +83,7 @@ class FinalAnswer(BaseModel):
answer: str

@classmethod
def new_from(cls, final_answer: agent_server_pb2.FinalRespond):
def new_from(cls, final_answer: agent_server_pb2.FinalAnswer):
return cls(answer=final_answer.answer)


Expand All @@ -94,7 +92,7 @@ class StepActionStart(BaseModel):
tool: str

@classmethod
def new_from(cls, step_action_start: agent_server_pb2.OnAgentAction):
def new_from(cls, step_action_start: agent_server_pb2.OnStepActionStart):
return cls(input=step_action_start.input, tool=step_action_start.tool)


Expand All @@ -109,49 +107,57 @@ class StepResponse(BaseModel):
final_answer: FinalAnswer | None = None

@classmethod
def new_from(cls, response: agent_server_pb2.TaskRespond):
if response.respond_type == agent_server_pb2.TaskRespond.OnAgentActionType:
def new_from(cls, response: agent_server_pb2.TaskResponse):
if response.response_type == agent_server_pb2.TaskResponse.OnStepActionStart:
return cls(
step_type=StepResponseType.OnStepActionStart,
step_state=ContextState.new_from(response.state),
step_action_start=StepActionStart.new_from(response.on_agent_action),
step_action_start=StepActionStart.new_from(
response.on_step_action_start
),
)
elif response.respond_type == agent_server_pb2.TaskRespond.OnAgentCodeTyping:
elif response.response_type == agent_server_pb2.TaskResponse.OnModelTypeCode:
return cls(
step_type=StepResponseType.OnStepCodeTyping,
step_state=ContextState.new_from(response.state),
typing_content=response.typing_content,
)

elif response.respond_type == agent_server_pb2.TaskRespond.OnAgentTextTyping:
elif response.response_type == agent_server_pb2.TaskResponse.OnModelTypeText:
return cls(
step_type=StepResponseType.OnStepTextTyping,
step_state=ContextState.new_from(response.state),
typing_content=response.typing_content,
)
elif response.respond_type == agent_server_pb2.TaskRespond.OnAgentActionStdout:
elif (
response.response_type
== agent_server_pb2.TaskResponse.OnStepActionStreamStdout
):
return cls(
step_type=StepResponseType.OnStepActionStdout,
step_state=ContextState.new_from(response.state),
step_action_stdout=response.console_stdout,
)
elif response.respond_type == agent_server_pb2.TaskRespond.OnAgentActionStderr:
elif (
response.response_type
== agent_server_pb2.TaskResponse.OnStepActionStreamStderr
):
return cls(
step_type=StepResponseType.OnStepActionStderr,
step_state=ContextState.new_from(response.state),
step_action_stderr=response.console_stderr,
)
elif response.respond_type == agent_server_pb2.TaskRespond.OnAgentActionEndType:
elif response.response_type == agent_server_pb2.TaskResponse.OnStepActionEnd:
return cls(
step_type=StepResponseType.OnStepActionEnd,
step_state=ContextState.new_from(response.state),
step_action_end=StepActionEnd.new_from(response.on_agent_action_end),
step_action_end=StepActionEnd.new_from(response.on_step_action_end),
)
elif response.respond_type == agent_server_pb2.TaskRespond.OnFinalAnswerType:
elif response.response_type == agent_server_pb2.TaskResponse.OnFinalAnswer:
return cls(
step_type=StepResponseType.OnFinalAnswer,
step_state=ContextState.new_from(response.state),
final_answer=FinalAnswer.new_from(response.final_respond),
final_answer=FinalAnswer.new_from(response.final_answer),
)


Expand All @@ -164,11 +170,9 @@ class TaskRequest(BaseModel):


async def run_task(task: TaskRequest, key):
index = 0
async for respond in agent_sdk.prompt(task.prompt, key, files=task.input_files):
response = StepResponse.new_from(respond).model_dump(exclude_none=True)
yield "\n" + json.dumps(response) if index > 0 else json.dumps(response)
index += 1
yield "data: %s\n" % json.dumps(response)


@app.post("/process")
Expand All @@ -181,7 +185,7 @@ async def process_task(
response.status_code = status.HTTP_401_UNAUTHORIZED
return
response.status_code = status.HTTP_200_OK
response.media_type = "application/json"
response.media_type = "text/event-stream"
agent_sdk.connect()
return StreamingResponse(run_task(task, api_token))

Expand Down
6 changes: 3 additions & 3 deletions agent/tests/agent_api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from pathlib import Path
from og_sdk.agent_sdk import AgentProxySDK
from og_sdk.utils import random_str
from og_proto.agent_server_pb2 import TaskRespond
from og_agent import agent_api_server

logger = logging.getLogger(__name__)
Expand All @@ -40,7 +39,8 @@ async def test_helloworld_test(agent_sdk):
)
responds = []
async for respond in agent_api_server.run_task(request, api_key):
responds.append(json.loads(respond))
json_data = respond[6:]
responds.append(json.loads(respond[6:]))
logger.debug(f"{responds}")
assert len(responds) > 0, "no responds for the prompt"
assert (
Expand All @@ -67,7 +67,7 @@ async def test_run_code_test(agent_sdk):
)
responds = []
async for respond in agent_api_server.run_task(request, api_key):
responds.append(json.loads(respond))
responds.append(json.loads(respond[6:]))
logger.debug(f"{responds}")
assert len(responds) > 0, "no responds for the prompt"
assert (
Expand Down
1 change: 1 addition & 0 deletions start_sandbox.sh
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ cases_path=${WORKDIR}/sdk/tests/mock_messages.json
EOF
og_agent_rpc_server > agent_rpc.log 2>&1 &
og_agent_http_server > agent_http.log 2>&1 &
sleep 2
echo "add a kernel"
og_agent_setup --kernel_endpoint=127.0.0.1:9527 --kernel_api_key=${KERNEL_RPC_KEY} --agent_endpoint=127.0.0.1:9528 --admin_key=${AGENT_RPC_KEY}
Expand Down

0 comments on commit 4513e71

Please sign in to comment.