Skip to content

Latest commit

 

History

History
41 lines (26 loc) · 1.33 KB

README.md

File metadata and controls

41 lines (26 loc) · 1.33 KB

Jax DDIM

Jax/Flax implementation of Denoising Diffusion Implicit Models

DDIM implementation following the keras example of Denoising Diffusion Implicit Models

Setup

Main dependencies

  • jax==0.3.14
  • flax==0.5.2
  • tensorflow==2.9.1
  • tensorflow-datasets==4.6.0
  • tensorboard==2.9.1

For instance, I recommend to use GCP Vertex Workbench (managed JupyterLab environment) with GPU accelerator. Vertex Workbench offers GPU environment and popular deep learning libraries.

Run experiment

Run train.py or train.ipynb. Trained model and Tensorboard logs are saved under outputs directory by default.

According to the Keras example, it is better to train at least 50 epochs for good results.

python train.py \
--epoch 50 \
<other arguments ...>

Results

Training loss and generated images for 50 epochs:

losses

images

Notes

This implementation follows the Keras example implementation. You can check the detailed tips and discussion here