diff --git a/.env.example b/.env.example index a86798c..d2b4c69 100644 --- a/.env.example +++ b/.env.example @@ -4,9 +4,13 @@ OPENAI_API_KEY=your-openai-api-key-here # Anthropic API Key for Claude model ANTHROPIC_API_KEY=your-anthropic-api-key-here +# Google API Key for Gemini model +GOOGLE_API_KEY=your-google-api-key-here + # Optional: Model configurations GPT4_MODEL=gpt-4-turbo-preview # or gpt-4 CLAUDE_MODEL=claude-3-opus-20240229 +GEMINI_MODEL=gemini-pro # Optional: Execution settings MAX_WORKERS=4 diff --git a/models/__init__.py b/models/__init__.py index 3ee30b0..0412afc 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -1,5 +1,6 @@ from .base import BaseModel, WebInteraction, TaskResult from .gpt4 import GPT4Model from .claude import ClaudeModel +from .gemini import GeminiModel -__all__ = ['BaseModel', 'WebInteraction', 'TaskResult', 'GPT4Model', 'ClaudeModel'] +__all__ = ['BaseModel', 'WebInteraction', 'TaskResult', 'GPT4Model', 'ClaudeModel', 'GeminiModel'] diff --git a/models/gemini.py b/models/gemini.py new file mode 100644 index 0000000..6ac5790 --- /dev/null +++ b/models/gemini.py @@ -0,0 +1,200 @@ +import json +import time +from typing import Dict, Any, Optional, Tuple +import google.generativeai as genai +from .base import BaseModel, WebInteraction, TaskResult + +class GeminiModel(BaseModel): + """Gemini model implementation for the DOM benchmark.""" + + def __init__(self, api_key: str, model_config: Dict[str, Any] = None): + super().__init__("gemini-pro", model_config or {}) + genai.configure(api_key=api_key) + self.model = genai.GenerativeModel('gemini-pro') + self.max_retries = 10 + self.temperature = model_config.get("temperature", 0) + + # Enhanced system prompt based on WebVoyager's approach + self.system_prompt = """You are an AI assistant that helps users interact with web elements. +Your task is to understand the user's intent and generate precise web element interactions. + +For each task, analyze: +1. The user's goal and required interaction (click, type, scroll, wait) +2. The target element's properties and accessibility +3. Any constraints or special conditions + +Key Guidelines: +1. Prefer stable selectors (id, unique class names) over dynamic ones +2. Consider element visibility and interactability +3. Handle dynamic content and loading states +4. Pay attention to timing and wait states +5. Validate success criteria for each interaction + +Generate interactions in this JSON format: +{ + "action": "click|type|scroll|wait", + "selector_type": "css|xpath|id", + "selector_value": "string", + "input_text": "string", # For type actions + "wait_time": integer, # For wait actions in seconds + "scroll_direction": "up|down", # For scroll actions + "validation": { + "expected_state": "visible|hidden|text_present|text_absent", + "validation_selector": "string", # Element to validate + "expected_text": "string" # For text validation + } +}""" + + def _call_api(self, prompt: str, retry_count: int = 0) -> Tuple[Optional[str], bool]: + """Helper method to call Gemini API with retry logic.""" + try: + response = self.model.generate_content(prompt, temperature=self.temperature) + return response.text, False + except Exception as e: + if retry_count >= self.max_retries: + print(f"Max retries ({self.max_retries}) exceeded. Error: {str(e)}") + return None, True + + wait_time = min(2 ** retry_count, 60) # Exponential backoff + if hasattr(e, "__class__"): + if e.__class__.__name__ == "RateLimitError": + wait_time = max(wait_time, 10) + elif e.__class__.__name__ == "ServerError": + wait_time = max(wait_time, 15) + + print(f"API call failed, retrying in {wait_time}s. Error: {str(e)}") + time.sleep(wait_time) + return self._call_api(prompt, retry_count + 1) + + def parse_task(self, task: Dict[str, Any]) -> WebInteraction: + """Parse task using Gemini to understand the interaction.""" + prompt = f"""System: {self.system_prompt} + +Task Description: {task['task']} +Target Element: {json.dumps(task['target_element'], indent=2)} +Page Context: {task.get('page_context', '')} +Previous Actions: {task.get('previous_actions', [])} + +Consider: +1. Is this a multi-step interaction? +2. Are there any loading or dynamic states to handle? +3. What validation should be performed? + +Generate the optimal web interaction instruction as a JSON object.""" + + response_text, error = self._call_api(prompt) + if error or not response_text: + return self._create_fallback_interaction(task) + + try: + # Find and parse JSON in response + start_idx = response_text.find('{') + end_idx = response_text.rfind('}') + 1 + if start_idx != -1 and end_idx != -1: + json_str = response_text[start_idx:end_idx] + interaction_data = json.loads(json_str) + + return WebInteraction( + action=interaction_data.get('action', task.get('interaction', 'click')), + selector_type=interaction_data.get('selector_type', task['target_element']['type']), + selector_value=interaction_data.get('selector_value', task['target_element']['value']), + input_text=interaction_data.get('input_text'), + description=task['task'], + wait_time=interaction_data.get('wait_time', 0), + validation=interaction_data.get('validation', {}) + ) + except Exception as e: + print(f"Error parsing Gemini response: {str(e)}") + return self._create_fallback_interaction(task) + + def _create_fallback_interaction(self, task: Dict[str, Any]) -> WebInteraction: + """Create a fallback interaction when API calls or parsing fails.""" + return WebInteraction( + action=task.get('interaction', 'click'), + selector_type=task['target_element']['type'], + selector_value=task['target_element']['value'], + input_text=task.get('input_text'), + description=task['task'] + ) + + def handle_error(self, task: Dict[str, Any], error: str) -> Optional[WebInteraction]: + """Use Gemini to understand and handle errors with enhanced error analysis.""" + prompt = f"""System: {self.system_prompt} + +Task: {task['task']} +Original Error: {error} +Previous Interaction: {json.dumps(task.get('previous_interaction', {}), indent=2)} + +Analyze the error and suggest a solution considering: +1. Is this a timing/loading issue? +2. Is the selector still valid? +3. Is the element interactive? +4. Are there any prerequisite steps missing? + +Generate a modified interaction as a JSON object or respond with "GIVE UP" if unrecoverable.""" + + response_text, api_error = self._call_api(prompt) + if api_error or not response_text: + return self.parse_task(task) + + suggestion = response_text.strip() + if suggestion == "GIVE UP": + return None + + try: + # Find and parse JSON in response + start_idx = suggestion.find('{') + end_idx = suggestion.rfind('}') + 1 + if start_idx != -1 and end_idx != -1: + json_str = suggestion[start_idx:end_idx] + interaction_data = json.loads(json_str) + + return WebInteraction( + action=interaction_data['action'], + selector_type=interaction_data['selector_type'], + selector_value=interaction_data['selector_value'], + input_text=interaction_data.get('input_text'), + description=f"Error recovery: {task['task']}", + wait_time=interaction_data.get('wait_time', 0), + validation=interaction_data.get('validation', {}) + ) + except Exception as e: + print(f"Error in error handling: {str(e)}") + return self.parse_task(task) + + def validate_result(self, task: Dict[str, Any], result: TaskResult) -> bool: + """Enhanced validation using Gemini with detailed success criteria.""" + if result.error: + return False + + prompt = f"""System: {self.system_prompt} + +Task: {task['task']} +Target Element: {json.dumps(result.html_element, indent=2)} +Before State: {result.before_screenshot} +After State: {result.after_screenshot} +Validation Rules: {json.dumps(task.get('validation_rules', {}), indent=2)} + +Evaluate the interaction success based on: +1. Element state changes (visibility, content, attributes) +2. Page state changes (URL, dynamic content) +3. Error messages or warnings +4. Expected outcomes from validation rules + +Respond with: +- "YES" if all success criteria are met +- "NO" with a brief explanation if any criteria fail""" + + response_text, error = self._call_api(prompt) + if error or not response_text: + return False + + validation_result = response_text.strip() + + if validation_result.startswith("YES"): + return True + else: + failure_reason = validation_result.replace("NO", "").strip() + if failure_reason: + print(f"Validation failed: {failure_reason}") + return False diff --git a/models/gpt4.py b/models/gpt4.py index 0a591c1..beea252 100644 --- a/models/gpt4.py +++ b/models/gpt4.py @@ -1,5 +1,6 @@ import json -from typing import Dict, Any, Optional +import time +from typing import Dict, Any, Optional, Tuple from openai import OpenAI from .base import BaseModel, WebInteraction, TaskResult @@ -9,40 +10,109 @@ class GPT4Model(BaseModel): def __init__(self, api_key: str, model_config: Dict[str, Any] = None): super().__init__("gpt-4", model_config or {}) self.client = OpenAI(api_key=api_key) + self.max_retries = 10 + self.model = model_config.get("model", "gpt-4") + self.temperature = model_config.get("temperature", 0) + self.max_tokens = model_config.get("max_tokens", 1000) - # Default system prompt + # Enhanced system prompt based on WebVoyager self.system_prompt = """You are an AI assistant that helps users interact with web elements. Your task is to understand the user's intent and generate precise web element interactions. -You should focus on the specific interaction requested, using the provided element selectors. -For each task, you will: -1. Understand the required interaction (click, type, hover) -2. Identify the correct element using the provided selector -3. Generate the appropriate interaction instruction +For each task, analyze: +1. The user's goal and required interaction (click, type, scroll, wait) +2. The target element's properties and accessibility +3. Any constraints or special conditions -Respond only with the exact interaction needed, no explanations or additional text.""" +Key Guidelines: +1. Prefer stable selectors (id, unique class names) over dynamic ones +2. Consider element visibility and interactability +3. Handle dynamic content and loading states +4. Pay attention to timing and wait states +5. Validate success criteria for each interaction + +Respond with a JSON object in this format: +{ + "action": "click|type|scroll|wait", + "selector_type": "css|xpath|id", + "selector_value": "string", + "input_text": "string", # For type actions + "wait_time": integer, # For wait actions in seconds + "scroll_direction": "up|down", # For scroll actions + "validation": { + "expected_state": "visible|hidden|text_present|text_absent", + "validation_selector": "string", # Element to validate + "expected_text": "string" # For text validation + } +}""" + + def _call_api(self, messages: list, retry_count: int = 0) -> Tuple[Optional[dict], bool]: + """Helper method to call OpenAI API with retry logic.""" + try: + response = self.client.chat.completions.create( + model=self.model, + messages=messages, + temperature=self.temperature, + max_tokens=self.max_tokens + ) + return response, False + except Exception as e: + if retry_count >= self.max_retries: + print(f"Max retries ({self.max_retries}) exceeded. Error: {str(e)}") + return None, True + + wait_time = min(2 ** retry_count, 60) # Exponential backoff + if hasattr(e, "__class__") and e.__class__.__name__ == "RateLimitError": + wait_time = max(wait_time, 10) + elif hasattr(e, "__class__") and e.__class__.__name__ == "APIError": + wait_time = max(wait_time, 15) + + print(f"API call failed, retrying in {wait_time}s. Error: {str(e)}") + time.sleep(wait_time) + return self._call_api(messages, retry_count + 1) def parse_task(self, task: Dict[str, Any]) -> WebInteraction: """Parse task using GPT-4 to understand the interaction.""" - # Construct prompt - prompt = f"""Task: {task['task']} -Target Element: {json.dumps(task['target_element'])} -Interaction Type: {task.get('interaction', 'click')} -Input Text: {task.get('input_text', '')} + prompt = f"""Task Description: {task['task']} +Target Element: {json.dumps(task['target_element'], indent=2)} +Page Context: {task.get('page_context', '')} +Previous Actions: {task.get('previous_actions', [])} -Generate the web interaction instruction.""" +Consider: +1. Is this a multi-step interaction? +2. Are there any loading or dynamic states to handle? +3. What validation should be performed? - # Get GPT-4 completion - response = self.client.chat.completions.create( - model="gpt-4", - messages=[ - {"role": "system", "content": self.system_prompt}, - {"role": "user", "content": prompt} - ], - temperature=0 - ) +Generate the optimal web interaction instruction as a JSON object.""" + + messages = [ + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": prompt} + ] - # Parse response into WebInteraction + response, error = self._call_api(messages) + if error or not response: + return self._create_fallback_interaction(task) + + try: + content = response.choices[0].message.content + interaction_data = json.loads(content) + + return WebInteraction( + action=interaction_data.get('action', task.get('interaction', 'click')), + selector_type=interaction_data.get('selector_type', task['target_element']['type']), + selector_value=interaction_data.get('selector_value', task['target_element']['value']), + input_text=interaction_data.get('input_text'), + description=task['task'], + wait_time=interaction_data.get('wait_time', 0), + validation=interaction_data.get('validation', {}) + ) + except Exception as e: + print(f"Error parsing GPT-4 response: {str(e)}") + return self._create_fallback_interaction(task) + + def _create_fallback_interaction(self, task: Dict[str, Any]) -> WebInteraction: + """Create a fallback interaction when API calls or parsing fails.""" return WebInteraction( action=task.get('interaction', 'click'), selector_type=task['target_element']['type'], @@ -52,45 +122,83 @@ def parse_task(self, task: Dict[str, Any]) -> WebInteraction: ) def handle_error(self, task: Dict[str, Any], error: str) -> Optional[WebInteraction]: - """Use GPT-4 to understand and handle errors.""" + """Use GPT-4 to understand and handle errors with enhanced error analysis.""" prompt = f"""Task: {task['task']} -Error: {error} +Original Error: {error} +Previous Interaction: {json.dumps(task.get('previous_interaction', {}), indent=2)} -How should we modify the interaction to handle this error? -If the error is unrecoverable, respond with "GIVE UP".""" +Analyze the error and suggest a solution considering: +1. Is this a timing/loading issue? +2. Is the selector still valid? +3. Is the element interactive? +4. Are there any prerequisite steps missing? - response = self.client.chat.completions.create( - model="gpt-4", - messages=[ - {"role": "system", "content": self.system_prompt}, - {"role": "user", "content": prompt} - ], - temperature=0 - ) +Generate a modified interaction as a JSON object or respond with "GIVE UP" if unrecoverable.""" + + messages = [ + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": prompt} + ] - suggestion = response.choices[0].message.content - if suggestion == "GIVE UP": + response, api_error = self._call_api(messages) + if api_error or not response: + return self.parse_task(task) + + content = response.choices[0].message.content + if content.strip() == "GIVE UP": return None - # Try to generate a new interaction based on GPT-4's suggestion - return self.parse_task(task) + try: + interaction_data = json.loads(content) + return WebInteraction( + action=interaction_data['action'], + selector_type=interaction_data['selector_type'], + selector_value=interaction_data['selector_value'], + input_text=interaction_data.get('input_text'), + description=f"Error recovery: {task['task']}", + wait_time=interaction_data.get('wait_time', 0), + validation=interaction_data.get('validation', {}) + ) + except Exception as e: + print(f"Error in error handling: {str(e)}") + return self.parse_task(task) def validate_result(self, task: Dict[str, Any], result: TaskResult) -> bool: - """Use GPT-4 to validate if the task was successful.""" + """Enhanced validation using GPT-4 with detailed success criteria.""" if result.error: return False prompt = f"""Task: {task['task']} -Target Element HTML: {result.html_element} -Was this interaction successful? Answer with just 'YES' or 'NO'.""" +Target Element: {json.dumps(result.html_element, indent=2)} +Before State: {result.before_screenshot} +After State: {result.after_screenshot} +Validation Rules: {json.dumps(task.get('validation_rules', {}), indent=2)} - response = self.client.chat.completions.create( - model="gpt-4", - messages=[ - {"role": "system", "content": self.system_prompt}, - {"role": "user", "content": prompt} - ], - temperature=0 - ) +Evaluate the interaction success based on: +1. Element state changes (visibility, content, attributes) +2. Page state changes (URL, dynamic content) +3. Error messages or warnings +4. Expected outcomes from validation rules + +Respond with: +- "YES" if all success criteria are met +- "NO" with a brief explanation if any criteria fail""" + + messages = [ + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": prompt} + ] - return response.choices[0].message.content == "YES" + response, error = self._call_api(messages) + if error or not response: + return False + + validation_result = response.choices[0].message.content.strip() + + if validation_result.startswith("YES"): + return True + else: + failure_reason = validation_result.replace("NO", "").strip() + if failure_reason: + print(f"Validation failed: {failure_reason}") + return False diff --git a/run.py b/run.py index 4725ab7..efc56b2 100644 --- a/run.py +++ b/run.py @@ -4,8 +4,24 @@ from parallel_runner import run_parallel_benchmark from serial_runner import run_serial_benchmark from evaluation.auto_eval import run_evaluation +from models import GPT4Model, ClaudeModel, GeminiModel import os +def get_model(model_name): + """Get the appropriate model based on command line argument.""" + load_dotenv() + + models = { + 'gpt4': lambda: GPT4Model(api_key=os.getenv("OPENAI_API_KEY")), + 'claude': lambda: ClaudeModel(api_key=os.getenv("ANTHROPIC_API_KEY")), + 'gemini': lambda: GeminiModel(api_key=os.getenv("GOOGLE_API_KEY")) + } + + if model_name not in models: + raise ValueError(f"Model {model_name} not supported. Choose from: {', '.join(models.keys())}") + + return models[model_name]() + def main(): parser = argparse.ArgumentParser(description='Run web automation tasks') parser.add_argument('--tasks', type=str, required=True, help='Path to tasks JSONL file') @@ -18,9 +34,13 @@ def main(): parser.add_argument('--evaluate', action='store_true', help='Run evaluation after benchmark') parser.add_argument('--evaluate-mode', type=str, choices=['serial', 'parallel'], default='parallel', help='Run evaluations serially or in parallel') + parser.add_argument('--model', choices=['gpt4', 'claude', 'gemini'], default='gpt4', help='Model to use for the benchmark') args = parser.parse_args() + # Initialize the selected model + model = get_model(args.model) + # Create output directory output_dir = Path(args.output) output_dir.mkdir(parents=True, exist_ok=True) @@ -30,16 +50,22 @@ def main(): results = run_parallel_benchmark( tasks_file=args.tasks, output_dir=args.output, + model=model, max_workers=args.max_workers, save_accessibility_tree=args.save_accessibility_tree, - wait_time=args.wait_time + wait_time=args.wait_time, + evaluate=args.evaluate, + evaluate_mode=args.evaluate_mode ) else: results = run_serial_benchmark( tasks_file=args.tasks, output_dir=args.output, + model=model, save_accessibility_tree=args.save_accessibility_tree, - wait_time=args.wait_time + wait_time=args.wait_time, + evaluate=args.evaluate, + evaluate_mode=args.evaluate_mode ) # Save results