Skip to content

Commit

Permalink
feat: move the prompt to role module
Browse files Browse the repository at this point in the history
  • Loading branch information
imotai committed Nov 12, 2023
1 parent 2c09d1a commit 819fccf
Show file tree
Hide file tree
Showing 9 changed files with 416 additions and 46 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,5 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

.cosine
45 changes: 0 additions & 45 deletions agent/src/og_agent/llama_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,51 +36,6 @@ def _output_exception(self):
"Sorry, the LLM did return nothing, You can use a better performance model"
)


def _format_output(self, json_response):
"""
format the response and send it to the user
"""
answer = json_response["explanation"]
if json_response["action"] == "no_action":
return answer
elif json_response["action"] == "show_sample_code":
return ""
else:
code = json_response.get("code", None)
answer_code = """%s
```%s
%s
```
""" % (
answer,
json_response.get("language", "python"),
code if code else "",
)
return answer_code

async def handle_show_sample_code(
self, json_response, queue, context, task_context
):
code = json_response["code"]
explanation = json_response["explanation"]
saved_filenames = json_response.get("saved_filenames", [])
tool_input = json.dumps({
"code": code,
"explanation": explanation,
"saved_filenames": saved_filenames,
"language": json_response.get("language", "text"),
})
await queue.put(
TaskResponse(
state=task_context.to_context_state_proto(),
response_type=TaskResponse.OnStepActionStart,
on_step_action_start=OnStepActionStart(
input=tool_input, tool="show_sample_code"
),
)
)

async def handle_bash_code(
self, json_response, queue, context, task_context, task_opt
):
Expand Down
2 changes: 1 addition & 1 deletion agent/src/og_agent/llama_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self, endpoint, key, grammar):
super().__init__(endpoint + "/v1/chat/completions", key)
self.grammar = grammar

