-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathaugment.py
54 lines (43 loc) · 1.85 KB
/
augment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# Easy data augmentation techniques for text classification
# Jason Wei and Kai Zou
from eda import *
#arguments to be parsed from command line
import argparse
ap = argparse.ArgumentParser()
ap.add_argument("--input", required=True, type=str, help="input file of unaugmented data")
ap.add_argument("--output", required=False, type=str, help="output file of unaugmented data")
ap.add_argument("--num_aug", required=False, type=int, help="number of augmented sentences per original sentence")
ap.add_argument("--alpha", required=False, type=float, help="percent of words in each sentence to be changed")
args = ap.parse_args()
#the output file
output = None
if args.output:
output = args.output
else:
from os.path import dirname, basename, join
output = join(dirname(args.input), 'eda_' + basename(args.input))
#number of augmented sentences to generate per original sentence
num_aug = 9 #default
if args.num_aug:
num_aug = args.num_aug
#how much to change each sentence
alpha = 0.1#default
if args.alpha:
alpha = args.alpha
#generate more data with standard augmentation
def gen_eda(train_orig, output_file, alpha, num_aug=9):
writer = open(output_file, 'w')
lines = open(train_orig, 'r').readlines()
for i, line in enumerate(lines):
parts = line[:-1].split(',')
label = parts[0]
sentence = parts[1]
aug_sentences = eda(sentence, alpha_sr=alpha, alpha_ri=alpha, alpha_rs=alpha, p_rd=alpha, num_aug=num_aug)
for aug_sentence in aug_sentences:
writer.write(label + "," + aug_sentence + '\n')
writer.close()
print("generated augmented sentences with eda for " + train_orig + " to " + output_file + " with num_aug=" + str(num_aug))
#main function
if __name__ == "__main__":
#generate augmented sentences and output into a new file
gen_eda(args.input, output, alpha=alpha, num_aug=num_aug)