Skip to content

Commit

Permalink
fix: add codellama test
Browse files Browse the repository at this point in the history
  • Loading branch information
imotai committed Oct 19, 2023
1 parent ee6a684 commit 50e3c31
Show file tree
Hide file tree
Showing 5 changed files with 203 additions and 125 deletions.
151 changes: 83 additions & 68 deletions agent/src/og_agent/codellama_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import io
import time
from .codellama_client import CodellamaClient
from og_proto.agent_server_pb2 import OnAgentAction, TaskRespond, OnAgentActionEnd, FinalRespond
from og_proto.agent_server_pb2 import OnStepActionStart, TaskResponse, OnStepActionEnd, FinalAnswer
from .base_agent import BaseAgent, TypingState, TaskContext
from .tokenizer import tokenize
import tiktoken
Expand Down Expand Up @@ -84,14 +84,14 @@ async def handle_bash_code(self, json_response, queue, context, task_context):
"code": commands,
"explanation": explanation,
"saved_filenames": saved_filenames,
"language": json_response.get("language", "text"),
"language": json_response.get("language"),
})
await queue.put(
TaskRespond(
state=task_context.to_task_state_proto(),
respond_type=TaskRespond.OnAgentActionType,
on_agent_action=OnAgentAction(
input=tool_input, tool="execute_python_code"
input=tool_input, tool="execute_bash_code"
),
)
)
Expand All @@ -105,23 +105,24 @@ async def handle_bash_code(self, json_response, queue, context, task_context):
await queue.put(respond)
return function_result

async def handle_function(self, json_response, queue, context, task_context):
async def handle_function(
self, json_response, queue, context, task_context, task_opt
):
code = json_response["action_input"]
explanation = json_response["explanation"]
saved_filenames = json_response.get("saved_filenames", [])
tool_input = json.dumps({
"code": code,
"explanation": explanation,
"saved_filenames": saved_filenames,
"language": json_response.get("language", "text"),
"language": json_response.get("language"),
})

