forked from rwth-i6/returnn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathEngineBase.py
216 lines (189 loc) · 7.51 KB
/
EngineBase.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
"""
Provides :class:`EngineBase`.
"""
from __future__ import print_function
import os
import sys
import typing
from Util import BackendEngine, model_epoch_from_filename, get_model_filename_postfix
from Log import log
from Pretrain import Pretrain
class EngineBase(object):
"""
Base class for a backend engine, such as :class:`TFEngine.Engine`.
"""
_epoch_model = None # type: typing.Optional[typing.Tuple[typing.Optional[int],typing.Optional[str]]] # get_epoch_model() # nopep8
def __init__(self):
self.epoch = 0
self.pretrain = None # type: typing.Optional[Pretrain]
self.model_filename = None # type: typing.Optional[str]
@classmethod
def config_get_final_epoch(cls, config):
"""
:param Config.Config config:
:rtype: int
"""
num_epochs = config.int('num_epochs', 5)
if config.has("load_epoch"):
num_epochs = max(num_epochs, config.int("load_epoch", 0))
return num_epochs
@classmethod
def get_existing_models(cls, config):
"""
:param Config.Config config:
:return: dict epoch -> model filename
:rtype: dict[int,str]
"""
model_filename = config.value('model', '')
if not model_filename:
return []
# Automatically search the filesystem for existing models.
file_list = {}
for epoch in range(1, cls.config_get_final_epoch(config) + 1):
for is_pretrain in [False, True]:
fn = cls.epoch_model_filename(model_filename, epoch, is_pretrain)
if os.path.exists(fn):
file_list[epoch] = fn
break
if BackendEngine.is_tensorflow_selected():
if os.path.exists(fn + ".index"):
file_list[epoch] = fn
break
return file_list
@classmethod
def get_epoch_model(cls, config):
"""
:type config: Config.Config
:returns (epoch, modelFilename)
:rtype: (int|None, str|None)
"""
# XXX: We cache it, although this is wrong if we have changed the config.
if cls._epoch_model:
return cls._epoch_model
start_epoch_mode = config.value('start_epoch', 'auto')
if start_epoch_mode == 'auto':
start_epoch = None
else:
start_epoch = int(start_epoch_mode)
assert start_epoch >= 1
load_model_epoch_filename = config.value('load', '')
if load_model_epoch_filename:
assert os.path.exists(load_model_epoch_filename + get_model_filename_postfix())
import_model_train_epoch1 = config.value('import_model_train_epoch1', '')
if import_model_train_epoch1:
assert os.path.exists(import_model_train_epoch1 + get_model_filename_postfix())
existing_models = cls.get_existing_models(config)
if not load_model_epoch_filename:
load_epoch = config.int("load_epoch", -1)
if load_epoch > 0: # ignore if load_epoch == 0
assert load_epoch in existing_models
load_model_epoch_filename = existing_models[load_epoch]
assert model_epoch_from_filename(load_model_epoch_filename) == load_epoch
# Only use this when we don't train.
# For training, we first consider existing models before we take the 'load' into account when in auto epoch mode.
# In all other cases, we use the model specified by 'load'.
if load_model_epoch_filename and (config.value('task', 'train') != 'train' or start_epoch is not None):
epoch = model_epoch_from_filename(load_model_epoch_filename)
if config.value('task', 'train') == 'train' and start_epoch is not None:
# Ignore the epoch. To keep it consistent with the case below.
epoch = None
epoch_model = (epoch, load_model_epoch_filename)
# In case of training, always first consider existing models.
# This is because we reran RETURNN training, we usually don't want to train from scratch
# but resume where we stopped last time.
elif existing_models:
epoch_model = sorted(existing_models.items())[-1]
if load_model_epoch_filename:
print("note: there is a 'load' which we ignore because of existing model", file=log.v4)
elif config.value('task', 'train') == 'train' and import_model_train_epoch1 and start_epoch in [None, 1]:
epoch_model = (0, import_model_train_epoch1)
# Now, consider this also in the case when we train, as an initial model import.
elif load_model_epoch_filename:
# Don't use the model epoch as the start epoch in training.
# We use this as an import for training.
epoch_model = (model_epoch_from_filename(load_model_epoch_filename), load_model_epoch_filename)
else:
epoch_model = (None, None)
if start_epoch == 1:
if epoch_model[0]: # existing model
print("warning: there is an existing model: %s" % (epoch_model,), file=log.v4)
epoch_model = (None, None)
elif (start_epoch or 0) > 1:
if epoch_model[0]:
if epoch_model[0] != start_epoch - 1:
print("warning: start_epoch %i but there is %s" % (start_epoch, epoch_model), file=log.v4)
epoch_model = start_epoch - 1, existing_models[start_epoch - 1]
cls._epoch_model = epoch_model
return epoch_model
@classmethod
def get_train_start_epoch_batch(cls, config):
"""
We will always automatically determine the best start (epoch,batch) tuple
based on existing model files.
This ensures that the files are present and enforces that there are
no old outdated files which should be ignored.
Note that epochs start at idx 1 and batches at idx 0.
:type config: Config.Config
:returns (epoch,batch)
:rtype (int,int)
"""
start_batch_mode = config.value('start_batch', 'auto')
if start_batch_mode == 'auto':
start_batch_config = None
else:
start_batch_config = int(start_batch_mode)
last_epoch, _ = cls.get_epoch_model(config)
if last_epoch is None:
start_epoch = 1
start_batch = start_batch_config or 0
elif start_batch_config is not None:
# We specified a start batch. Stay in the same epoch, use that start batch.
start_epoch = last_epoch
start_batch = start_batch_config
else:
# Start with next epoch.
start_epoch = last_epoch + 1
start_batch = 0
return start_epoch, start_batch
@classmethod
def epoch_model_filename(cls, model_filename, epoch, is_pretrain):
"""
:type model_filename: str
:type epoch: int
:type is_pretrain: bool
:rtype: str
"""
if sys.platform == "win32" and model_filename.startswith("/tmp/"):
import tempfile
model_filename = tempfile.gettempdir() + model_filename[len("/tmp"):]
return model_filename + (".pretrain" if is_pretrain else "") + ".%03d" % epoch
def get_epoch_model_filename(self, epoch=None):
"""
:param int|None epoch:
:return: filename, excluding TF specific postfix
:rtype: str
"""
if not epoch:
epoch = self.epoch
return self.epoch_model_filename(self.model_filename, epoch, self.is_pretrain_epoch(epoch=epoch))
def get_epoch_str(self):
"""
:return: e.g. "epoch 3", or "pretrain epoch 5"
:rtype: str
"""
return ("pretrain " if self.is_pretrain_epoch() else "") + "epoch %s" % self.epoch
def is_pretrain_epoch(self, epoch=None):
"""
:param int|None epoch:
:return: whether this epoch is covered by the pretrain logic
:rtype: bool
"""
if not epoch:
epoch = self.epoch
return self.pretrain and epoch <= self.pretrain.get_train_num_epochs()
def is_first_epoch_after_pretrain(self):
"""
:return: whether the current epoch is the first epoch right after pretraining
:rtype: bool
"""
return self.pretrain and self.epoch == self.pretrain.get_train_num_epochs() + 1