Skip to content

Commit

Permalink
fix: fix the mock environment bug
Browse files Browse the repository at this point in the history
  • Loading branch information
imotai committed Oct 19, 2023
1 parent 50e3c31 commit 2ba9fa3
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 130 deletions.
1 change: 0 additions & 1 deletion agent/src/og_agent/agent_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,6 @@ async def worker(task, agent, queue, context, task_opt):
task = asyncio.create_task(
worker(request.task, agent, queue, context, request.options)
)

while True:
try:
logger.debug("start wait the queue message")
Expand Down
42 changes: 22 additions & 20 deletions agent/src/og_agent/mock_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,17 @@ async def call_ai(self, prompt, queue, iteration, task_context):
message = self.messages.get(prompt)[iteration]
if message.get("explanation", None):
await queue.put(
TaskRespond(
state=task_context.to_task_state_proto(),
respond_type=TaskRespond.OnAgentTextTyping,
TaskResponse(
state=task_context.to_context_state_proto(),
response_type=TaskResponse.OnModelTypeText,
typing_content=message["explanation"],
)
)
if message.get("code", None):
await queue.put(
TaskRespond(
state=task_context.to_task_state_proto(),
respond_type=TaskRespond.OnAgentCodeTyping,
TaskResponse(
state=task_context.to_context_state_proto(),
response_type=TaskResponse.OnModelTypeCode,
typing_content=message["code"],
)
)
Expand All @@ -53,12 +53,13 @@ async def handle_call_function(
"code": code,
"explanation": explanation,
"saved_filenames": saved_filenames,
"language": "python",
})
await queue.put(
TaskRespond(
state=task_context.to_task_state_proto(),
respond_type=TaskRespond.OnAgentActionType,
on_agent_action=OnAgentAction(
TaskResponse(
state=task_context.to_context_state_proto(),
response_type=TaskResponse.OnStepActionStart,
on_step_action_start=OnStepActionStart(
input=tool_input, tool="execute_python_code"
),
)
Expand All @@ -70,7 +71,7 @@ async def handle_call_function(
await queue.put(respond)
return function_result

async def arun(self, task, queue, context, max_iteration=5):
async def arun(self, task, queue, context, task_opt):
"""
run the agent
"""
Expand All @@ -81,8 +82,9 @@ async def arun(self, task, queue, context, max_iteration=5):
llm_name="mock",
llm_respond_duration=1000,
)
iteration = 0
try:
while iteration < max_iteration:
while iteration <= 10:
message = await self.call_ai(task, queue, iteration, task_context)
iteration = iteration + 1
if message.get("code", None):
Expand All @@ -95,10 +97,10 @@ async def arun(self, task, queue, context, max_iteration=5):
message.get("saved_filenames", []),
)
await queue.put(
TaskRespond(
state=task_context.to_task_state_proto(),
respond_type=TaskRespond.OnAgentActionEndType,
on_agent_action_end=OnAgentActionEnd(
TaskResponse(
state=task_context.to_context_state_proto(),
response_type=TaskResponse.OnStepActionEnd,
on_step_action_end=OnStepActionEnd(
output="",
output_files=function_result.saved_filenames,
has_error=function_result.has_error,
Expand All @@ -107,10 +109,10 @@ async def arun(self, task, queue, context, max_iteration=5):
)
else:
await queue.put(
TaskRespond(
state=task_context.to_task_state_proto(),
respond_type=TaskRespond.OnFinalAnswerType,
final_respond=FinalRespond(answer=message["explanation"]),
TaskResponse(
state=task_context.to_context_state_proto(),
response_type=TaskResponse.OnFinalAnswer,
final_answer=FinalAnswer(answer=message["explanation"]),
)
)
break
Expand Down
1 change: 1 addition & 0 deletions agent/tests/openai_agent_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ async def agent_sdk():
yield sdk
await sdk.close()


@pytest.mark.asyncio
async def test_openai_agent_smoke_test(mocker, agent_sdk):
sentence = "Hello, how can I help you?"
Expand Down
12 changes: 6 additions & 6 deletions sdk/src/og_sdk/agent_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ def prompt(self, prompt, files=[]):
"""
ask the ai with prompt and uploaded files
"""
request = agent_server_pb2.SendTaskRequest(task=prompt, input_files=files)
for respond in self.stub.send_task(request, metadata=self.metadata):
request = agent_server_pb2.ProcessTaskRequest(task=prompt, input_files=files)
for respond in self.stub.process_task(request, metadata=self.metadata):
yield respond


Expand Down Expand Up @@ -129,8 +129,8 @@ async def prompt(self, prompt, api_key, files=[]):
metadata = aio.Metadata(
("api_key", api_key),
)
request = agent_server_pb2.SendTaskRequest(task=prompt, input_files=files)
async for respond in self.stub.send_task(request, metadata=metadata):
request = agent_server_pb2.ProcessTaskRequest(task=prompt, input_files=files)
async for respond in self.stub.process_task(request, metadata=metadata):
yield respond

async def close(self):
Expand Down Expand Up @@ -167,8 +167,8 @@ async def prompt(self, prompt, files=[]):
"""
ask the ai with prompt and uploaded files
"""
request = agent_server_pb2.SendTaskRequest(task=prompt, input_files=files)
async for respond in self.stub.send_task(request, metadata=self.metadata):
request = agent_server_pb2.ProcessTaskRequest(task=prompt, input_files=files)
async for respond in self.stub.process_task(request, metadata=self.metadata):
yield respond

async def download_file(self, filename, parent_path):
Expand Down
116 changes: 13 additions & 103 deletions sdk/tests/agent_sdk_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,20 @@
from pathlib import Path
from og_sdk.agent_sdk import AgentSDK
from og_sdk.utils import random_str
from og_proto.agent_server_pb2 import TaskRespond
from og_proto.agent_server_pb2 import TaskResponse
import pytest_asyncio

logger = logging.getLogger(__name__)
api_base = "127.0.0.1:9528"
api_key = "ZCeI9cYtOCyLISoi488BgZHeBkHWuFUH"


@pytest.fixture
def agent_sdk():
@pytest_asyncio.fixture
async def agent_sdk():
sdk = AgentSDK(api_base, api_key)
sdk.connect()
yield sdk
sdk.close()
await sdk.close()


def test_connect_bad_endpoint():
Expand Down Expand Up @@ -78,9 +79,9 @@ async def test_prompt_smoke_test(agent_sdk):
responds.append(respond)
logger.debug(f"{responds}")
assert len(responds) > 0, "no responds for the prompt"
assert responds[len(responds) - 1].respond_type == TaskRespond.OnFinalAnswerType
assert responds[len(responds) - 1].response_type == TaskResponse.OnFinalAnswer
assert (
responds[len(responds) - 1].final_respond.answer
responds[len(responds) - 1].final_answer.answer
== "how can I help you today?"
)
except Exception as ex:
Expand All @@ -97,9 +98,9 @@ async def test_run_code_test(agent_sdk):
responds.append(respond)
logger.debug(f"{responds}")
assert len(responds) > 0, "no responds for the prompt"
assert responds[len(responds) - 1].respond_type == TaskRespond.OnFinalAnswerType
assert responds[len(responds) - 1].response_type == TaskResponse.OnFinalAnswer
assert (
responds[len(responds) - 1].final_respond.answer
responds[len(responds) - 1].final_answer.answer
== "this code prints 'hello world'"
)
except Exception as ex:
Expand All @@ -116,105 +117,14 @@ async def test_run_code_with_error(agent_sdk):
responds.append(respond)
logger.debug(f"{responds}")
assert len(responds) > 0, "no responds for the prompt"
assert (
responds[len(responds) - 3].respond_type == TaskRespond.OnAgentActionEndType
)
assert responds[len(responds) - 3].response_type == TaskResponse.OnStepActionEnd
assert responds[
len(responds) - 3
].on_agent_action_end.has_error, "bad has error result"
assert responds[len(responds) - 1].respond_type == TaskRespond.OnFinalAnswerType
].on_step_action_end.has_error, "bad has error result"
assert responds[len(responds) - 1].response_type == TaskResponse.OnFinalAnswer
assert (
responds[len(responds) - 1].final_respond.answer
responds[len(responds) - 1].final_answer.answer
== "this code prints 'hello world'"
)
except Exception as ex:
assert 0, str(ex)


@pytest.mark.asyncio
async def test_assemble_test(agent_sdk):
sdk = agent_sdk
await sdk.add_kernel(api_key, "127.0.0.1:9527")
try:
code = "print('hello')"
name = random_str(10)
response = await sdk.assemble(name, code, "python")
assert response.code == 0, "fail to assemble app"
apps = await sdk.query_apps()
app = list(filter(lambda x: x.name == name, apps.apps))
assert len(app) == 1, "fail to get the app with name " + name
responds = []
async for respond in sdk.run(name):
responds.append(respond)
assert len(responds) == 3, "bad responds for run application"
assert responds[0].respond_type == TaskRespond.OnAgentActionType
assert responds[1].respond_type == TaskRespond.OnAgentActionStdout
assert responds[2].respond_type == TaskRespond.OnAgentActionEndType
assert (
json.loads(responds[0].on_agent_action.input)["code"] == code
), "remote code !eq the local code"
assert responds[1].console_stdout.find("hello") >= 0, "bad output"
except Exception as ex:
assert 0, str(ex)


display_image_test_code = """import matplotlib.pyplot as plt
import numpy as np
# Step 2: Create the dataset
categories = ['Category 1', 'Category 2', 'Category 3']
group1_values = [10, 15, 12]
group2_values = [8, 11, 9]
# Step 3: Set the width and positions
bar_width = 0.35
index = np.arange(len(categories))
# Step 4: Create the figure and axis
fig, ax = plt.subplots()
# Step 5: Plot the bars
rects1 = ax.bar(index, group1_values, bar_width, label='Group 1')
rects2 = ax.bar(index + bar_width, group2_values, bar_width, label='Group 2')
# Step 6: Set labels, title, and legend
ax.set_xlabel('Categories')
ax.set_ylabel('Values')
ax.set_title('Grouped Bar Chart')
ax.set_xticks(index + bar_width / 2)
ax.set_xticklabels(categories)
ax.legend()
# Step 7: Show the chart
plt.show()"""


@pytest.mark.asyncio
async def test_assemble_image_test(agent_sdk):
sdk = agent_sdk
await sdk.add_kernel(api_key, "127.0.0.1:9527")
try:
name = random_str(10)
response = await sdk.assemble(name, display_image_test_code, "python")
assert response.code == 0, "fail to assemble app"
apps = await sdk.query_apps()
app = list(filter(lambda x: x.name == name, apps.apps))
assert len(app) == 1, "fail to get the app with name " + name
responds = []
async for respond in sdk.run(name):
responds.append(respond)
assert len(responds) == 3, "bad responds for run application"
assert responds[0].respond_type == TaskRespond.OnAgentActionType
assert responds[1].respond_type == TaskRespond.OnAgentActionStdout
assert responds[2].respond_type == TaskRespond.OnAgentActionEndType
assert (
json.loads(responds[0].on_agent_action.input)["code"]
== display_image_test_code
)
assert responds[1].console_stdout.find("png") > 0, "should output the files"
assert (
len(responds[2].on_agent_action_end.output_files) == 1
), "should output the files"

except Exception as ex:
assert 0, str(ex)

0 comments on commit 2ba9fa3

Please sign in to comment.