From ba89bfe3aa67c7bc8915ff0241bfe5d0afa3eb09 Mon Sep 17 00:00:00 2001 From: Manoj Date: Wed, 8 Jan 2025 18:21:03 +0530 Subject: [PATCH] fix: improve state handling and JSON parsing with better error handling --- src/agent/custom_agent.py | 109 +++++++++++++++++++++++----- src/agent/custom_massage_manager.py | 65 +++++++++++------ 2 files changed, 132 insertions(+), 42 deletions(-) diff --git a/src/agent/custom_agent.py b/src/agent/custom_agent.py index 3bf5496..c5242a2 100644 --- a/src/agent/custom_agent.py +++ b/src/agent/custom_agent.py @@ -123,8 +123,7 @@ def _log_response(self, response: CustomAgentOutput) -> None: logger.info(f"🎯 Summary: {response.current_state.summary}") for i, action in enumerate(response.action): logger.info( - f"🛠️ Action {i + 1}/{len(response.action)}: {action.model_dump_json(exclude_unset=True)}" - ) + f"🛠️ Action {i + 1}/{len(response.action)}: {action.model_dump_json(exclude_unset=True)}") def update_step_info( self, model_output: CustomAgentOutput, step_info: CustomAgentStepInfo = None @@ -163,23 +162,51 @@ async def get_next_action(self, input_messages: list[BaseMessage]) -> AgentOutpu return parsed except Exception as e: - # If something goes wrong, try to invoke the LLM again without structured output, - # and Manually parse the response. Temporarily solution for DeepSeek + # If structured output fails, try manual JSON parsing ret = self.llm.invoke(input_messages) - if isinstance(ret.content, list): - parsed_json = json.loads(ret.content[0].replace("```json", "").replace("```", "")) + content = ret.content + + # Clean up the content to ensure valid JSON + if "```json" in content: + # Extract JSON from code block + start = content.find("```json") + 7 + end = content.find("```", start) + if end == -1: + end = len(content) + content = content[start:end] else: - parsed_json = json.loads(ret.content.replace("```json", "").replace("```", "")) - parsed: AgentOutput = self.AgentOutput(**parsed_json) - if parsed is None: - raise ValueError(f'Could not parse response.') - - # cut the number of actions to max_actions_per_step - parsed.action = parsed.action[: self.max_actions_per_step] - self._log_response(parsed) - self.n_steps += 1 - - return parsed + # Try to find JSON object + start = content.find("{") + end = content.rfind("}") + 1 + if start >= 0 and end > start: + content = content[start:end] + + # Clean up any remaining whitespace or newlines + content = content.strip() + + try: + parsed_json = json.loads(content) + parsed: AgentOutput = self.AgentOutput(**parsed_json) + # cut the number of actions to max_actions_per_step + parsed.action = parsed.action[: self.max_actions_per_step] + self._log_response(parsed) + self.n_steps += 1 + return parsed + except json.JSONDecodeError as e: + logger.error(f"Failed to parse JSON: {str(e)}") + logger.error(f"Content was: {content}") + # Create a default response + from .custom_views import CustomAgentBrain + return CustomAgentOutput( + current_state=CustomAgentBrain( + prev_action_evaluation="Failed - Error parsing response", + important_contents="None", + completed_contents="", + thought="Failed to parse the response. Will retry with a simpler action.", + summary="Retry with simpler action" + ), + action=[{"go_to_url": {"url": "https://www.google.com"}}] + ) @time_execution_async("--step") async def step(self, step_info: Optional[CustomAgentStepInfo] = None) -> None: @@ -191,11 +218,50 @@ async def step(self, step_info: Optional[CustomAgentStepInfo] = None) -> None: try: state = await self.browser_context.get_state(use_vision=self.use_vision) - self.message_manager.add_state_message(state, self._last_result, step_info) + if state is None: + logger.error("Failed to get browser state") + return + + # Create a new state object with default values + from dataclasses import dataclass, field + from typing import List, Optional + + @dataclass + class ElementTree: + clickable_elements: List[str] = field(default_factory=list) + + def clickable_elements_to_string(self, include_attributes=None): + return "\n".join(self.clickable_elements) if self.clickable_elements else "" + + @dataclass + class BrowserState: + url: str = "" + tabs: List[str] = field(default_factory=list) + element_tree: ElementTree = field(default_factory=ElementTree) + screenshot: Optional[str] = None + + browser_state = BrowserState() + browser_state.url = getattr(state, 'url', '') + browser_state.tabs = getattr(state, 'tabs', []) + browser_state.screenshot = getattr(state, 'screenshot', None) + + # Extract clickable elements if available + if hasattr(state, 'element_tree') and hasattr(state.element_tree, 'clickable_elements'): + browser_state.element_tree.clickable_elements = state.element_tree.clickable_elements + + self.message_manager.add_state_message(browser_state, self._last_result, step_info) input_messages = self.message_manager.get_messages() + if not input_messages: + logger.error("Failed to get input messages") + return + model_output = await self.get_next_action(input_messages) + if model_output is None: + logger.error("Failed to get next action") + return + self.update_step_info(model_output, step_info) - logger.info(f"🧠 All Memory: {step_info.memory}") + logger.info(f"🧠 All Memory: {getattr(step_info, 'memory', '')}") self._save_conversation(input_messages, model_output) self.message_manager._remove_last_state_message() # we dont want the whole state in the chat history self.message_manager.add_model_output(model_output) @@ -203,6 +269,8 @@ async def step(self, step_info: Optional[CustomAgentStepInfo] = None) -> None: result: list[ActionResult] = await self.controller.multi_act( model_output.action, self.browser_context ) + if result is None: + result = [] self._last_result = result if len(result) > 0 and result[-1].is_done: @@ -211,6 +279,7 @@ async def step(self, step_info: Optional[CustomAgentStepInfo] = None) -> None: self.consecutive_failures = 0 except Exception as e: + logger.error(f"Error in step: {str(e)}") result = self._handle_step_error(e) self._last_result = result @@ -218,7 +287,7 @@ async def step(self, step_info: Optional[CustomAgentStepInfo] = None) -> None: if not result: return for r in result: - if r.error: + if r and r.error: self.telemetry.capture( AgentStepErrorTelemetryEvent( agent_id=self.agent_id, diff --git a/src/agent/custom_massage_manager.py b/src/agent/custom_massage_manager.py index 8de2b06..748ece6 100644 --- a/src/agent/custom_massage_manager.py +++ b/src/agent/custom_massage_manager.py @@ -94,26 +94,47 @@ def add_state_message( ) -> None: """Add browser state as human message""" - # if keep in memory, add to directly to history and add state without result - if result: - for r in result: - if r.include_in_memory: - if r.extracted_content: - msg = HumanMessage(content=str(r.extracted_content)) - self._add_message_with_tokens(msg) - if r.error: - msg = HumanMessage( - content=str(r.error)[-self.max_error_length:] - ) - self._add_message_with_tokens(msg) - result = None # if result in history, we dont want to add it again + try: + # if keep in memory, add to directly to history and add state without result + if result: + for r in result: + if r and r.include_in_memory: + if r.extracted_content: + msg = HumanMessage(content=str(r.extracted_content)) + self._add_message_with_tokens(msg) + if r.error: + msg = HumanMessage( + content=str(r.error)[-self.max_error_length :] + ) + self._add_message_with_tokens(msg) + result = None # if result in history, we dont want to add it again - # otherwise add state message and result to next message (which will not stay in memory) - state_message = CustomAgentMessagePrompt( - state, - result, - include_attributes=self.include_attributes, - max_error_length=self.max_error_length, - step_info=step_info, - ).get_user_message() - self._add_message_with_tokens(state_message) + # Create state message with safe attribute access + state_message = CustomAgentMessagePrompt( + state, + result, + include_attributes=self.include_attributes, + max_error_length=self.max_error_length, + step_info=step_info, + ).get_user_message() + + if state_message and hasattr(state_message, 'content'): + if isinstance(state_message.content, str): + self._add_message_with_tokens(state_message) + elif isinstance(state_message.content, list): + # Handle multi-modal messages (text + image) + has_valid_content = False + for item in state_message.content: + if isinstance(item, dict): + if item.get('type') == 'text' and item.get('text'): + has_valid_content = True + elif item.get('type') == 'image_url' and item.get('image_url', {}).get('url'): + has_valid_content = True + if has_valid_content: + self._add_message_with_tokens(state_message) + + except Exception as e: + logger.error(f"Error in add_state_message: {str(e)}") + # Create a basic message if state processing fails + msg = HumanMessage(content="Error processing browser state") + self._add_message_with_tokens(msg)