-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathpredict.ipy
executable file
·62 lines (51 loc) · 2.33 KB
/
predict.ipy
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import argparse
import os
import json
parser = argparse.ArgumentParser(description='run.ipy')
# Prediction
parser.add_argument('-start', type=int, default=5,
help='Epoch to start prediction')
parser.add_argument('-end', type=int, default=30,
help='Epoch to end prediction')
parser.add_argument('-beam', type=int, default=3,
help='Beam size')
parser.add_argument('-tgt_len', type=int, default=3,
help='Beam size')
parser.add_argument('-models_dir', type=str, default='',
help='Models directory.')
parser.add_argument('-test_file', type=str, default='',
help='Test data directory.')
parser.add_argument('-additional', type=str, default='',
help='Any additional flags.')
opt = parser.parse_args()
try:
os.makedirs(opt.models_dir + '/preds/')
except:
pass
# We need a text file with the outputs to compute BLEU.
# so extract it out of the json file
test_dataset = json.loads(open(opt.test_file, 'r').read())
test_dataset_targets = open('/tmp/test.code', 'w')
for example in test_dataset:
test_dataset_targets.write(' '.join(example['code']).replace('concodeclass_', '').replace('concodefunc_', '') + '\n')
test_dataset_targets.close()
best_bleu, best_exact = (0, 0, 0), (0, 0, 0)
for i in range(opt.start, opt.end + 1):
fname = !ls {opt.models_dir}/model_acc_*e{i}.pt
f = os.path.basename(fname[0])
print(f)
!rm {opt.models_dir}/preds/{f}.nl.prediction*
# Prod is just a dummy here
!python translate.py -beam_size {opt.beam} -gpu 0 -model {fname[0]} -src {opt.test_file} -output {opt.models_dir}/preds/{f}.nl.prediction -max_sent_length {opt.tgt_len} -replace_unk -batch_size 1 -trunc 2000 {opt.additional}
bleu = !perl tools/multi-bleu.perl -lc /tmp/test.code < {opt.models_dir}/preds/{f}.nl.prediction
print(bleu)
bleu_score = float(bleu[0].split(',')[0])
exact = !python tools/exact.py /tmp/test.code < {opt.models_dir}/preds/{f}.nl.prediction
print(exact)
exact_score = float(exact[0])
if bleu_score > best_bleu[0]:
best_bleu = (bleu_score, exact_score, i)
if exact_score > best_exact[1]:
best_exact = (bleu_score, exact_score, i)
print ('Best BLEU so far is: {} - Exact is {} - on epoch {}'.format(*best_bleu))
print ('BLEU is {} - Best Exact so far: is {} - on epoch {}'.format(*best_exact))