-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
139 lines (110 loc) · 4.85 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import csv
import json
import random
from typing import List, Optional, Any
import numpy as np
def set_seed(seed: int) -> None:
"""Set RNG seeds for python's `random` module, numpy and torch"""
random.seed(seed)
np.random.seed(seed)
def read_inputs(input_file: str, input_file_type: str) -> List[str]:
valid_types = ['plain', 'jsonl', 'stsb']
assert input_file_type in valid_types, f"Invalid input file type: '{input_file_type}'. Valid types: {valid_types}"
if input_file_type == "plain":
return read_plaintext_inputs(input_file)
elif input_file_type == "jsonl":
return read_jsonl_inputs_ab(input_file)
elif input_file_type == "stsb":
return read_sts_inputs(input_file)
def read_plaintext_inputs(path: str) -> List[str]:
"""Read input texts from a plain text file where each line corresponds to one input"""
with open(path, 'r', encoding='utf8') as fh:
inputs = fh.read().splitlines()
print(f"Done loading {len(inputs)} inputs from file '{path}'")
return inputs
def read_jsonl_inputs_ab(path: str) -> List[str]:
"""Read input texts from a jsonl file, where each line is one json object and input texts are stored in the field 'text_a'"""
ds_entries = DatasetEntry.read_list(path)
print(f"Done loading {len(ds_entries)} inputs from file '{path}'")
return [(entry.text_a, entry.text_b) for entry in ds_entries]
def read_jsonl_inputs(path: str) -> List[str]:
"""Read input texts from a jsonl file, where each line is one json object and input texts are stored in the field 'text_a'"""
ds_entries = DatasetEntry.read_list(path)
print(f"Done loading {len(ds_entries)} inputs from file '{path}'")
return [entry.text_a for entry in ds_entries]
def read_sts_inputs(path: str) -> List[str]:
"""Read input texts from a tsv file, formatted like the official STS benchmark"""
inputs = []
with open(path, 'r', encoding='utf8') as fh:
reader = csv.reader(fh, delimiter='\t', quoting=csv.QUOTE_NONE)
for row in reader:
try:
sent_a, sent_b = row[5], row[6]
inputs.append(sent_a)
inputs.append(sent_b)
except IndexError:
print(f"Cannot parse line {row}")
print(f"Done loading {len(inputs)} inputs from file '{path}'")
return inputs
def read_nli_inputs(path: str) -> List[str]:
pass
class DatasetEntry:
"""This class represents a dataset entry for text (pair) classification"""
def __init__(self, text_a: str, text_b: Optional[str], label: Any):
self.text_a = text_a
self.text_b = text_b
self.label = label
def __repr__(self):
if self.text_b is not None:
return f'DatasetEntry(text_a="{self.text_a}", text_b="{self.text_b}", label={self.label})'
else:
return f'DatasetEntry(text_a="{self.text_a}", label={self.label})'
def __key(self):
return self.text_a, self.text_b, self.label
def __hash__(self):
return hash(self.__key())
def __eq__(self, other):
if isinstance(other, DatasetEntry):
return self.__key() == other.__key()
return False
@staticmethod
def save_list(entries: List['DatasetEntry'], path: str):
with open(path, 'w', encoding='utf8') as fh:
for entry in entries:
fh.write(f'{json.dumps(entry.__dict__)}\n')
@staticmethod
def read_list(path: str) -> List['DatasetEntry']:
pairs = []
with open(path, 'r', encoding='utf8') as fh:
for line in fh:
pairs.append(DatasetEntry(**json.loads(line)))
return pairs
class DatasetEntryWithExp:
"""This class represents a dataset entry for text (pair) classification"""
def __init__(self, text_a, text_b, label, explanation):
self.text_a = text_a
self.text_b = text_b
self.label = label
self.explanation = explanation
def __repr__(self):
return f'DatasetEntryWithExplanation(text_a="{self.text_a}", text_b="{self.text_b}", label={self.label}, explanation={self.explanation})'
def __key(self):
return self.text_a, self.text_b, self.label, self.explanation
def __hash__(self):
return hash(self.__key())
def __eq__(self, other):
if isinstance(other, DatasetEntryWithExp):
return self.__key() == other.__key()
return False
@staticmethod
def save_list(entries: List['DatasetEntryWithExp'], path: str):
with open(path, 'w', encoding='utf8') as fh:
for entry in entries:
fh.write(f'{json.dumps(entry.__dict__)}\n')
@staticmethod
def read_list(path: str) -> List['DatasetEntryWithExp']:
pairs = []
with open(path, 'r', encoding='utf8') as fh:
for line in fh:
pairs.append(DatasetEntryWithExp(**json.loads(line)))
return pairs