-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathass_train.py
152 lines (136 loc) · 6.07 KB
/
ass_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from __future__ import (division, absolute_import,
print_function, unicode_literals)
import os
import sys
import argparse
import logging
from ass_theano import train
def main():
parser = argparse.ArgumentParser(
description='Abstractive sentence summariser'
)
#parser.add_argument('--log',
# default='DEBUG',
# choices=['DEBUG', 'WARNING', 'ERROR'],
# help='log level for Python logging')
parser.add_argument('--context-encoder',
default='baseline',
choices=['baseline', 'attention'],
help='context encoder name')
parser.add_argument('--corpus', required=True,
help='directory of the corpus e.g data/wikipedia/')
# optimiser
parser.add_argument('--optimizer',
default='adam',
choices=['adam', 'adadelta', 'rmsprop', 'sgd'],
help='optimizing algorithm')
parser.add_argument('--learning-rate', type=float,
default=0.001,
help='learning rate for the optimizer')
# model params
parser.add_argument('--embed-full-text-by',
choices=['word', 'sentence'],
default='word',
help='embed full text by word or sentence')
parser.add_argument('--seq-maxlen', type=int,
default=500,
help='max length of input full text')
parser.add_argument('--summary-maxlen', type=int,
default=200,
help='max length of each summary')
parser.add_argument('--summary-context-length', type=int,
default=5,
help='summary context length used for training')
parser.add_argument('--internal-representation-dim', type=int,
default=2000,
help='internal representation dimension')
parser.add_argument('--attention-weight-max-roll', type=int,
default=1,
help='max roll for the attention weight vector in attention encoder')
# training params
parser.add_argument('--l2-penalty-coeff', type=float,
default=0.00,
help='penalty coefficient to the L2-norms of the model params')
parser.add_argument('--train-split', type=float,
default=0.75,
help='weight of training corpus in the entire corpus, the rest for validation')
parser.add_argument('--epochs', type=int,
default=10000,
help='number of epochs for training')
parser.add_argument('--minibatch-size', type=int,
default=20,
help='mini batch size')
parser.add_argument('--seed', type=int,
default=None,
help='seed for the random stream')
parser.add_argument('--dropout-rate', type=float,
default=None,
help='dropout rate in (0,1)')
# model load/save
parser.add_argument('--save-params',
default='ass_params.pkl',
help='file for saving params')
parser.add_argument('--save-params-every', type=int,
default=5,
help='save params every <k> epochs')
parser.add_argument('--validate-every', type=int,
default=5,
help='validate every <k> epochs')
parser.add_argument('--print-every', type=int,
default=5,
help='print info every <k> batches')
# summary generation on the validation set
parser.add_argument('--generate-summary',
action='store_true',
default=False,
help='whether to generate summaries when validating')
parser.add_argument('--summary-search-beam-size', type=int,
default=2,
help='beam size for the summary search')
args = parser.parse_args()
#logging.basicConfig(level=args.log.upper())
assert args.learning_rate > 0
assert args.seq_maxlen > 0
assert args.summary_maxlen > 0
assert args.summary_context_length > 0
assert args.internal_representation_dim > 0
assert args.attention_weight_max_roll >= 0
assert args.l2_penalty_coeff >= 0
assert (args.train_split > 0 and args.train_split <= 1)
assert args.epochs >= 0
assert args.minibatch_size > 0
assert (args.seed is None or args.seed >= 0)
assert (args.dropout_rate is None
or (args.dropout_rate > 0 and args.dropout_rate < 1))
assert args.save_params_every > 0
assert args.validate_every > 0
assert args.print_every > 0
assert args.summary_search_beam_size > 0
args_dict = vars(args)
print('Args', args_dict)
train(**args_dict)
#train(
# model=args.model,
# corpus=args.corpus,
# optimizer=args.optimizer,
# learning_rate=args.learning_rate,
# embed_full_text_by=args.embed_full_text_by,
# summary_maxlen=args.summary_maxlen,
# summary_context_length=args.summary_context_length,
# l2_penalty_coeff=args.L2_penalty_coeff,
# minibatch_size=args.minibatch_size,
# epochs=args.epochs,
# train_split=args.train_split,
# seed=args.seed,
# dropout_rate=args.dropout_rate,
# internal_representation_dim=args.internal_representation_dim,
# attention_weight_max_roll=args.attention_weight_max_roll,
# load_params=args.load_params,
# save_params=args.save_params,
# save_params_every=args.save_params_every,
# validate_every=args.validate_every,
# generate_summary=args.generate_summary,
# summary_search_beam_size=args.summary_search_beam_size
#)
if __name__ == '__main__':
main()