Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: improve state handling and JSON parsing with better error handling #47

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 89 additions & 20 deletions src/agent/custom_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,7 @@ def _log_response(self, response: CustomAgentOutput) -> None:
logger.info(f"🎯 Summary: {response.current_state.summary}")
for i, action in enumerate(response.action):
logger.info(
f"🛠️ Action {i + 1}/{len(response.action)}: {action.model_dump_json(exclude_unset=True)}"
)
f"🛠️ Action {i + 1}/{len(response.action)}: {action.model_dump_json(exclude_unset=True)}")

def update_step_info(
self, model_output: CustomAgentOutput, step_info: CustomAgentStepInfo = None
Expand Down Expand Up @@ -163,23 +162,51 @@ async def get_next_action(self, input_messages: list[BaseMessage]) -> AgentOutpu

return parsed
except Exception as e:
# If something goes wrong, try to invoke the LLM again without structured output,
# and Manually parse the response. Temporarily solution for DeepSeek
# If structured output fails, try manual JSON parsing
ret = self.llm.invoke(input_messages)
if isinstance(ret.content, list):
parsed_json = json.loads(ret.content[0].replace("```json", "").replace("```", ""))
content = ret.content

# Clean up the content to ensure valid JSON
if "```json" in content:
# Extract JSON from code block
start = content.find("```json") + 7
end = content.find("```", start)
if end == -1:
end = len(content)
content = content[start:end]
else:
parsed_json = json.loads(ret.content.replace("```json", "").replace("```", ""))
parsed: AgentOutput = self.AgentOutput(**parsed_json)
if parsed is None:
raise ValueError(f'Could not parse response.')

# cut the number of actions to max_actions_per_step
parsed.action = parsed.action[: self.max_actions_per_step]
self._log_response(parsed)
self.n_steps += 1

return parsed
# Try to find JSON object
start = content.find("{")
end = content.rfind("}") + 1
if start >= 0 and end > start:
content = content[start:end]

# Clean up any remaining whitespace or newlines
content = content.strip()

try:
parsed_json = json.loads(content)
parsed: AgentOutput = self.AgentOutput(**parsed_json)
# cut the number of actions to max_actions_per_step
parsed.action = parsed.action[: self.max_actions_per_step]
self._log_response(parsed)
self.n_steps += 1
return parsed
except json.JSONDecodeError as e:
logger.error(f"Failed to parse JSON: {str(e)}")
logger.error(f"Content was: {content}")
# Create a default response
from .custom_views import CustomAgentBrain
return CustomAgentOutput(
current_state=CustomAgentBrain(
prev_action_evaluation="Failed - Error parsing response",
important_contents="None",
completed_contents="",
thought="Failed to parse the response. Will retry with a simpler action.",
summary="Retry with simpler action"
),
action=[{"go_to_url": {"url": "https://www.google.com"}}]
)

@time_execution_async("--step")
async def step(self, step_info: Optional[CustomAgentStepInfo] = None) -> None:
Expand All @@ -191,18 +218,59 @@ async def step(self, step_info: Optional[CustomAgentStepInfo] = None) -> None:

try:
state = await self.browser_context.get_state(use_vision=self.use_vision)
self.message_manager.add_state_message(state, self._last_result, step_info)
if state is None:
logger.error("Failed to get browser state")
return

# Create a new state object with default values
from dataclasses import dataclass, field
from typing import List, Optional

@dataclass
class ElementTree:
clickable_elements: List[str] = field(default_factory=list)

def clickable_elements_to_string(self, include_attributes=None):
return "\n".join(self.clickable_elements) if self.clickable_elements else ""

@dataclass
class BrowserState:
url: str = ""
tabs: List[str] = field(default_factory=list)
element_tree: ElementTree = field(default_factory=ElementTree)
screenshot: Optional[str] = None

browser_state = BrowserState()
browser_state.url = getattr(state, 'url', '')
browser_state.tabs = getattr(state, 'tabs', [])
browser_state.screenshot = getattr(state, 'screenshot', None)

# Extract clickable elements if available
if hasattr(state, 'element_tree') and hasattr(state.element_tree, 'clickable_elements'):
browser_state.element_tree.clickable_elements = state.element_tree.clickable_elements

self.message_manager.add_state_message(browser_state, self._last_result, step_info)
input_messages = self.message_manager.get_messages()
if not input_messages:
logger.error("Failed to get input messages")
return

model_output = await self.get_next_action(input_messages)
if model_output is None:
logger.error("Failed to get next action")
return

self.update_step_info(model_output, step_info)
logger.info(f"🧠 All Memory: {step_info.memory}")
logger.info(f"🧠 All Memory: {getattr(step_info, 'memory', '')}")
self._save_conversation(input_messages, model_output)
self.message_manager._remove_last_state_message() # we dont want the whole state in the chat history
self.message_manager.add_model_output(model_output)

result: list[ActionResult] = await self.controller.multi_act(
model_output.action, self.browser_context
)
if result is None:
result = []
self._last_result = result

if len(result) > 0 and result[-1].is_done:
Expand All @@ -211,14 +279,15 @@ async def step(self, step_info: Optional[CustomAgentStepInfo] = None) -> None:
self.consecutive_failures = 0

except Exception as e:
logger.error(f"Error in step: {str(e)}")
result = self._handle_step_error(e)
self._last_result = result

finally:
if not result:
return
for r in result:
if r.error:
if r and r.error:
self.telemetry.capture(
AgentStepErrorTelemetryEvent(
agent_id=self.agent_id,
Expand Down
65 changes: 43 additions & 22 deletions src/agent/custom_massage_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,26 +94,47 @@ def add_state_message(
) -> None:
"""Add browser state as human message"""

# if keep in memory, add to directly to history and add state without result
if result:
for r in result:
if r.include_in_memory:
if r.extracted_content:
msg = HumanMessage(content=str(r.extracted_content))
self._add_message_with_tokens(msg)
if r.error:
msg = HumanMessage(
content=str(r.error)[-self.max_error_length:]
)
self._add_message_with_tokens(msg)
result = None # if result in history, we dont want to add it again
try:
# if keep in memory, add to directly to history and add state without result
if result:
for r in result:
if r and r.include_in_memory:
if r.extracted_content:
msg = HumanMessage(content=str(r.extracted_content))
self._add_message_with_tokens(msg)
if r.error:
msg = HumanMessage(
content=str(r.error)[-self.max_error_length :]
)
self._add_message_with_tokens(msg)
result = None # if result in history, we dont want to add it again

# otherwise add state message and result to next message (which will not stay in memory)
state_message = CustomAgentMessagePrompt(
state,
result,
include_attributes=self.include_attributes,
max_error_length=self.max_error_length,
step_info=step_info,
).get_user_message()
self._add_message_with_tokens(state_message)
# Create state message with safe attribute access
state_message = CustomAgentMessagePrompt(
state,
result,
include_attributes=self.include_attributes,
max_error_length=self.max_error_length,
step_info=step_info,
).get_user_message()

if state_message and hasattr(state_message, 'content'):
if isinstance(state_message.content, str):
self._add_message_with_tokens(state_message)
elif isinstance(state_message.content, list):
# Handle multi-modal messages (text + image)
has_valid_content = False
for item in state_message.content:
if isinstance(item, dict):
if item.get('type') == 'text' and item.get('text'):
has_valid_content = True
elif item.get('type') == 'image_url' and item.get('image_url', {}).get('url'):
has_valid_content = True
if has_valid_content:
self._add_message_with_tokens(state_message)

except Exception as e:
logger.error(f"Error in add_state_message: {str(e)}")
# Create a basic message if state processing fails
msg = HumanMessage(content="Error processing browser state")
self._add_message_with_tokens(msg)