Skip to content

Commit

Permalink
so many changes
Browse files Browse the repository at this point in the history
  • Loading branch information
dhruvahuja19 committed Dec 17, 2024
1 parent 082ab98 commit ca959d4
Show file tree
Hide file tree
Showing 13 changed files with 498 additions and 312 deletions.
17 changes: 17 additions & 0 deletions analyze_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import json
from pathlib import Path

# Read results file
results_file = Path('results/benchmark_results.json/results.json')
with open(results_file) as f:
results = json.load(f)

# Calculate succexss percentage
total_tasks = len(results)
successful_tasks = sum(1 for result in results if result.get('success', False))
success_percentage = (successful_tasks / total_tasks) * 100 if total_tasks > 0 else 0

print(f"\nResults Analysis:")
print(f"Total Tasks: {total_tasks}")
print(f"Successful Tasks: {successful_tasks}")
print(f"Success Rate: {success_percentage:.2f}%")
92 changes: 1 addition & 91 deletions data/dom_tasks.jsonl

Large diffs are not rendered by default.

29 changes: 17 additions & 12 deletions evaluation/auto_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def run_serial_evaluation(
results_dir: Path,
output_file: Path,
openai_key: str
) -> None:
) -> Dict[str, Any]:
"""Run evaluation on task results serially"""
# Initialize OpenAI client
client = OpenAI(api_key=openai_key)
Expand Down Expand Up @@ -59,7 +59,7 @@ def run_serial_evaluation(
"success": result["success"],
"visual_score": visual_score,
"html_score": html_score,
"final_score": (visual_score + html_score) / 2,
"final_score": (0.8 * visual_score + 0.2 * html_score),
"visual_reasoning": visual_reasoning,
"html_reasoning": html_reasoning
}
Expand All @@ -76,23 +76,28 @@ def run_serial_evaluation(
"error": str(e)
})

# Save evaluations to output file
with output_file.open('w') as f:
json.dump({
"total_tasks": len(tasks),
"successful_tasks": sum(1 for e in evaluations if e.get("success", False)),
"evaluations": evaluations
}, f, indent=2)
evaluation_results = {
"total_tasks": len(tasks),
"successful_tasks": sum(1 for e in evaluations if e.get("success", False)),
"evaluations": evaluations
}

# Save evaluations if output file is provided
if output_file:
with output_file.open('w') as f:
json.dump(evaluation_results, f, indent=2)

return evaluation_results