async def chat(self, messages, model, temperature=0, max_tokens=1024, stop=[]):
async def chat(self, messages, model, temperature=0, max_tokens=1024, stop=['\n']):
data = {
"messages": messages,
"temperature": temperature,
Expand Down
288 changes: 288 additions & 0 deletions agent/tests/llama_agent_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,288 @@
# vim:fenc=utf-8

# SPDX-FileCopyrightText: 2023 imotai <[email protected]>
# SPDX-FileContributor: imotai
#
# SPDX-License-Identifier: Elastic-2.0

""" """

import json
import logging
import pytest
from og_sdk.kernel_sdk import KernelSDK
from og_agent import openai_agent
from og_proto.agent_server_pb2 import ProcessOptions, TaskResponse, ProcessTaskRequest
from openai.openai_object import OpenAIObject
import asyncio
import pytest_asyncio

api_base = "127.0.0.1:9528"
api_key = "ZCeI9cYtOCyLISoi488BgZHeBkHWuFUH"

logger = logging.getLogger(__name__)

class PayloadStream:

def __init__(self, payload):
self.payload = payload

def __aiter__(self):
# create an iterator of the input keys
self.iter_keys = iter(self.payload)
return self

async def __anext__(self):
try:
k = next(self.iter_keys)
obj = OpenAIObject()
delta = OpenAIObject()
content = OpenAIObject()
content.content = k
delta.delta = content
obj.choices = [delta]
return obj
except StopIteration:
# raise stopasynciteration at the end of iterator
raise StopAsyncIteration


class FunctionCallPayloadStream:

def __init__(self, name, arguments):
self.name = name
self.arguments = arguments

def __aiter__(self):
# create an iterator of the input keys
self.iter_keys = iter(self.arguments)
return self

async def __anext__(self):
try:
k = next(self.iter_keys)
obj = OpenAIObject()
delta = OpenAIObject()
function_para = OpenAIObject()
function_para.name = self.name
function_para.arguments = k
function_call = OpenAIObject()
function_call.function_call = function_para
delta.delta = function_call
obj.choices = [delta]
return obj
except StopIteration:
# raise stopasynciteration at the end of iterator
raise StopAsyncIteration


class MockContext:

def done(self):
return False


class MultiCallMock:

def __init__(self, responses):
self.responses = responses
self.index = 0

def call(self, *args, **kwargs):
if self.index >= len(self.responses):
raise Exception("no more response")
self.index += 1
logger.debug("call index %d", self.index)
return self.responses[self.index - 1]


@pytest.fixture
def kernel_sdk():
endpoint = (
"localhost:9527" # Replace with the actual endpoint of your test gRPC server
)
return KernelSDK(endpoint, "ZCeI9cYtOCyLISoi488BgZHeBkHWuFUH")


@pytest.mark.asyncio
async def test_openai_agent_call_execute_bash_code(mocker, kernel_sdk):
kernel_sdk.connect()
arguments = {
"explanation": "the hello world in bash",
"code": "echo 'hello world'",
"saved_filenames": [],
"language": "bash",
}
stream1 = FunctionCallPayloadStream("execute", json.dumps(arguments))
sentence = "The output 'hello world' is the result"
stream2 = PayloadStream(sentence)
call_mock = MultiCallMock([stream1, stream2])
with mocker.patch(
"og_agent.openai_agent.openai.ChatCompletion.acreate",
side_effect=call_mock.call,
) as mock_openai:
agent = openai_agent.OpenaiAgent("gpt4", kernel_sdk, is_azure=False)
queue = asyncio.Queue()
task_opt = ProcessOptions(
streaming=True,
llm_name="gpt4",
input_token_limit=100000,
output_token_limit=100000,
timeout=5,
)
request = ProcessTaskRequest(
input_files=[],
task="write a hello world in bash",
context_id="",
options=task_opt,
)
await agent.arun(request, queue, MockContext(), task_opt)
responses = []
while True:
try:
response = await queue.get()
if not response:
break
responses.append(response)
except asyncio.QueueEmpty:
break
logger.info(responses)
console_output = list(
filter(
lambda x: x.response_type == TaskResponse.OnStepActionStreamStdout,
responses,
)
)
assert len(console_output) == 1, "bad console output count"
assert console_output[0].console_stdout == "hello world\n", "bad console output"

@pytest.mark.asyncio
async def test_openai_agent_direct_message(mocker, kernel_sdk):
kernel_sdk.connect()
arguments = {
"message": "hello world",
}
stream1 = FunctionCallPayloadStream("direct_message", json.dumps(arguments))
call_mock = MultiCallMock([stream1])
with mocker.patch(
"og_agent.openai_agent.openai.ChatCompletion.acreate",
side_effect=call_mock.call,
) as mock_openai:
agent = openai_agent.OpenaiAgent("gpt4", kernel_sdk, is_azure=False)
queue = asyncio.Queue()
task_opt = ProcessOptions(
streaming=False,
llm_name="gpt4",
input_token_limit=100000,
output_token_limit=100000,
timeout=5,
)
request = ProcessTaskRequest(
input_files=[],
task="say hello world",
context_id="",
options=task_opt,
)
await agent.arun(request, queue, MockContext(), task_opt)
responses = []
while True:
try:
response = await queue.get()
if not response:
break
responses.append(response)
except asyncio.QueueEmpty:
break
logger.info(responses)
assert responses[0].final_answer.answer == "hello world"


@pytest.mark.asyncio
async def test_openai_agent_call_execute_python_code(mocker, kernel_sdk):
kernel_sdk.connect()
arguments = {
"explanation": "the hello world in python",
"code": "print('hello world')",
"language": "python",
"saved_filenames": [],
}
stream1 = FunctionCallPayloadStream("execute", json.dumps(arguments))
sentence = "The output 'hello world' is the result"
stream2 = PayloadStream(sentence)
call_mock = MultiCallMock([stream1, stream2])
with mocker.patch(
"og_agent.openai_agent.openai.ChatCompletion.acreate",
side_effect=call_mock.call,
) as mock_openai:
agent = openai_agent.OpenaiAgent("gpt4", kernel_sdk, is_azure=False)
queue = asyncio.Queue()
task_opt = ProcessOptions(
streaming=True,
llm_name="gpt4",
input_token_limit=100000,
output_token_limit=100000,
timeout=5,
)
request = ProcessTaskRequest(
input_files=[],
task="write a hello world in python",
context_id="",
options=task_opt,
)
await agent.arun(request, queue, MockContext(), task_opt)
responses = []
while True:
try:
response = await queue.get()
if not response:
break
responses.append(response)
except asyncio.QueueEmpty:
break
logger.info(responses)
console_output = list(
filter(
lambda x: x.response_type == TaskResponse.OnStepActionStreamStdout,
responses,
)
)
assert len(console_output) == 1, "bad console output count"
assert console_output[0].console_stdout == "hello world\n", "bad console output"


@pytest.mark.asyncio
async def test_openai_agent_smoke_test(mocker, kernel_sdk):
sentence = "Hello, how can I help you?"
stream = PayloadStream(sentence)
with mocker.patch(
"og_agent.openai_agent.openai.ChatCompletion.acreate", return_value=stream
) as mock_openai:
agent = openai_agent.OpenaiAgent("gpt4", kernel_sdk, is_azure=False)
queue = asyncio.Queue()
task_opt = ProcessOptions(
streaming=True,
llm_name="gpt4",
input_token_limit=100000,
output_token_limit=100000,
timeout=5,
)
request = ProcessTaskRequest(
input_files=[], task="hello", context_id="", options=task_opt
)
await agent.arun(request, queue, MockContext(), task_opt)
responses = []
while True:
try:
response = await queue.get()
if not response:
break
responses.append(response)
except asyncio.QueueEmpty:
break
logger.info(responses)
assert len(responses) == len(sentence) + 1, "bad response count"
assert (
responses[-1].response_type == TaskResponse.OnFinalAnswer
), "bad response type"
assert responses[-1].state.input_token_count == 325
assert responses[-1].state.output_token_count == 8
21 changes: 21 additions & 0 deletions agent/tests/tokenizer_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# vim:fenc=utf-8

# SPDX-FileCopyrightText: 2023 imotai <[email protected]>
# SPDX-FileContributor: imotai
#
# SPDX-License-Identifier: Elastic-2.0

"""
"""

import logging
import io
from og_agent.tokenizer import tokenize

logger = logging.getLogger(__name__)
def test_parse_explanation():
arguments="""{"function_call":"execute", "arguments": {"explanation":"h"""
for token_state, token in tokenize(io.StringIO(arguments)):
logger.info(f"token_state: {token_state}, token: {token}")

1 change: 1 addition & 0 deletions roles/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# the role module
Loading

0 comments on commit 819fccf

Please sign in to comment.