Skip to content

Commit

Permalink
cleanup and added load_weights and save_weights to gan itself
Browse files Browse the repository at this point in the history
  • Loading branch information
Akatuoro committed Oct 16, 2018
1 parent ab0ba88 commit 5f75069
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 39 deletions.
88 changes: 49 additions & 39 deletions IconGenerator.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,30 @@
"#gan = WGANGP(desired_shape, architecture='resnet')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Optional: Loading weights\n",
"# For this to work, the architecture of the GAN and weights have to match.\n",
"# The image shape ('desired_shape') also has to match.\n",
"\n",
"#gan.load_weights('weights_g.h5', 'weights_d.h5')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training\n",
"\n",
"During training, samples created by the current generator get saved in the given interval. The default path is `images/`.\n",
"\n",
"(If weights are loaded, this is optional)"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -131,6 +155,7 @@
"metadata": {},
"outputs": [],
"source": [
"# Plot losses\n",
"fig, axs = plt.subplots(3,1)\n",
"axs[0].plot(d_loss)\n",
"axs[1].plot(d_acc)\n",
Expand All @@ -143,31 +168,15 @@
"metadata": {},
"outputs": [],
"source": [
"r, c = 4, 6\n",
"noise = np.random.normal(0, 1, (r * c+1, 100))\n",
"\n",
"gen_imgs = gan.generator.predict(noise)\n",
"gen_imgs = 0.5 * gen_imgs + 0.5\n",
"\n",
"plot(gen_imgs, mode, r, c)"
"# Save the weights (optional if no good results ;-) )\n",
"gan.save_weights('weights_g.h5', 'weights_d.h5')"
]
},
{
"cell_type": "code",
"execution_count": null,
"cell_type": "markdown",
"metadata": {},
"outputs": [],
"source": [
"r,c,steps = 4,6,32\n",
"interpol = interpolation(gan,r*c,steps)\n",
"anim = animation(interpol,mode,r,c,steps)\n",
"HTML(anim.to_jshtml())\n",
"\n",
"## For saving the animation:\n",
"#anim.save('line.gif', dpi=80, writer='imagemagick')\n",
"# with reflection:\n",
"#animation(interpol + interpol[::-1], mode,r,c,steps*2).save('line.gif', dpi=80, writer='imagemagick')\n",
"# if it's not looping, use 'convert line.gif -loop 0 anim.gif' (using imagemagick)"
"## Generate Icons"
]
},
{
Expand All @@ -176,38 +185,39 @@
"metadata": {},
"outputs": [],
"source": [
"animation(interpol + interpol[::-1], mode,r,c,steps*2).save('anim.gif', dpi=80, writer='pillow')\n",
"gan.generator_model.save_weights('weights_g.h5')\n",
"gan.critic_model.save_weights('weights_d.h5')"
"# Generate some random samples\n",
"r, c = 4, 6 # number of rows and columns\n",
"noise = np.random.normal(0, 1, (r * c+1, 100))\n",
"\n",
"gen_imgs = gan.generator.predict(noise)\n",
"gen_imgs = 0.5 * gen_imgs + 0.5 # generator output is between [-1,1], but it needs to be between [0,1]\n",
"\n",
"plot(gen_imgs, mode, r, c)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
"source": [
"# Walk from a random point in latent space in a random direction\n",
"r,c,steps = 4,6,32 # number of rows, columns and steps to walk in latent space\n",
"interpol = interpolation(gan,r*c,steps)\n",
"anim = animation(interpol,mode,r,c,steps)\n",
"HTML(anim.to_jshtml())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
"source": [
"# The animation is saved forwards and backwards.\n",
"# For looping, use 'convert anim.gif -loop 0 anim_looping.gif' (using imagemagick)\n",
"animation(interpol + interpol[::-1], mode,r,c,steps*2).save('anim.gif', dpi=80, writer='pillow')"
]
},
{
"cell_type": "code",
Expand Down
8 changes: 8 additions & 0 deletions gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,14 @@ def sample_images(self, epoch):
fig.savefig(self.save_path + "%d.png" % epoch)
plt.close()

def load_weights(self, g_weights, d_weights):
self.generator.load_weights(g_weights)
self.discriminator.load_weights(d_weights)

def save_weights(self, g_weights, d_weights):
self.generator.save_weights(g_weights)
self.discriminator.save_weights(d_weights)


if __name__ == '__main__':
(X_train,_), (_,_) = mnist.load_data()
Expand Down
8 changes: 8 additions & 0 deletions wgangp.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,14 @@ def train(self, X_train, epochs, batch_size, sample_interval=50):

return d_losses, d_acc, g_losses

def load_weights(self, g_weights, d_weights):
self.generator_model.load_weights(g_weights)
self.critic_model.load_weights(d_weights)

def save_weights(self, g_weights, d_weights):
self.generator_model.save_weights(g_weights)
self.critic_model.save_weights(d_weights)


if __name__ == '__main__':
wgan = WGANGP()
Expand Down

0 comments on commit 5f75069

Please sign in to comment.