Skip to content

Commit

Permalink
Merge pull request #13 from yuanzhaoYZ/master
Browse files Browse the repository at this point in the history
Added tensorboard support & a few other minor improvements
  • Loading branch information
github-pengge authored Dec 25, 2017
2 parents 8f90b7e + 6aac40f commit 8337fc9
Show file tree
Hide file tree
Showing 12 changed files with 1,090 additions and 866 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
*.pyc
.DS_Store
.idea/
26 changes: 21 additions & 5 deletions README.md
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
# Progressive growing of GANs
# Face Aging with Progressive growing of GANs
PyTorch implementation of [Progressive Growing of GANs for Improved Quality, Stability, and Variation](http://arxiv.org/abs/1710.10196).

## How to create CelebA-HQ dataset
I borrowed `h5tool.py` from [official code](https://github.com/tkarras/progressive_growing_of_gans). To create CelebA-HQ dataset, we have to download the original CelebA dataset, and the additional deltas files from [here](https://drive.google.com/open?id=0B4qLcYyJmiz0TXY1NG02bzZVRGs). After that, run
```
python2 h5tool.py create_celeba_hq file_name_to_save /path/to/celeba_dataset/ /path/to/celeba_hq_deltas
```

This is what I used on my laptop
```
python2 h5tool.py create_celeba_hq /Users/yuan/Downloads/CelebA-HQ /Users/yuan/Downloads/CelebA/Original\ CelebA/ /Users/yuan/Downloads/CelebA/CelebA-HQ-Deltas
```
I found that MD5 checking were always failed, so I just commented out the MD5 checking part([LN 568](https://github.com/github-pengge/PyTorch-progressive_growing_of_gans/blob/master/h5tool#L568) and [LN 589](https://github.com/github-pengge/PyTorch-progressive_growing_of_gans/blob/master/h5tool#L589))

With default setting, it took 1 day on my server. You can specific `num_threads` and `num_tasks` for accleration.
Expand All @@ -16,7 +19,14 @@ You have to create CelebA-HQ dataset first, please follow the instructions above

To obtain the similar results in `samples` directory, see `train_no_tanh.py` or `train.py` scipt for details(with default options). Both should work well. For example, you could run
```
python train.py --gpu 0 --train_kimg 600 --transition_kimg 600 --lr 1e-3 --beta1 0 --beta2 0.99 --gan lsgan --first_resol 4 --target_resol 256 --no_tanh
conda create -n pytorch_p36 python=3.6 h5py matplotlib
source activate pytorch_p36
conda install pytorch torchvision -c pytorch
conda install scipy
pip install tensorflow
#0=first gpu, 1=2nd gpu ,2=3rd gpu etc...
python train.py --gpu 0,1,2 --train_kimg 600 --transition_kimg 600 --beta1 0 --beta2 0.99 --gan lsgan --first_resol 4 --target_resol 256 --no_tanh
```

`train_kimg`(`transition_kimg`) means after seeing `train_kimg * 1000`(`transition_kimg * 1000`) real images, switching to fade in(stabilize) phase. Currently only support LSGAN and GAN with `--no_noise` option, since WGAN-GP is unavailable, `--drift` option does not affect the result. `--no_tanh` means do not use `tanh` at generator's output layer.
Expand All @@ -26,6 +36,11 @@ If you are Python 2 user, You'd better add this to the top of `train.py` since I
from __future__ import print_function
```


Tensorboard
```
tensorboard --logdir='./logs'
```
## Update history

* **Update(20171213)**: Update `data.py`, now when fading in, real images are weighted combination of current resolution images and 0.5x resolution images. This weighting trick is similar to the one used in Generator's outputs or Discriminator's inputs. This helps stabilize when fading in.
Expand Down Expand Up @@ -73,6 +88,7 @@ from __future__ import print_function

* **Update(20171111)**: It's still under implementation. I did not care design the structure, and now I had to reimplement(phase='fade in' is hard to implement under current structure). I also fixed some bugs, since reimplementation is needed, I do not plan to pull requests at this moment.

# Official implementation
Official implementation using lasagne can ben found at [tkarras/progressive_growing_of_gans](https://github.com/tkarras/progressive_growing_of_gans).
# Reference implementation
* https://github.com/github-pengge/PyTorch-progressive_growing_of_gans


Empty file modified began.py
100644 → 100755
Empty file.
4 changes: 2 additions & 2 deletions debug.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import sys
sys.path.append('./models')
sys.path.append('./utils')
from model import *
from data import CelebA
from models.model import *
from utils.data import CelebA


G = Generator(num_channels=3, resolution=1024, fmap_max=512, fmap_base=8192, latent_size=512)
Expand Down
Empty file modified h5tool.py
100644 → 100755
Empty file.
Loading

0 comments on commit 8337fc9

Please sign in to comment.