Skip to content
This repository was archived by the owner on Nov 30, 2023. It is now read-only.

google-research/vdm

Repository files navigation

Variational Diffusion Models

Jax/Flax Code for reproducing some key results of Variational Diffusion Models (https://arxiv.org/abs/2107.00630).

Standalone Colabs

At colab/SimpleDiffusionColab.ipynb you will find an independent and stand-alone Colab implementation of a Variational Diffusion Model (VDM), serving as an easy-to-understand demonstration of the code and principles behind the paper. Link to open in Colab. (Thanks a lot to Alex Alemi and Ben Poole for this implementation.)

At colab/2D_VDM_Example.ipynb you will find an even more basic implementation, on a 2D swirl dataset and using MLPs. Link to open in Colab.

Setup: Installing required libraries

This code was tested on a TPU-v3 machine. For instructions on how to launch such a machine, see 'Setting up a v3-8 machine' below.

To install the required libraries on a TPU machine:

pip3 install -U pip
sudo pip uninstall -y jax jaxlib
pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip3 install --upgrade -r requirements.txt

Alternatively, on a GPU machine, this should work:

pip3 install -U pip
pip3 install --upgrade jax jaxlib
pip3 install --upgrade -r requirements.txt

Train/evaluate: CIFAR-10 without data augmentation

The commands below assume that the code is checked out at the ./vdm directory, such that this README is located at `./vdm/README'.

To evaluate from a pre-trained checkpoint:

python3 -m vdm.main --mode=eval --config=vdm/configs/cifar10.py --workdir=[workdir] --checkpoint='gs://gresearch/vdm/cifar10/checkpoint-final/checkpoints-0'

where [workdir] is a directory to write results to, such as '/tmp/vdm-workdir', or a Google Cloud Storage address (gs://[your-address]). Running the command above will print out a bunch of statistics, including eval_bpd=2.637, which matches the result in the paper (2.64).

To train:

python3 -m vdm.main --config=vdm/configs/cifar10.py --workdir=[workdir]

CIFAR-10 with data augmentation

We also provide code for training a continuous-time VDM of CIFAR-10 with data augmentation. The model has some minor differences with the one described in the paper, but achieves similar performance.

To evaluate:

python3 -m vdm.main --mode=eval --config=vdm/configs/cifar10_aug.py --workdir=[workdir] --checkpoint='gs://gresearch/vdm/cifar10/aug-checkpoint-final/checkpoints-0'

This reports a bpd results slightly worse than the paper: eval_bpd=2.522, versus the paper's result of 2.49. We suspect this is due to small the fact that this open source version of the model was trained on a smaller batch size, in addition to a small difference in the score network implementation, namely how it is conditioned.

To train:

python3 -m vdm.main --config=vdm/configs/cifar10_aug.py --workdir=[workdir]

At the time of writing, training this model is too slow and memory-intensive to run on a single TPU or GPU machine. Therefore, we run on a v3-64 TPU pod, as explained below.

(Optional) Setting up a v3-8 machine on Google Cloud Engine (GCE)

This code was tested on a v3-8 TPU machine on Google Cloud Engine (GCE). The machine was created using:

gcloud alpha compute tpus tpu-vm create --project=[project] --zone=[zone] [machine-name] --accelerator-type=v3-8 --version=v2-alpha

SSH to the instance:

gcloud alpha compute tpus tpu-vm ssh --project=[project] --zone=[zone] [machine-name]

Then, we copied (or cloned through git) the code to a directory ~/vdm on the machine. Then, on the instance, we installed libraries:

cd ~; bash vdm/sh/setup-vm-tpuv3.sh

Training on a TPU pod on GCE

Create the machine with:

gcloud alpha compute tpus tpu-vm create --project=[project] --zone=[zone] [machine-name] --accelerator-type=v3-64 --version=v2-alpha

Copy (or clone through git) the code to a directory ~/vdm on the machine. Then, install libraries:

gcloud alpha compute tpus tpu-vm ssh --project=[project] --zone=[zone] [machine-name] --worker=all --command="$(<~/vdm/sh/setup-tpu-machine-v3.sh)"

Start training with:

gcloud alpha compute tpus tpu-vm ssh --project=[project] --zone=[zone] [machine-name] --worker=all --command="~/.local/bin/ipython vdm/main.py -- --workdir=[dir_to_save_logs] --config=vdm/configs/cifar10_aug.py"

Stop training on all machines by running:

gcloud alpha compute tpus tpu-vm ssh --project=[project] --zone=[zone] [machine-name] --worker=all --command="killall main.py"

Acknowledgments

We thank Ruiqi Gao for substantial contributions to this public VDM codebase, and to Alex Alemi and Ben Poole for implementing colab/SimpleDiffusionColab.ipynb.

Disclaimer

This is not an officially supported Google product.

About

No description, website, or topics provided.

Resources

License

Contributing

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 2

  •  
  •