generated from UKPLab/ukp-project-template
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathglue_utils.py
181 lines (157 loc) · 5.74 KB
/
glue_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
# Copyright 2024 The T5 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for GLUE and SuperGLUE tasks."""
import collections
#import functools
#from t5.data import postprocessors
#from t5.data import preprocessors
import t5_metrics as metrics
# These weights are based on the number of examples in each dataset.
SUPER_GLUE_WEIGHT_MAPPING = {
"dpr_v001_simple": 1_322.,
"super_glue_wsc_v102_simple_train": 259.,
"super_glue_wsc_v102_simple_eval": 0.,
"super_glue_boolq_v102": 9_427.,
"super_glue_cb_v102": 250.,
"super_glue_copa_v102": 400.,
"super_glue_multirc_v102": 27_243.,
"super_glue_record_v102": 138_854.,
"super_glue_rte_v102": 2_490.,
"super_glue_wic_v102": 5_428.,
"super_glue_axb_v102": 0.,
"super_glue_axg_v102": 0.,
}
# Mappings for the SuperGLUE tasks with sentinel tokens added.
SUPER_GLUE_WEIGHT_MAPPING_SENTINEL = {
"dpr_v001_simple_1_sentinel": 1_322.,
"super_glue_wsc_v102_simple_1_sentinel_train": 259.,
"super_glue_wsc_v102_simple_1_sentinel_eval": 0.,
"super_glue_boolq_v102_1_sentinel": 9_427.,
"super_glue_cb_v102_1_sentinel": 250.,
"super_glue_copa_v102_1_sentinel": 400.,
"super_glue_multirc_v102_1_sentinel": 27_243.,
"super_glue_record_v102_1_sentinel": 138_854.,
"super_glue_rte_v102_1_sentinel": 2_490.,
"super_glue_wic_v102_1_sentinel": 5_428.,
"super_glue_axb_v102_1_sentinel": 0.,
"super_glue_axg_v102_1_sentinel": 0.,
}
# These weights are based on the number of examples in each dataset.
# We omit WNLI because we train on WSC/DPR simple instead
GLUE_WEIGHT_MAPPING = {
"glue_cola_v002": 8_551.,
"glue_sst2_v002": 67_349.,
"glue_mrpc_v002": 3_668.,
"glue_qqp_v002": 363_849.,
"glue_stsb_v002": 5_749.,
"glue_mnli_v002": 392_702.,
"glue_qnli_v002": 104_743.,
"glue_rte_v002": 2_490.,
"glue_mnli_mismatched_v002": 0.,
"glue_mnli_matched_v002": 0.,
"glue_ax_v002": 0.,
}
def get_glue_weight_mapping():
return GLUE_WEIGHT_MAPPING
def get_super_glue_weight_mapping():
return SUPER_GLUE_WEIGHT_MAPPING
def get_super_glue_weight_mapping_sentinel():
return SUPER_GLUE_WEIGHT_MAPPING_SENTINEL
'''
def get_glue_text_preprocessor(builder_config):
"""Return the glue preprocessor.
Args:
builder_config: a BuilderConfig
Returns:
a preprocessor function
"""
# stsb uses a floating point target, so use special preprocessor
if builder_config.name == "stsb":
return preprocessors.stsb
elif builder_config.name == "wsc.fixed":
return preprocessors.wsc
elif builder_config.name == "record":
return preprocessors.record
else:
if "mnli" in builder_config.name or builder_config.name == "ax":
# Cast the GLUE diagnostic task as MNLI.
benchmark_name = "mnli"
elif builder_config.name in ["axb", "axg"]:
# Cast the SuperGLUE diagnostic tasks as RTE.
benchmark_name = "rte"
else:
benchmark_name = builder_config.name
if builder_config.name == "multirc":
feature_names = ("question", "answer", "paragraph")
elif builder_config.name == "wic":
# This ignores the start/end indices which show where in each sentence the
# word appears.
# TODO(craffel): Investigate using those indices.
feature_names = ("sentence1", "sentence2", "word")
else:
feature_names = None
return functools.partial(
preprocessors.glue,
benchmark_name=benchmark_name,
label_names=builder_config.label_classes,
feature_names=feature_names)
'''
'''
def get_glue_postprocess_fn(builder_config):
if builder_config.name == "stsb":
return postprocessors.string_to_float
elif builder_config.name == "multirc":
return postprocessors.multirc
elif builder_config.name == "record":
return postprocessors.record
else:
return functools.partial(
postprocessors.string_label_to_class_id,
label_classes=builder_config.label_classes,
)
'''
GLUE_METRICS = collections.OrderedDict([
("cola", [metrics.sklearn_metrics_wrapper(
"matthews_corrcoef", metric_post_process_fn=lambda x: 100 * x)]),
("sst2", [metrics.accuracy]),
("mrpc", [metrics.f1_score_with_invalid, metrics.accuracy]),
("stsb", [metrics.pearson_corrcoef, metrics.spearman_corrcoef]),
("qqp", [metrics.f1_score_with_invalid, metrics.accuracy]),
("mnli", [metrics.accuracy]),
("mnli_matched", [metrics.accuracy]),
("mnli_mismatched", [metrics.accuracy]),
("qnli", [metrics.accuracy]),
("rte", [metrics.accuracy]),
("wnli", [metrics.accuracy]),
("ax", []), # Only test set available.
])
def get_glue_metric(task_name):
return GLUE_METRICS[task_name]
SUPERGLUE_METRICS = collections.OrderedDict([
("bool_q", [metrics.accuracy]),
("cb", [metrics.mean_multiclass_f1(num_classes=3), metrics.accuracy]),
("copa", [metrics.accuracy]),
("multirc", [
metrics.multirc_f1_over_all_answers,
metrics.mean_group_metric(metrics.all_match)
]),
#("record", [metrics.deduplicate_metric(metrics.squad)]),
("record", [metrics.squad]),
("rte", [metrics.accuracy]),
("wic", [metrics.accuracy]),
("axb", []), # Only test set available.
("axg", []), # Only test set available.
])
def get_super_glue_metric(task_name):
return SUPERGLUE_METRICS[task_name]