-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathtrain.py
49 lines (39 loc) · 1.45 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping
from data import load_data, tf_dataset
from model import build_unet
if __name__ == "__main__":
""" Seeding """
np.random.seed(42)
tf.random.set_seed(42)
""" Dataset """
path = "oxford-iiit-pet/"
(train_x, train_y), (valid_x, valid_y), (test_x, test_y) = load_data(path)
print(f"Dataset: Train: {len(train_x)} - Valid: {len(valid_x)} - Test: {len(test_x)}")
""" Hyperparameters """
shape = (256, 256, 3)
num_classes = 3
lr = 1e-4
batch_size = 8
epochs = 10
""" Model """
model = build_unet(shape, num_classes)
model.compile(loss="categorical_crossentropy", optimizer=tf.keras.optimizers.Adam(lr))
train_dataset = tf_dataset(train_x, train_y, batch=batch_size)
valid_dataset = tf_dataset(valid_x, valid_y, batch=batch_size)
train_steps = len(train_x)//batch_size
valid_steps = len(valid_x)//batch_size
callbacks = [
ModelCheckpoint("model.h5", verbose=1, save_best_model=True),
ReduceLROnPlateau(monitor="val_loss", patience=3, factor=0.1, verbose=1, min_lr=1e-6),
EarlyStopping(monitor="val_loss", patience=5, verbose=1)
]
model.fit(train_dataset,
steps_per_epoch=train_steps,
validation_data=valid_dataset,
validation_steps=valid_steps,
epochs=epochs,
callbacks=callbacks
)