await queue.put(
TaskRespond(
state=task_context.to_task_state_proto(),
respond_type=TaskRespond.OnAgentActionType,
on_agent_action=OnAgentAction(
input=tool_input, tool="execute_python_code"
TaskResponse(
state=task_context.to_context_state_proto(),
response_type=TaskResponse.OnStepActionStart,
on_agent_action_start=OnStepActionStart(
input=tool_input, tool=json_response["action"]
),
)
)
Expand All @@ -131,7 +132,7 @@ async def handle_function(self, json_response, queue, context, task_context):
logger.debug("the client has cancelled the request")
break
function_result = result
if respond:
if respond and task_opt.streaming:
await queue.put(respond)
return function_result

Expand Down Expand Up @@ -160,14 +161,15 @@ def _get_argument_new_typing(self, message):
return (state, explanation_str, action_input_str)

async def call_codellama(
self, question, chat_history, queue, context, task_context
self, question, chat_history, queue, context, task_context, task_opt
):
"""
call codellama api
"""
start_time = time.time()
num_tokens = len(encoding.encode(question)) + len(encoding.encode(chat_history))
task_context.sent_token_count += num_tokens
task_context.input_token_count += num_tokens
output_token_count = task_context.output_token_count
state = None
message = ""
text_content = ""
Expand All @@ -179,14 +181,11 @@ async def call_codellama(
logger.debug("the client has cancelled the request")
break
respond = json.loads(line[6:])
task_context.generated_token_count += len(
encoding.encode(respond["content"])
)
task_context.model_respond_duration += int(
(time.time() - start_time) * 1000
)
task_context.llm_response_duration += int((time.time() - start_time) * 1000)
start_time = time.time()
message += respond["content"]
response_token_count = len(encoding.encode(message))
task_context.output_token_count = output_token_count + response_token_count
logger.debug(f" message {message}")
(
state,
Expand All @@ -196,79 +195,92 @@ async def call_codellama(
if explanation_str and text_content != explanation_str:
typed_chars = explanation_str[len(text_content) :]
text_content = explanation_str
await queue.put(
TaskRespond(
state=task_context.to_task_state_proto(),
respond_type=TaskRespond.OnAgentTextTyping,
typing_content=typed_chars,
if task_opt.streaming:
await queue.put(
TaskResponse(
state=task_context.to_context_state_proto(),
response_type=TaskResponse.OnModelTypeText,
typing_content=typed_chars,
)
)
)
if action_input_str and code_content != action_input_str:
typed_chars = action_input_str[len(code_content) :]
code_content = action_input_str
await queue.put(
TaskRespond(
state=task_context.to_task_state_proto(),
respond_type=TaskRespond.OnAgentCodeTyping,
TaskResponse(
state=task_context.to_context_state_proto(),
response_type=TaskResponse.OnModelTypeCode,
typing_content=typed_chars,
)
)
logger.debug(
f"argument explanation:{explanation_str} code:{action_input_str}"
)
if respond["stop"]:
if respond.get("stop", ""):
state = respond

return (message, state)

async def arun(self, question, queue, context, max_iteration=5):
async def arun(self, question, queue, context, task_opt):
"""
run the agent
"""
history = []
current_question = question
task_context = TaskContext(
start_time=time.time(),
generated_token_count=0,
sent_token_count=0,
model_name="codellama",
iteration_count=0,
model_respond_duration=0,
output_token_count=0,
input_token_count=0,
llm_name="codellama",
llm_respond_duration=0,
)
iteration = 0
try:
while iteration < max_iteration:
if context.done():
logger.debug("the client has cancelled the request")
while not context.done():
if task_context.input_token_count >= task_opt.input_token_limit:
await queue.put(
TaskResponse(
state=task_context.to_context_state_proto(),
response_type=TaskResponse.OnInputTokenLimitExceed,
error_msg="input token limit reached",
)
)
break
if task_context.output_token_count >= task_opt.output_token_limit:
await queue.put(
TaskResponse(
state=task_context.to_context_state_proto(),
response_type=TaskResponse.OnOutputTokenLimitExceed,
error_msg="output token limit reached",
)
)
break
iteration += 1
task_context.iteration_count = iteration
chat_history = "\n".join(history)
(message, state) = await self.call_codellama(
current_question, chat_history, queue, context, task_context
current_question,
chat_history,
queue,
context,
task_context,
task_opt,
)
try:
json_response = json.loads(message)
if not json_response:
await queue.put(
TaskRespond(
state=task_context.to_task_state_proto(),
respond_type=TaskRespond.OnFinalAnswerType,
final_respond=FinalRespond(
answer=self._output_exception()
),
TaskResponse(
state=task_context.to_context_state_proto(),
response_type=TaskResponse.OnModelOutputError,
error_msg=self._output_exception(),
)
)
break
except Exception as ex:
logger.exception(f"fail to load message the message is {message}")
await queue.put(
TaskRespond(
state=task_context.to_task_state_proto(),
respond_type=TaskRespond.OnFinalAnswerType,
final_respond=FinalRespond(
answer="The model made an invalid respone"
),
TaskResponse(
state=task_context.to_context_state_proto(),
response_type=TaskResponse.OnModelOutputError,
error_msg=str(ex),
)
)
break
Expand All @@ -283,15 +295,18 @@ async def arun(self, question, queue, context, max_iteration=5):
"execute_bash_code": self.handle_bash_code,
}
function_result = await tools_mapping[json_response["action"]](
json_response, queue, context, task_context
json_response, queue, context, task_context, task_opt
)
logger.debug(f"the function result {function_result}")
await queue.put(
TaskRespond(
state=task_context.to_task_state_proto(),
respond_type=TaskRespond.OnAgentActionEndType,
TaskResponse(
state=task_context.to_context_state_proto(),
response_type=TaskRespond.OnStepActionEnd,
on_agent_action_end=OnAgentActionEnd(
output="",
output=""
if task_opt.streaming
else function_result.console_stderr
+ function_result.console_stdout,
output_files=function_result.saved_filenames,
has_error=function_result.has_error,
),
Expand Down Expand Up @@ -333,20 +348,20 @@ async def arun(self, question, queue, context, max_iteration=5):
)
result = self._format_output(json_response)
await queue.put(
TaskRespond(
state=task_context.to_task_state_proto(),
respond_type=TaskRespond.OnFinalAnswerType,
final_respond=FinalRespond(answer=result),
TaskResponse(
state=task_context.to_context_state_proto(),
response_type=TaskRespond.OnFinalAnswer,
final_answer=FinalAnswer(answer=result),
)
)
break
else:
result = self._format_output(json_response)
await queue.put(
TaskRespond(
state=task_context.to_task_state_proto(),
respond_type=TaskRespond.OnFinalAnswerType,
final_respond=FinalRespond(answer=result),
TaskResponse(
state=task_context.to_context_state_proto(),
response_type=TaskResponse.OnFinalAnswer,
final_answer=FinalAnswer(answer=result),
)
)
break
Expand Down
12 changes: 5 additions & 7 deletions agent/src/og_agent/mock_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import time
import logging
from .base_agent import BaseAgent, TypingState, TaskContext
from og_proto.agent_server_pb2 import OnAgentAction, TaskRespond, OnAgentActionEnd, FinalRespond
from og_proto.agent_server_pb2 import OnStepActionStart, TaskResponse, OnStepActionEnd, FinalAnswer
from .tokenizer import tokenize

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -74,14 +74,12 @@ async def arun(self, task, queue, context, max_iteration=5):
"""
run the agent
"""
iteration = 0
task_context = TaskContext(
start_time=time.time(),
generated_token_count=10,
sent_token_count=10,
model_name="mock",
iteration_count=1,
model_respond_duration=1000,
output_token_count=10,
input_token_count=10,
llm_name="mock",
llm_respond_duration=1000,
)
try:
while iteration < max_iteration:
Expand Down
Loading

0 comments on commit 50e3c31

Please sign in to comment.