-
Notifications
You must be signed in to change notification settings - Fork 128
/
Copy pathmain.py
55 lines (49 loc) · 2.18 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# -*- coding: utf-8 -*-
from __future__ import print_function
import argparse
import ast
import logging
from keras_wrapper.extra.read_write import pkl2dict
from config import load_parameters
from utils.utils import update_parameters
from nmt_keras import check_params
from nmt_keras.training import train_model
logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(message)s', datefmt='%d/%m/%Y %H:%M:%S')
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser("Train or sample NMT models")
parser.add_argument("-c", "--config", required=False, help="Config pkl for loading the model configuration. "
"If not specified, hyperparameters "
"are read from config.py")
parser.add_argument("-ds", "--dataset", required=False, help="Dataset instance with data")
parser.add_argument("changes", nargs="*", help="Changes to config. "
"Following the syntax Key=Value",
default="")
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
parameters = load_parameters()
if args.config is not None:
parameters = update_parameters(parameters, pkl2dict(args.config))
try:
for arg in args.changes:
try:
k, v = arg.split('=')
except ValueError:
print ('Overwritten arguments must have the form key=Value. \n Currently are: %s' % str(args.changes))
exit(1)
try:
parameters[k] = ast.literal_eval(v)
except ValueError:
parameters[k] = v
except ValueError:
print ('Error processing arguments: (', k, ",", v, ")")
exit(2)
parameters = check_params(parameters)
if parameters['MODE'] == 'training':
logger.info('Running training.')
train_model(parameters, args.dataset)
elif parameters['MODE'] == 'sampling':
logger.error('Depecrated function. For sampling from a trained model, please run sample_ensemble.py.')
exit(2)
logger.info('Done!')