-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathstat_multi.py
27 lines (24 loc) · 1.03 KB
/
stat_multi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import pandas as pd
import ast
for shot in [4, 16, 64, 256]:
for reward in [0]:
fix = "rl_" if reward == 1 else ""
df = pd.read_csv(
f"result_chat/advbench/multiturn_{fix}Llama-3.1-70B-Instruct_{shot}shot.csv"
)
counts = {i: {"safe": 0, "unsafe": 0} for i in range(1, 11)}
for judgement_str in df["judgement"]:
try:
judgement_list = ast.literal_eval(judgement_str)
for i in range(min(10, len(judgement_list))):
if judgement_list[i] == "safe":
counts[i + 1]["safe"] += 1
elif judgement_list[i] == "unsafe":
counts[i + 1]["unsafe"] += 1
except (ValueError, SyntaxError):
print(f"Error parsing judgement string: {judgement_str}")
print("reward:", bool(reward), "shot:", shot)
for i in range(1, 11):
print(
f"Position {i}: safe:unsafe = {counts[i]['safe']}:{counts[i]['unsafe']}"
)