def run_evaluation(
tasks_file: Path,
results_dir: Path,
output_file: Path,
openai_key: str,
max_workers: int = None
) -> None:
) -> Dict[str, Any]:
"""Run evaluation on task results using either serial or parallel mode"""
if max_workers:
run_parallel_evaluation(tasks_file, results_dir, output_file, openai_key, max_workers)
return run_parallel_evaluation(tasks_file, results_dir, output_file, openai_key, max_workers)
else:
run_serial_evaluation(tasks_file, results_dir, output_file, openai_key)
return run_serial_evaluation(tasks_file, results_dir, output_file, openai_key)
16 changes: 16 additions & 0 deletions evaluation/fuzzy_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,22 @@ def fuzzy_match_html(

client = openai_client

# Truncate inputs if too long
max_html_length = 2000 # Characters per HTML string
max_task_length = 500 # Characters for task description

if len(actual_html) > max_html_length:
actual_html = actual_html[:max_html_length] + "..."
logger.warning("Actual HTML was truncated due to length")

if len(expected_html) > max_html_length:
expected_html = expected_html[:max_html_length] + "..."
logger.warning("Expected HTML was truncated due to length")

if len(task_description) > max_task_length:
task_description = task_description[:max_task_length] + "..."
logger.warning("Task description was truncated due to length")

user_prompt = f"""You are evaluating if an HTML element matches the expected element for the following task: {task_description}
Expected HTML: {expected_html}
Expand Down
37 changes: 32 additions & 5 deletions evaluation/image_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@

def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
image_data = image_file.read()
# Check file size (max 20MB)
if len(image_data) > 20 * 1024 * 1024:
raise ValueError(f"Image {image_path} is too large (>20MB)")
return base64.b64encode(image_data).decode('utf-8')

def compare_images(prompt, ground_truth_path, agent_image_path, note = None, openai_client = None):
if openai_client is None:
Expand All @@ -42,17 +46,40 @@ def compare_images(prompt, ground_truth_path, agent_image_path, note = None, ope
logger.debug("Using provided OpenAI client")
client = openai_client

image1 = encode_image(ground_truth_path)
image2 = encode_image(agent_image_path)
try:
image1 = encode_image(ground_truth_path)
image2 = encode_image(agent_image_path)
except ValueError as e:
logger.error(f"Image encoding error: {str(e)}")
return False, f"Image processing error: {str(e)}"

# Truncate prompt if too long
max_prompt_length = 500
if len(prompt) > max_prompt_length:
prompt = prompt[:max_prompt_length] + "..."

user_prompt = f"The agent was trying to accomplish the following task: {prompt} The first image is the expected image and the second image is the agent's output. Does the image answer the question correctly as the expected image? Don't focus on unnecessary details, like axes titles or colors or image size or labels unless specified in the task."
if note:
# Truncate note if too long
if len(note) > 200:
note = note[:200] + "..."
user_prompt += f"Here are some notes to help you evaluate the images: {note}"
messages = [
{"role": "system", "content": system_prompt},
{
"role": "system",
"content": """You are evaluating if a web automation task was completed successfully. Compare the screenshots and determine if the task's goal was achieved, focusing on the relevant UI changes that indicate success.
Return a JSON object with:
- correctness (boolean): Whether the task was completed successfully
- reason (string): Clear explanation of your evaluation"""
},
{
"role": "user",
"content": [
{"type": "text", "text": user_prompt},
{
"type": "text",
"text": user_prompt
},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{image1}"}
Expand Down
83 changes: 52 additions & 31 deletions evaluation/parallel_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Dict, Any, List, Tuple
from openai import OpenAI
from concurrent.futures import ThreadPoolExecutor, as_completed
import time

from evaluation.image_match import compare_images
from evaluation.fuzzy_match import fuzzy_match_html
Expand Down Expand Up @@ -32,16 +33,19 @@ def evaluate_task(task: Dict[str, Any], result: Dict[str, Any], client: OpenAI)
visual_score = 1.0 if visual_correctness else 0.0
html_score = 1.0 if html_correctness else 0.0

# Calculate final score: 80% visual, 20% HTML
final_score = (0.8 * visual_score) + (0.2 * html_score)

evaluation = {
"task_id": task_id,
"success": result["success"],
"visual_score": visual_score,
"html_score": html_score,
"final_score": (visual_score + html_score) / 2,
"final_score": final_score,
"visual_reasoning": visual_reasoning,
"html_reasoning": html_reasoning
}
logging.info(f"Evaluated task {task_id}: score={evaluation.get('final_score', 0.0):.2f}")
logging.info(f"Evaluated task {task_id}: score={final_score:.2f}")
return evaluation
except Exception as e:
logging.error(f"Error evaluating task {task_id}: {str(e)}")
Expand All @@ -60,7 +64,7 @@ def run_parallel_evaluation(
output_file: Path,
openai_key: str,
max_workers: int = 4
) -> None:
) -> Dict[str, Any]:
"""Run evaluation on task results in parallel"""
# Initialize OpenAI client
client = OpenAI(api_key=openai_key)
Expand All @@ -84,33 +88,50 @@ def run_parallel_evaluation(
if result:
task_pairs.append((task, result))

# Run evaluations in parallel
with ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_task = {
executor.submit(evaluate_task, task, result, client): task_id
for task, result in task_pairs
}
# Process tasks in smaller batches to avoid rate limits
batch_size = min(max_workers, 3) # Process at most 3 tasks at a time
for i in range(0, len(task_pairs), batch_size):
batch = task_pairs[i:i + batch_size]
logging.info(f"Processing evaluation batch {i//batch_size + 1}/{(len(task_pairs) + batch_size - 1)//batch_size}")

for future in as_completed(future_to_task):
try:
evaluation = future.result()
evaluations.append(evaluation)
except Exception as e:
task_id = future_to_task[future]
logging.error(f"Error in evaluation future for task {task_id}: {str(e)}")
evaluations.append({
"task_id": task_id,
"success": False,
"visual_score": 0.0,
"html_score": 0.0,
"final_score": 0.0,
"error": str(e)
})
# Run evaluations in parallel for this batch
with ThreadPoolExecutor(max_workers=batch_size) as executor:
future_to_task = {
executor.submit(evaluate_task, task, result, client): task['id']
for task, result in batch
}

for future in as_completed(future_to_task):
try:
evaluation = future.result(timeout=60) # 60 second timeout per evaluation
evaluations.append(evaluation)
logging.info(f"Completed evaluation for task {future_to_task[future]}")
except Exception as e:
task_id = future_to_task[future]
error_msg = f"Error in evaluation future for task {task_id}: {str(e)}"
logging.error(error_msg)
evaluations.append({
"task_id": task_id,
"success": False,
"visual_score": 0.0,
"html_score": 0.0,
"final_score": 0.0,
"error": error_msg
})

# Add a small delay between batches to avoid rate limits
if i + batch_size < len(task_pairs):
time.sleep(1)

evaluation_results = {
"total_tasks": len(tasks),
"successful_tasks": sum(1 for e in evaluations if e.get("success", False)),
"evaluations": evaluations
}

# Save evaluations to output file
with output_file.open('w') as f:
json.dump({
"total_tasks": len(tasks),
"successful_tasks": sum(1 for e in evaluations if e.get("success", False)),
"evaluations": evaluations
}, f, indent=2)
# Save evaluations if output file is provided
if output_file:
with output_file.open('w') as f:
json.dump(evaluation_results, f, indent=2)

return evaluation_results
14 changes: 6 additions & 8 deletions models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,24 @@ class WebInteraction:

@dataclass
class TaskResult:
"""Represents the result of executing a task."""
"""Class to store task execution results"""
task_id: str
success: bool
before_screenshot: Optional[str] = None
after_screenshot: Optional[str] = None
error: Optional[str] = None
html_element: Optional[str] = None
after_screenshot: Optional[str] = None
accessibility_tree: Optional[Dict[str, Any]] = None
error: Optional[str] = None
metadata: Optional[Dict[str, Any]] = None

def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for serialization."""
"""Convert to dictionary format"""
return {
"task_id": self.task_id,
"success": self.success,
"before_screenshot": self.before_screenshot,
"after_screenshot": self.after_screenshot,
"error": self.error,
"html_element": self.html_element,
"after_screenshot": self.after_screenshot,
"accessibility_tree": self.accessibility_tree,
"error": self.error,
"metadata": self.metadata
}

Expand Down
10 changes: 2 additions & 8 deletions models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,7 @@ def parse_task(self, task: Dict[str, Any]) -> WebInteraction:
selector_type=interaction_data.get('selector_type', task['target_element']['type']),
selector_value=interaction_data.get('selector_value', task['target_element']['value']),
input_text=interaction_data.get('input_text'),
description=task['task'],
wait_time=interaction_data.get('wait_time', 0),
hover_duration=interaction_data.get('hover_duration', 0),
validation=interaction_data.get('validation', {})
description=task['task']
)
except Exception as e:
print(f"Error parsing Gemini response: {str(e)}")
Expand Down Expand Up @@ -167,10 +164,7 @@ def handle_error(self, task: Dict[str, Any], error: str) -> Optional[WebInteract
selector_type=interaction_data['selector_type'],
selector_value=interaction_data['selector_value'],
input_text=interaction_data.get('input_text'),
description=f"Error recovery: {task['task']}",
wait_time=interaction_data.get('wait_time', 0),
hover_duration=interaction_data.get('hover_duration', 0),
validation=interaction_data.get('validation', {})
description=f"Error recovery: {task['task']}"
)
except Exception as e:
print(f"Error in error handling: {str(e)}")
Expand Down
Loading

0 comments on commit ca959d4

Please sign in to comment.