Skip to content

Commit

Permalink
[ITL-90] Saving method extracted
Browse files Browse the repository at this point in the history
  • Loading branch information
mtyrolski committed Apr 4, 2020
1 parent 24d96dc commit 0e8baf8
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions imagetolatex/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@ def fit_on_layered(self, train_sequences, test_sequences, epochs):
layer_model.train_on_batch(x=train_features, y=train_labels)
layer_model.test_on_batch(x=test_features, y=test_labels)

def _fit_on_layered(self, train_sequences, test_sequences, save=None, epochs=10, verbose=True, **kwargs):
def save(self, path):
for layer_index, layer_model in enumerate(self._layer_models):
layer_model.save(path.format(layer_index))

def _fit_on_layered(self, train_sequences, test_sequences, epochs, verbose, **kwargs):
train_sequences = [_Flatten(train_sequences, i) for i in range(train_sequences.layer_count)]
test_sequences = [_Flatten(test_sequences, i) for i in range(test_sequences.layer_count)]

Expand All @@ -38,12 +42,6 @@ def _fit_on_layered(self, train_sequences, test_sequences, save=None, epochs=10,
**kwargs
)

if not save:
pass

for layer_index, layer_model in enumerate(self._layer_models):
layer_model.save(save.format(layer_index))


from keras.utils import Sequence

Expand Down Expand Up @@ -117,7 +115,8 @@ def complex_equation_layer(input_shape, num_classes, verbose=False):
model._fit_on_layered(
train_sequence,
test_sequence,
save='pretrained/itl{0}.h5',
epochs=5,
epochs=10,
verbose=True
)

model.save('pretrained/itl{0}.h5')

0 comments on commit 0e8baf8

Please sign in to comment.