-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
executable file
·125 lines (104 loc) · 4.88 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
import json
from distutils.version import LooseVersion
from sacred import Experiment
from tqdm import trange
import tensorflow as tf
from typing import List
import string
try:
import better_exceptions
except ImportError:
pass
from tf_attention.model import crnn_attention_fn
from tf_attention.data_loader import data_loader, serving_single_input
from tf_attention.config import Params, TrainingParams
ex = Experiment('CRNN_attention_experiment')
def distribution_gpus(num_gpus):
if num_gpus == 1:
return tf.contrib.distribute.OneDeviceStrategy(device=tf.DeviceSpec(device_type="GPU", device_index=0))
elif num_gpus > 1:
return tf.contrib.distribute.MirroredStrategy(num_gpus=num_gpus)
else:
return None
@ex.config
def default_config():
csv_files_train = None
csv_files_eval = None
output_model_dir = None
num_gpus = 1
lookup_hangul_file = ''
input_shape = (32, 100)
training_params = TrainingParams().to_dict()
restore_model = False
csv_delimiter = ','
string_split_delimiter = '|'
data_augmentation = True
data_augmentation_max_rotation = 0.05
input_data_n_parallel_calls = 4
@ex.automain
def run(csv_files_train: List[str], csv_files_eval: List[str], output_model_dir: str, training_params: dict, _config):
# Save config
if not os.path.isdir(output_model_dir):
os.makedirs(output_model_dir)
else:
assert _config.get('restore_model'), \
'{0} already exists, you cannot use it as output directory. ' \
'Set "restore_model=True" to continue training, or delete dir "rm -r {0}"'.format(output_model_dir)
with open(os.path.join(output_model_dir, 'config.json'), 'w') as f:
json.dump(_config, f, indent=4, sort_keys=True)
parameters = Params(**_config)
training_params = TrainingParams(**training_params)
model_params = {
'Params': parameters,
'TrainingParams': training_params
}
# Create export directory
export_dir = os.path.join(output_model_dir, 'export')
if not os.path.isdir(export_dir):
os.makedirs(export_dir)
# Check if hangul contains all chars in csv input files
discarded_chars = parameters.string_split_delimiter + parameters.csv_delimiter + string.whitespace[1:]
parameters.hangul.check_input_file_hangul(
parameters.csv_files_train + parameters.csv_files_eval,
discarded_chars=discarded_chars,
csv_delimiter=parameters.csv_delimiter)
config_sess = tf.ConfigProto()
# config_sess.gpu_options.per_process_gpu_memory_fraction = 0.9
config_sess.gpu_options.allow_growth = True
# Config estimator
est_config = tf.estimator.RunConfig()
if LooseVersion(tf.__version__) < LooseVersion('1.8'):
est_config.replace(keep_checkpoint_max=10,
save_checkpoints_steps=training_params.save_interval,
session_config=config_sess,
save_checkpoints_secs=None,
save_summary_steps=1000,
model_dir=output_model_dir)
else:
est_config.replace(keep_checkpoint_max=10,
save_checkpoints_steps=training_params.save_interval,
session_config=config_sess,
save_checkpoints_secs=None,
save_summary_steps=1000,
model_dir=output_model_dir,
train_distribute=distribution_gpus(parameters.num_gpus))
estimator = tf.estimator.Estimator(model_fn=crnn_attention_fn,
params=model_params,
model_dir=output_model_dir,
config=est_config
)
for e in trange(0, training_params.n_epochs, training_params.evaluate_every_epoch):
estimator.train(input_fn=data_loader(csv_filename=csv_files_train,
params=parameters,
batch_size=training_params.train_batch_size,
num_epochs=training_params.evaluate_every_epoch,
data_augmentation=parameters.data_augmentation,
image_summaries=True))
estimator.export_savedmodel(export_dir,
serving_input_receiver_fn=serving_single_input(parameters,
fixed_height=parameters.input_shape[0], min_width=10))
estimator.evaluate(input_fn=data_loader(csv_filename=csv_files_eval,
params=parameters,
batch_size=training_params.eval_batch_size,
num_epochs=1))