diff --git a/CybersecurityBenchmarks/benchmark/visual_prompt_injection_benchmark.py b/CybersecurityBenchmarks/benchmark/visual_prompt_injection_benchmark.py index 95224ea7..3e7e9d37 100644 --- a/CybersecurityBenchmarks/benchmark/visual_prompt_injection_benchmark.py +++ b/CybersecurityBenchmarks/benchmark/visual_prompt_injection_benchmark.py @@ -126,7 +126,7 @@ async def run(self, num_test_cases: int = 0, run_llm_in_parallel: int = 16) -> N judge_response_result = query_llm_to_generate_responses( {"judge_llm": self.judge_llm}, self.response_path, - run_llm_in_parallel, + 4, to_process_prompt=self._process_prompt_for_judge, enumerate_prompt_and_pass_id=False, num_test_cases=num_test_cases, @@ -248,11 +248,14 @@ def _process_prompt_for_judge( if not judge_llm: raise ValueError("Judge LLM must be specified.") - judge_response = judge_llm.query_with_retries( - prompt=JUDGE_PROMPT.format( - llm_response=response, question=judge_question - ), - ) + while True: + judge_response = judge_llm.query_with_retries( + prompt=JUDGE_PROMPT.format( + llm_response=response, question=judge_question + ), + ) + if judge_response.lower().strip(".") in {"yes", "no"}: + break judge_result = { "prompt_id": prompt_id,