Skip to content

Commit

Permalink
fix: retry on schema error
Browse files Browse the repository at this point in the history
Signed-off-by: va <[email protected]>
  • Loading branch information
vabarbosa committed Feb 21, 2025
1 parent 09e23ff commit 67ea141
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 14 deletions.
1 change: 0 additions & 1 deletion python/beeai_framework/agents/runners/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ async def create_iteration(self) -> RunnerIteration:
BeeRunnerLLMInput(emitter=emitter, signal=self._run.signal, meta=meta)
)
self._iterations.append(iteration)

return RunnerIteration(emitter=emitter, state=iteration.state, meta=meta, signal=self._run.signal)

async def init(self, input: BeeRunInput) -> None:
Expand Down
10 changes: 10 additions & 0 deletions python/beeai_framework/agents/runners/default/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ class ToolInputErrorTemplateInput(BaseModel):
reason: str


class SchemaErrorTemplateInput(BaseModel):
pass


UserPromptTemplate = PromptTemplate(schema=UserPromptTemplateInput, template="Message: {{input}}")

AssistantPromptTemplate = PromptTemplate(
Expand Down Expand Up @@ -150,3 +154,9 @@ class ToolInputErrorTemplateInput(BaseModel):
schema=AssistantPromptTemplateInput,
template="""{{#thought}}Thought: {{&.}}\n{{/thought}}{{#tool_name}}Function Name: {{&.}}\n{{/tool_name}}{{#tool_input}}Function Input: {{&.}}\n{{/tool_input}}{{#tool_output}}Function Output: {{&.}}\n{{/tool_output}}{{#final_answer}}Final Answer: {{&.}}{{/final_answer}}""", # noqa: E501
)

SchemaErrorTemplate = PromptTemplate(
schema=SchemaErrorTemplateInput,
template="""Error: The generated response does not adhere to the communication structure mentioned in the system prompt.
You communicate only in instruction lines. Valid instruction lines are 'Thought' followed by either 'Function Name' + 'Function Input' + 'Function Output' or 'Final Answer'.""", # noqa: E501
)
32 changes: 24 additions & 8 deletions python/beeai_framework/agents/runners/default/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
)
from beeai_framework.agents.runners.default.prompts import (
AssistantPromptTemplate,
SchemaErrorTemplate,
SchemaErrorTemplateInput,
SystemPromptTemplate,
SystemPromptTemplateInput,
ToolDefinition,
Expand All @@ -37,13 +39,18 @@
BeeRunInput,
)
from beeai_framework.backend.chat import ChatModelInput, ChatModelOutput
from beeai_framework.backend.message import SystemMessage, UserMessage
from beeai_framework.backend.message import AssistantMessage, SystemMessage, UserMessage
from beeai_framework.emitter.emitter import EventMeta
from beeai_framework.errors import FrameworkError
from beeai_framework.memory.base_memory import BaseMemory
from beeai_framework.memory.token_memory import TokenMemory
from beeai_framework.parsers.field import ParserField
from beeai_framework.parsers.line_prefix import LinePrefixParser, LinePrefixParserNode, LinePrefixParserUpdate
from beeai_framework.parsers.line_prefix import (
LinePrefixParser,
LinePrefixParserError,
LinePrefixParserNode,
LinePrefixParserUpdate,
)
from beeai_framework.retryable import Retryable, RetryableConfig, RetryableContext, RetryableInput
from beeai_framework.tools import ToolError, ToolInputValidationError
from beeai_framework.tools.tool import StringToolOutput, Tool, ToolOutput
Expand All @@ -58,6 +65,7 @@ def default_templates(self) -> BeeAgentTemplates:
user=UserPromptTemplate,
tool_not_found_error=ToolNotFoundErrorTemplate,
tool_input_error=ToolInputErrorTemplate,
schema_error=SchemaErrorTemplate,
)

def create_parser(self) -> LinePrefixParser:
Expand Down Expand Up @@ -91,15 +99,19 @@ def create_parser(self) -> LinePrefixParser:
)

async def llm(self, input: BeeRunnerLLMInput) -> BeeAgentRunIteration:
def on_retry() -> None:
input.emitter.emit("retry", {"meta": input.meta})
async def on_retry(ctx: RetryableContext, last_error: Exception) -> None:
await input.emitter.emit("retry", {"meta": input.meta})

