-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathevaluate.py
124 lines (108 loc) · 4.33 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import argparse
import json
import os
from muchomusic_eval.scoring import (
compare_answers,
extract_responses,
get_all_categories,
get_finegrained_genre_scores,
get_finegrained_knowledge_scores,
get_finegrained_reasoning_scores,
get_knowledge_scores,
get_reasoning_scores,
)
from muchomusic_eval.utils import format_dict
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input", type=str, default="data/example_file.json")
parser.add_argument("--output_dir", type=str, default="results")
parser.add_argument("--save_results", action="store_false")
parser.add_argument("--eval_name", type=str, default="default")
args = parser.parse_args()
with open(args.input) as f:
eval_json = json.load(f)
eval_json = {k: [dic[k] for dic in eval_json] for k in eval_json[0]}
model_responses = eval_json["model_output"]
prompts = eval_json["prompt"]
answers = eval_json["answers"]
answer_orders = eval_json["answer_orders"]
extracted_responses = extract_responses(
model_responses,
answers,
)
reasoning_scores = get_reasoning_scores(
extracted_responses,
eval_json,
)
knowledge_scores = get_knowledge_scores(
extracted_responses,
eval_json,
)
scores = compare_answers(
extracted_responses,
answer_orders,
)
########################
# Finegrained results
########################
all_genres, knowledge_categories, reasoning_categories = get_all_categories()
all_genres = [i for i in all_genres if i in eval_json["genre"]]
contained_knowledge_cats = [i for j in eval_json["knowledge"] for i in j]
contained_reasoning_cats = [i for j in eval_json["reasoning"] for i in j]
knowledge_categories = [
i for i in knowledge_categories if i in contained_knowledge_cats
]
reasoning_categories = [
i for i in reasoning_categories if i in contained_reasoning_cats
]
genre_finegrained_scores = get_finegrained_genre_scores(
extracted_responses, eval_json, all_genres
)
reasoning_finegrained_scores = get_finegrained_reasoning_scores(
extracted_responses, eval_json, reasoning_categories
)
knowledge_finegrained_scores = get_finegrained_knowledge_scores(
extracted_responses, eval_json, knowledge_categories
)
all_genre_titles = [
"_".join("".join("_".join(i.split(", ")).split("& ")).split(" "))
for i in all_genres
]
results = {
"accuracy": scores["accuracy"],
"IFR": 1.0 - scores["unanswered_rate"],
"knowledge": {
"overall": knowledge_scores["accuracy"],
"finegrained": {
k: knowledge_finegrained_scores[k]["accuracy"]
for k in knowledge_categories
},
},
"reasoning": {
"overall": reasoning_scores["accuracy"],
"finegrained": {
k: reasoning_finegrained_scores[k]["accuracy"]
for k in reasoning_categories
},
},
}
format_dict(results)
print(json.dumps(results, indent=4))
# Update results csv
if args.save_results:
if not os.path.exists(args.output_dir):
os.mkdir(args.output_dir)
results_csv_path = os.path.join(args.output_dir, "results.csv")
if not os.path.exists(results_csv_path):
with open(results_csv_path, "w") as f:
f.write(
f"eval_name,accuracy,unanswered_rate,reasonig_acc,knowledge_acc,{','.join(all_genre_titles)},{','.join(reasoning_categories)},{','.join(knowledge_categories)}\n"
)
with open(results_csv_path, "a") as f:
f.write(
f"{args.eval_name},{scores['accuracy']},{scores['unanswered_rate']},{knowledge_scores['accuracy']},{reasoning_scores['accuracy']},{','.join([str(genre_finegrained_scores[genre]['accuracy']) for genre in all_genres])},"
f"{','.join([str(reasoning_finegrained_scores[reasoning]['accuracy']) for reasoning in reasoning_categories])},"
f"{','.join([str(knowledge_finegrained_scores[knowledge]['accuracy']) for knowledge in knowledge_categories])}\n"
)
from muchomusic_eval.plot_utils import spider_plot
spider_plot(results_csv_path, args.output_dir)