diff --git a/train.py b/train.py index 41db4b40..66ebf38f 100755 --- a/train.py +++ b/train.py @@ -73,7 +73,7 @@ def _map_fn_train(img): def train(): G = get_G((batch_size, 96, 96, 3)) D = get_D((batch_size, 384, 384, 3)) - VGG = tl.models.vgg19(pretrained=False, end_with='pool4', mode='static') + VGG = tl.models.vgg19(pretrained=True, end_with='pool4', mode='static') lr_v = tf.Variable(lr_init) g_optimizer_init = tf.optimizers.Adam(lr_v, beta_1=beta1)