From 67ea141a1c7befa7cd4d53858fdc2794b3653a57 Mon Sep 17 00:00:00 2001 From: va Date: Fri, 21 Feb 2025 15:15:10 -0500 Subject: [PATCH] fix: retry on schema error Signed-off-by: va --- python/beeai_framework/agents/runners/base.py | 1 - .../agents/runners/default/prompts.py | 10 ++++++ .../agents/runners/default/runner.py | 32 ++++++++++++++----- .../agents/runners/granite/prompts.py | 7 ++++ .../agents/runners/granite/runner.py | 2 ++ python/beeai_framework/agents/types.py | 2 +- python/beeai_framework/retryable.py | 9 +++--- 7 files changed, 49 insertions(+), 14 deletions(-) diff --git a/python/beeai_framework/agents/runners/base.py b/python/beeai_framework/agents/runners/base.py index 14e3082b..0e6c5f3f 100644 --- a/python/beeai_framework/agents/runners/base.py +++ b/python/beeai_framework/agents/runners/base.py @@ -114,7 +114,6 @@ async def create_iteration(self) -> RunnerIteration: BeeRunnerLLMInput(emitter=emitter, signal=self._run.signal, meta=meta) ) self._iterations.append(iteration) - return RunnerIteration(emitter=emitter, state=iteration.state, meta=meta, signal=self._run.signal) async def init(self, input: BeeRunInput) -> None: diff --git a/python/beeai_framework/agents/runners/default/prompts.py b/python/beeai_framework/agents/runners/default/prompts.py index c0c1c9c0..d7069680 100644 --- a/python/beeai_framework/agents/runners/default/prompts.py +++ b/python/beeai_framework/agents/runners/default/prompts.py @@ -49,6 +49,10 @@ class ToolInputErrorTemplateInput(BaseModel): reason: str +class SchemaErrorTemplateInput(BaseModel): + pass + + UserPromptTemplate = PromptTemplate(schema=UserPromptTemplateInput, template="Message: {{input}}") AssistantPromptTemplate = PromptTemplate( @@ -150,3 +154,9 @@ class ToolInputErrorTemplateInput(BaseModel): schema=AssistantPromptTemplateInput, template="""{{#thought}}Thought: {{&.}}\n{{/thought}}{{#tool_name}}Function Name: {{&.}}\n{{/tool_name}}{{#tool_input}}Function Input: {{&.}}\n{{/tool_input}}{{#tool_output}}Function Output: {{&.}}\n{{/tool_output}}{{#final_answer}}Final Answer: {{&.}}{{/final_answer}}""", # noqa: E501 ) + +SchemaErrorTemplate = PromptTemplate( + schema=SchemaErrorTemplateInput, + template="""Error: The generated response does not adhere to the communication structure mentioned in the system prompt. +You communicate only in instruction lines. Valid instruction lines are 'Thought' followed by either 'Function Name' + 'Function Input' + 'Function Output' or 'Final Answer'.""", # noqa: E501 +) diff --git a/python/beeai_framework/agents/runners/default/runner.py b/python/beeai_framework/agents/runners/default/runner.py index 0da1fac1..18a3ed35 100644 --- a/python/beeai_framework/agents/runners/default/runner.py +++ b/python/beeai_framework/agents/runners/default/runner.py @@ -23,6 +23,8 @@ ) from beeai_framework.agents.runners.default.prompts import ( AssistantPromptTemplate, + SchemaErrorTemplate, + SchemaErrorTemplateInput, SystemPromptTemplate, SystemPromptTemplateInput, ToolDefinition, @@ -37,13 +39,18 @@ BeeRunInput, ) from beeai_framework.backend.chat import ChatModelInput, ChatModelOutput -from beeai_framework.backend.message import SystemMessage, UserMessage +from beeai_framework.backend.message import AssistantMessage, SystemMessage, UserMessage from beeai_framework.emitter.emitter import EventMeta from beeai_framework.errors import FrameworkError from beeai_framework.memory.base_memory import BaseMemory from beeai_framework.memory.token_memory import TokenMemory from beeai_framework.parsers.field import ParserField -from beeai_framework.parsers.line_prefix import LinePrefixParser, LinePrefixParserNode, LinePrefixParserUpdate +from beeai_framework.parsers.line_prefix import ( + LinePrefixParser, + LinePrefixParserError, + LinePrefixParserNode, + LinePrefixParserUpdate, +) from beeai_framework.retryable import Retryable, RetryableConfig, RetryableContext, RetryableInput from beeai_framework.tools import ToolError, ToolInputValidationError from beeai_framework.tools.tool import StringToolOutput, Tool, ToolOutput @@ -58,6 +65,7 @@ def default_templates(self) -> BeeAgentTemplates: user=UserPromptTemplate, tool_not_found_error=ToolNotFoundErrorTemplate, tool_input_error=ToolInputErrorTemplate, + schema_error=SchemaErrorTemplate, ) def create_parser(self) -> LinePrefixParser: @@ -91,15 +99,19 @@ def create_parser(self) -> LinePrefixParser: ) async def llm(self, input: BeeRunnerLLMInput) -> BeeAgentRunIteration: - def on_retry() -> None: - input.emitter.emit("retry", {"meta": input.meta}) + async def on_retry(ctx: RetryableContext, last_error: Exception) -> None: + await input.emitter.emit("retry", {"meta": input.meta}) async def on_error(error: Exception, _: RetryableContext) -> None: - input.emitter.emit("error", {"error": error, "meta": input.meta}) + await input.emitter.emit("error", {"error": error, "meta": input.meta}) self._failedAttemptsCounter.use(error) - # TODO: handle - # if isinstance(error, LinePrefixParserError) + if isinstance(error, LinePrefixParserError): + if error.reason == LinePrefixParserError.Reason.NoDataReceived: + await self.memory.add(AssistantMessage("\n", {"tempMessage": True})) + else: + schema_error_prompt: str = self.templates.schema_error.render(SchemaErrorTemplateInput()) + await self.memory.add(UserMessage(schema_error_prompt, {"tempMessage": True})) async def executor(_: RetryableContext) -> Awaitable[BeeAgentRunIteration]: await input.emitter.emit("start", {"meta": input.meta, "tools": self._input.tools, "memory": self.memory}) @@ -254,7 +266,11 @@ async def executor(_: RetryableContext) -> Awaitable[BeeRunnerToolResult]: max_retries = 0 retryable_state = await Retryable( - {"on_error": on_error, "executor": executor, "config": RetryableConfig(max_retries=max_retries)} + RetryableInput( + on_error=on_error, + executor=executor, + config=RetryableConfig(max_retries=max_retries), + ) ).get() return retryable_state.value diff --git a/python/beeai_framework/agents/runners/granite/prompts.py b/python/beeai_framework/agents/runners/granite/prompts.py index 16830ce9..bb810e6e 100644 --- a/python/beeai_framework/agents/runners/granite/prompts.py +++ b/python/beeai_framework/agents/runners/granite/prompts.py @@ -16,6 +16,7 @@ from beeai_framework.agents.runners.default.prompts import ( AssistantPromptTemplateInput, + SchemaErrorTemplateInput, SystemPromptTemplateInput, ToolInputErrorTemplateInput, ToolNotFoundErrorTemplateInput, @@ -92,3 +93,9 @@ HINT: If you're convinced that the input was correct but the tool cannot process it then use a different tool or say I don't know.""", # noqa: E501 ) + +GraniteSchemaErrorTemplate = PromptTemplate( + schema=SchemaErrorTemplateInput, + template="""Error: The generated response does not adhere to the communication structure mentioned in the system prompt. +You communicate only in instruction lines. Valid instruction lines are 'Thought' followed by either 'Function Name' + 'Function Input' + 'Function Output' or 'Final Answer'.""", # noqa: E501 +) diff --git a/python/beeai_framework/agents/runners/granite/runner.py b/python/beeai_framework/agents/runners/granite/runner.py index 39dc0c70..c89c4190 100644 --- a/python/beeai_framework/agents/runners/granite/runner.py +++ b/python/beeai_framework/agents/runners/granite/runner.py @@ -17,6 +17,7 @@ from beeai_framework.agents.runners.default.runner import DefaultRunner from beeai_framework.agents.runners.granite.prompts import ( GraniteAssistantPromptTemplate, + GraniteSchemaErrorTemplate, GraniteSystemPromptTemplate, GraniteToolInputErrorTemplate, GraniteToolNotFoundErrorTemplate, @@ -86,6 +87,7 @@ def default_templates(self) -> BeeAgentTemplates: user=GraniteUserPromptTemplate, tool_not_found_error=GraniteToolNotFoundErrorTemplate, tool_input_error=GraniteToolInputErrorTemplate, + schema_error=GraniteSchemaErrorTemplate, ) async def init_memory(self, input: BeeRunInput) -> BaseMemory: diff --git a/python/beeai_framework/agents/types.py b/python/beeai_framework/agents/types.py index 8eec0d1c..62be7441 100644 --- a/python/beeai_framework/agents/types.py +++ b/python/beeai_framework/agents/types.py @@ -80,7 +80,7 @@ class BeeAgentTemplates(BaseModel): tool_input_error: InstanceOf[PromptTemplate] # tool_no_result_error: InstanceOf[PromptTemplate] tool_not_found_error: InstanceOf[PromptTemplate] - # schema_error: InstanceOf[PromptTemplate] + schema_error: InstanceOf[PromptTemplate] class AgentMeta(BaseModel): diff --git a/python/beeai_framework/retryable.py b/python/beeai_framework/retryable.py index b10c525c..6bca1c28 100644 --- a/python/beeai_framework/retryable.py +++ b/python/beeai_framework/retryable.py @@ -209,12 +209,13 @@ async def _on_failed_attempt(e: FrameworkError, meta: Meta) -> None: async def get(self, config: RetryableRunConfig | None = None) -> Awaitable[T]: if self.is_resolved(): return self._retry_state.value - if self.is_rejected(): + elif self.is_rejected(): raise self._retry_state.value - if (self._retry_state.state not in ["resolved", "rejected"] if self._retry_state else False) and not config: + elif (self._retry_state.state not in ["resolved", "rejected"] if self._retry_state else False) and not config: + return self._retry_state + else: + self._retry_state = await self._run(config) return self._retry_state - self._retry_state = await self._run(config) - return self._retry_state def reset(self) -> None: self._retry_state = None