Skip to content

Commit

Permalink
fix keras_vggface comaptibility with keras 2.2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
shaoanlu committed Nov 8, 2018
1 parent aa4fb3f commit 1f2df8f
Show file tree
Hide file tree
Showing 2 changed files with 533 additions and 3 deletions.
25 changes: 22 additions & 3 deletions colab_demo/faceswap-GAN_colab_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,16 @@
"model = FaceswapGANModel(**arch_config)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%capture\n",
"!wget https://github.com/rcmalli/keras-vggface/releases/download/v2.0/rcmalli_vggface_tf_notop_resnet50.h5"
]
},
{
"cell_type": "code",
"execution_count": 57,
Expand All @@ -574,10 +584,17 @@
}
],
"source": [
"from keras_vggface.vggface import VGGFace\n",
"#from keras_vggface.vggface import VGGFace\n",
"\n",
"# VGGFace ResNet50\n",
"vggface = VGGFace(include_top=False, model='resnet50', input_shape=(224, 224, 3))\n",
"#vggface = VGGFace(include_top=False, model='resnet50', input_shape=(224, 224, 3))'\n",
"\n",
"from .vggface_models import RESNET50\n",
"vggface = RESNET50(include_top=False, weights=None, input_shape=(224, 224, 3))\n",
"vggface.load_weights(\"rcmalli_vggface_tf_notop_resnet50.h5\")\n",
"\n",
"#from keras.applications.resnet50 import ResNet50\n",
"#vggface = ResNet50(include_top=False, input_shape=(224, 224, 3))\n",
"\n",
"#vggface.summary()\n",
"\n",
Expand Down Expand Up @@ -667,7 +684,9 @@
" K.clear_session()\n",
" model = FaceswapGANModel(**arch_config)\n",
" model.load_weights(path=save_path)\n",
" vggface = VGGFace(include_top=False, model='resnet50', input_shape=(224, 224, 3))\n",
" #vggface = VGGFace(include_top=False, model='resnet50', input_shape=(224, 224, 3))\n",
" vggface = RESNET50(include_top=False, weights=None, input_shape=(224, 224, 3))\n",
" vggface.load_weights(\"rcmalli_vggface_tf_notop_resnet50.h5\")\n",
" model.build_pl_model(vggface_model=vggface, before_activ=loss_config[\"PL_before_activ\"])\n",
" train_batchA = DataLoader(train_A, train_AnB, batchSize, img_dirA_bm_eyes,\n",
" RESOLUTION, num_cpus, K.get_session(), **da_config)\n",
Expand Down
Loading

0 comments on commit 1f2df8f

Please sign in to comment.