-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
51 lines (40 loc) · 1.68 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import argparse
import nltk
from utils import measure
from config import cfg
from trainer import RWTrainer
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument('-z', type=str, default='train', choices=['train', 'test'],
help='action, train or test. The default is train.')
parser.add_argument('-k', type=int, default=2, help='Maximum length of word spans')
parser.add_argument('-e', type=int, default=5, help='# of emission epochs')
parser.add_argument('-t', type=int, default=5, help='# of transition epochs')
parser.add_argument('-m', type=str, default='skip', choices=['without', 'with', 'skip'],
help='Null mode. See HMM.py for details.')
parser.add_argument('-n', type=int, default=0, help='Prevent numerics-to-string alignment')
parser.add_argument('-r', type=float, default=0.4, help='Null ratio for posterior regularization')
parser.add_argument('-j', type=int, default=4, help='Number of workers (multiprocess)')
parser.add_argument('-f', type=str, default='LL',
help='Filter sentences. "L" for line-scores and "B" for box-scores')
return parser.parse_args()
if __name__ == '__main__':
args = get_parser()
cfg.k = args.k
cfg.emit_epoch = args.e
cfg.trans_epoch = args.t
cfg.null_mode = args.m
cfg.no_num = (args.n > 0)
cfg.null_ratio = args.r
cfg.jobs = args.j
cfg.filter = args.f
nltk.download('punkt')
nltk.download('wordnet')
# cfg.print_info()
trainer = RWTrainer(cfg)
todo = args.z
if todo == 'train':
trainer.train()
elif todo == 'test':
trainer.random_test()
measure('END')