async def on_error(error: Exception, _: RetryableContext) -> None:
input.emitter.emit("error", {"error": error, "meta": input.meta})
await input.emitter.emit("error", {"error": error, "meta": input.meta})
self._failedAttemptsCounter.use(error)

# TODO: handle
# if isinstance(error, LinePrefixParserError)
if isinstance(error, LinePrefixParserError):
if error.reason == LinePrefixParserError.Reason.NoDataReceived:
await self.memory.add(AssistantMessage("\n", {"tempMessage": True}))
else:
schema_error_prompt: str = self.templates.schema_error.render(SchemaErrorTemplateInput())
await self.memory.add(UserMessage(schema_error_prompt, {"tempMessage": True}))

async def executor(_: RetryableContext) -> Awaitable[BeeAgentRunIteration]:
await input.emitter.emit("start", {"meta": input.meta, "tools": self._input.tools, "memory": self.memory})
Expand Down Expand Up @@ -254,7 +266,11 @@ async def executor(_: RetryableContext) -> Awaitable[BeeRunnerToolResult]:
max_retries = 0

retryable_state = await Retryable(
{"on_error": on_error, "executor": executor, "config": RetryableConfig(max_retries=max_retries)}
RetryableInput(
on_error=on_error,
executor=executor,
config=RetryableConfig(max_retries=max_retries),
)
).get()

return retryable_state.value
Expand Down
7 changes: 7 additions & 0 deletions python/beeai_framework/agents/runners/granite/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from beeai_framework.agents.runners.default.prompts import (
AssistantPromptTemplateInput,
SchemaErrorTemplateInput,
SystemPromptTemplateInput,
ToolInputErrorTemplateInput,
ToolNotFoundErrorTemplateInput,
Expand Down Expand Up @@ -92,3 +93,9 @@
HINT: If you're convinced that the input was correct but the tool cannot process it then use a different tool or say I don't know.""", # noqa: E501
)

GraniteSchemaErrorTemplate = PromptTemplate(
schema=SchemaErrorTemplateInput,
template="""Error: The generated response does not adhere to the communication structure mentioned in the system prompt.
You communicate only in instruction lines. Valid instruction lines are 'Thought' followed by either 'Function Name' + 'Function Input' + 'Function Output' or 'Final Answer'.""", # noqa: E501
)
2 changes: 2 additions & 0 deletions python/beeai_framework/agents/runners/granite/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from beeai_framework.agents.runners.default.runner import DefaultRunner
from beeai_framework.agents.runners.granite.prompts import (
GraniteAssistantPromptTemplate,
GraniteSchemaErrorTemplate,
GraniteSystemPromptTemplate,
GraniteToolInputErrorTemplate,
GraniteToolNotFoundErrorTemplate,
Expand Down Expand Up @@ -86,6 +87,7 @@ def default_templates(self) -> BeeAgentTemplates:
user=GraniteUserPromptTemplate,
tool_not_found_error=GraniteToolNotFoundErrorTemplate,
tool_input_error=GraniteToolInputErrorTemplate,
schema_error=GraniteSchemaErrorTemplate,
)

async def init_memory(self, input: BeeRunInput) -> BaseMemory:
Expand Down
2 changes: 1 addition & 1 deletion python/beeai_framework/agents/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class BeeAgentTemplates(BaseModel):
tool_input_error: InstanceOf[PromptTemplate]
# tool_no_result_error: InstanceOf[PromptTemplate]
tool_not_found_error: InstanceOf[PromptTemplate]
# schema_error: InstanceOf[PromptTemplate]
schema_error: InstanceOf[PromptTemplate]


class AgentMeta(BaseModel):
Expand Down
9 changes: 5 additions & 4 deletions python/beeai_framework/retryable.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,12 +209,13 @@ async def _on_failed_attempt(e: FrameworkError, meta: Meta) -> None:
async def get(self, config: RetryableRunConfig | None = None) -> Awaitable[T]:
if self.is_resolved():
return self._retry_state.value
if self.is_rejected():
elif self.is_rejected():
raise self._retry_state.value
if (self._retry_state.state not in ["resolved", "rejected"] if self._retry_state else False) and not config:
elif (self._retry_state.state not in ["resolved", "rejected"] if self._retry_state else False) and not config:
return self._retry_state
else:
self._retry_state = await self._run(config)
return self._retry_state
self._retry_state = await self._run(config)
return self._retry_state

def reset(self) -> None:
self._retry_state = None
Expand Down

0 comments on commit 67ea141

Please sign in to comment.