-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluate.py
117 lines (110 loc) · 3.66 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from pages.search import run, create_args
from tabulate import tabulate # type: ignore
from utils import fix_arxiv_link, eval_datasets
import numpy as np
from constants import *
if __name__ == "__main__":
args = create_args()
metric_results = {}
data_names = []
len_data = 0
use_split = None
if args.masader_validate or args.masader_test:
titles = []
data_names = []
paper_links = []
years = []
links = []
if args.masader_validate:
use_split = "valid"
dataset = eval_datasets[args.schema][use_split]
else:
use_split = "test"
dataset = eval_datasets[args.schema][use_split]
for x in dataset:
titles.append(str(x["Paper Title"]))
data_names.append(str(x["Name"]))
paper_links.append(str(x["Paper Link"]))
years.append(str(x["Year"]))
links.append(x["Link"])
else:
data_names = args.keywords.split(",")
titles = ["" for _ in data_names]
paper_links = ["" for _ in data_names]
years = ["" for _ in data_names]
links = ["" for _ in data_names]
models = args.models.split(",")
len_data = len(data_names)
curr_idx = [0,len(data_names) * len(models)]
for data_name, title, paper_link, year, link in zip(data_names, titles, paper_links, years, links):
if title != "":
title = title.replace("\r\n", " ")
title = title.replace(":", "")
args.keywords = title
else:
args.keywords = data_name
if paper_link != "":
paper_link = fix_arxiv_link(paper_link)
model_results = run(
mode="api",
link=paper_link,
year=year,
month=None,
models=args.models.split(","),
browse_web=args.browse_web,
overwrite=args.overwrite,
use_split=use_split,
repo_link=link,
summarize = args.summarize,
curr_idx= curr_idx,
schema = args.schema,
use_pdf = args.use_pdf,
few_shot = args.few_shot
)
else:
model_results = run(
mode="api",
keywords=args.keywords,
year=None,
month=None,
models=args.models.split(","),
browse_web=args.browse_web,
overwrite=args.overwrite,
use_split=use_split,
repo_link=link,
summarize = args.summarize,
curr_idx = curr_idx,
schema = args.schema,
use_pdf = args.use_pdf,
few_shot = args.few_shot
)
for model_name in model_results:
results = model_results[model_name]
if model_name not in metric_results:
metric_results[model_name] = []
metric_results[model_name].append(
[results["validation"][m] for m in results["validation"]]
)
results = []
for model_name in metric_results:
if len(metric_results[model_name]) == len_data:
results.append(
[model_name]
+ (np.mean(metric_results[model_name], axis=0) * 100).tolist()
)
headers = [
"MODEL",
"CONTENT",
"ACCESSABILITY",
"DIVERSITY",
"EVALUATION",
"AVERAGE",
]
print(
tabulate(
sorted(results, key=lambda x: x[-1]),
headers=headers,
tablefmt="grid",
floatfmt=".2f",
)
)