This is the training code for the Jax/Flax implementation of StyleGAN2.
- Getting Started
- Preparing Datasets for Training
- Training
- Checkpoints
- Generate Images
- Samples
- Original Checkpoints
- References
- License
You will need Python 3.7 or later.
- Clone the repository:
> git clone https://github.com/matthias-wright/flaxmodels.git
- Go into the directory:
> cd flaxmodels/training/stylegan2
- Install Jax with CUDA.
- Install requirements:
> pip install -r requirements.txt
Before training, the images should be stored in a TFRecord dataset. The TFRecord format stores your data as a sequence of bytes, which allows for fast data loading.
Alternatively, you can also use tfds.folder_dataset.ImageFolder on the image directory directly but you will have to replace the tf.data.TFRecordDataset
in data_pipeline.py
with tfds.folder_dataset.ImageFolder
(see this thread for more info).
- Download the cropped and aligned images. Alternatively, you can also download the thumbnails at 128x128 resolution.
- Unzip and
cd
into the extracted directory. - Move the images from the subdirectories into the main directory (because there aren't any labels):
> find . -mindepth 2 -type f -print -exec mv {} . \;
- Remove empty subdirectories:
> rm -r */
- Create TFRecord dataset:
> python dataset_utils/images_to_tfrecords.py --image_dir /path/to/image_dir/ --data_dir /path/to/tfrecord
--image_dir
is the path to the image directory.
--data_dir
is the path where the TFRecord dataset is stored.
- Download the images from here:
> rsync --verbose --recursive rsync://78.46.86.149:873/biggan/portraits/ ./portraits/
- Many of the images in this dataset have black borders. These can be mostly removed with this command:
> python dataset_utils/crop_image_borders.py --image_dir ./portraits/
- Create TFRecord dataset:
> python dataset_utils/images_to_tfrecords.py --image_dir ./portraits/ --data_dir /path/to/tfrecord
--image_dir
is the path to the image directory.
--data_dir
is the path where the TFRecord dataset is stored.
I am assuming that your dataset is an image folder containing JPEG or PNG files (with or without label subfolders). If you have labels, your image folder should have the following structure:
/path/to/image_dir/
label0/
0.jpg
1.jpg
...
label1/
a.jpg
b.jpg
c.jpg
...
...
If you don't have labels, your image folder should look like this:
/path/to/image_dir/
0.jpg
1.jpg
2.jpg
4.jpg
...
Create a TFRecord dataset dataset from the image folder:
> python dataset_utils/images_to_tfrecords.py --image_dir /path/to/image_dir/ --data_dir /path/to/tfrecord
If you have labels, additionally use the --has_labels
flag. The TFRecord dataset will be stored at /path/to/tfrecord
.
TensorFlow Datasets also has many datasets to choose from. You will have to replace the tf.data.TFRecordDataset
in data_pipeline.py
with that dataset.
To start training with default arguments:
> CUDA_VISIBLE_DEVICES=a,b,c,d python main.py --data_dir /path/to/tfrecord
Here a
, b
, c
, d
are the GPU indices. Multi GPU training (data parallelism) works by default and will automatically use all the devices that you make visible.
To reproduce the results on FFHQ and Danbooru2019 Portraits with resolution 512x512:
> CUDA_VISIBLE_DEVICES=a,b,c,d python main.py --data_dir /path/to/tfrecord --resolution 512 --batch_size 8 --learning_rate 0.0025 --r1_gamma 0.5 --mbstd_group_size 8 --fmap_base 32768 --wandb
Some guidelines for choosing the hyperparameters (taken from here):
256x256 | 512x512 | 1024x1024 | |
---|---|---|---|
batch_size | 8 | 8 | 4 |
fmap_base | 16384 | 32768 | 32768 |
learning_rate | 0.0025 | 0.0025 | 0.002 |
r1_gamma | 1.0 | 0.5 | 2.0 or 10.0 |
ema_kimg | 20 | 20 | 10 |
mbstd_group_size | 8 | 8 | 4 |
I use Weights & Biases for logging but you can simply replace it with the logging method of your choice. The logging happens all in the training loop implemented in training.py
. To use logging with Weights & Biases, use --wand
. The Weights & Biases logging can be configured in line 60 of main.py
.
By default, every 1000
training steps the FID score is evaluated for 10.000
images. The checkpoint with the highest FID score is saved. You can change evaluation frequency using the --eval_fid_every
argument and the number of images to evaluate the FID score on using --num_fid_images
.
You can disable the FID score evaluation using --disable_fid
. In that case, a checkpoint will be saved every 2000
steps (can be changed using --save_every
).
Mixed precision training is implemented and can be activated using --mixed_precision
. However, at the moment it is not stable so I don't recommend using it until further notice.
I have implemented all the mixed precision tricks from the original StyleGAN2 implementation (casting to float32 for some operations, using pre-normalization in the modulated conv layer, only using float16 for the higher resolutions, clipping the output of the convolution layers, etc).
Dynamic loss scaling is also implemented with dynamic_scale_lib.DynamicScale.
I will look into it. If you figure it out, you are more than welcome to submit a PR.
I have trained StyleGAN2 from scratch on FFHQ and Danbooru2019 Portraits, both at resolution 512x512.
- FFHQ at 512x512 (922,1 MB)
- FFHQ at 256x256 (828,4 MB)
- Danbooru2019 Portraits at 512x512 (922,1 MB)
Generate Images:
- Download checkpoint:
> wget https://www.dropbox.com/s/qm7ot3k81wlhh6m/ffhq_512x512.pickle
- Generate images, one image will be generated for each seed:
> CUDA_VISIBLE_DEVICES=0 python generate_images.py --ckpt_path ffhq_512x512.pickle --seeds 1 2 3 4
Style Mixing:
- Download checkpoint:
> wget https://www.dropbox.com/s/qm7ot3k81wlhh6m/ffhq_512x512.pickle
- Generate style mixing grid:
> CUDA_VISIBLE_DEVICES=0 python style_mixing.py --ckpt_path ffhq_512x512.pickle --row_seeds 1 2 3 4 --col_seeds 5 6 7 8
The implementation is compatible with the pretrained weights from NVIDIA. To generate images in Jax/Flax using the original checkpoints from NVIDIA, go here.