Official code for the "Shortest-path Diffusion" developed by MediaTek Research, accepted at International Conference on Machine Learning 2023.
Image generation with shortest path diffusion
Ayan Das*, Stathi Fotiadis*, Anil Batra, Farhang Nabiei, FengTing Liao, Sattar Vakili, Da-shan Shiu, Alberto Bernaccia
MediaTek Research, Cambourne UK
(* Equal Contributions)
The field of image generation has made significant progress thanks to the introduction of Diffusion Models, which learn to progressively reverse a given image corruption. Recently, a few studies introduced alternative ways of corrupting images in Diffusion Models, with an emphasis on blurring. However, these studies are purely empirical and it remains unclear what is the optimal procedure for corrupting an image. In this work, we hypothesize that the optimal procedure minimizes the length of the path taken when corrupting an image towards a given final state. We propose the Fisher metric for the path length, measured in the space of probability distributions. We compute the shortest path according to this metric, and we show that it corresponds to a combination of image sharpening, rather than blurring, and noise deblurring. While the corruption was chosen arbitrarily in previous work, our Shortest Path Diffusion (SPD) determines uniquely the entire spatiotemporal structure of the corruption. We show that SPD improves on strong baselines without any hyperparameter tuning, and outperforms all previous Diffusion Models based on image blurring. Furthermore, any small deviation from the shortest path leads to worse performance, suggesting that SPD provides the optimal procedure to corrupt images. Our work sheds new light on observations made in recent works and provides a new approach to improve diffusion models on images and other types of data.
Please note that this codebase is built on the publicly available implementation of OpenAI's "Guided Diffusion". Below we provide instructions for downloading data, training the model and sampling from it.
NOTE: Running this code requires at least one GPU available on the system.
Please use the datasets/cifar10.py
script to download the CIFAR10 dataset at a directory using the following command
python datasets/cifar10.py </tmp/dir/>
Use the argument --data_dir /tmp/dir/cifar_train/
in all our scripts to use the dataset.
The entire codebase is written as a python package, hence you need to run ..
pip install -e .
.. in order to install the package. This will also install all necessary dependencies (require internet connection).
First, create a reference batch of 50000 images for FID computation. Run the following script pointing to the data directory
python evaluations/create_ref_batch.py </tmp/dir/cifar_train/>
.. which creates a file named cifar10_reference_50000x32x32x3.npz
, which we will be required for going forward.
We provide the code for training both our SPD model and the original DDPM model. To train our proposed SPD model, run the following
python scripts/image_fourier_train.py --config ./configs/cifar10_fourier.yml --data_dir </tmp/dir/cifar_train/> --reference_batch_path ./cifar10_reference_50000x32x32x3.npz --output_dir ./logs --exp_name my_training --debug True --batch_size 1024 --num_samples 50000 --diffusion_steps 4000
- Please make sure to use the
--debug True
flag for running in a non-distributed setting, otherwise use torchrun appropriately. See below for an example. - Use appropriate
--batch_size xx
depending on the GPU memeory size & no. of GPUs used (if distributed). - You may use on-the-fly FID computation with
--num_samples xx
but we discourage doing so due to it's time-consuming nature. We recommend a training-only run with--num_samples 0
followed by separate sampling run.
An important argument for training SPD is the
--diffusion_steps xx
which setsT
, the total number of diffusion steps. Use this argument with the training as well as the sampling script (explained below) to produce the results in the paper.
The training process will produce EMA-checkpoints on certain interval (configurable with --save_interval xx
) inside the ./logs/my_training/rank_x
folder. Choose a checkpoint, e.g. ./logs/my_training/rank_x/<checkpoint>.pt
and run the sampling as explained below
python scripts/image_fourier_sample.py --config ./configs/cifar10_fourier.yml --output_dir ./logs --exp_name sampling --debug True --batch_size 128 --model_path ./logs/my_training/rank_0/ema_0.9999_000000.pt --num_samples 50000
This will create an .npz
file containin samples from the model provided as checkpoint with --model_path ./logs/my_training/rank_0/<checkpoint>.pt
. It will also compute the FID and display it. Please note that you must have internet connection in order to download the inception weight necessary for FID computation.
The original DDPM and DDIM implementation is also provided for the sake of completeness. The training and sampling process is exactly same as explained above with, only the name of the scripts change. Use the following scripts for DDPM training
python scripts/image_train.py --config ./configs/cifar10_fourier.yml ... <arguments>
and sampling
python scripts/image_sample.py --config ./configs/cifar10_fourier.yml ... <arguments>
Please note that the every script reads all necessary hyperparameters from various sections of the ./configs/cifar10_fourier.yml
config file.
You may also use DDIM sampler in the sampling script in the same manner as explained here
python scripts/image_sample.py ... <arguments> --use_ddim True --timestep_respacing ddimXXX
where XXX
is the number of desired steps.
In most realistic cases, you would need multiple GPUs to run any of the above commands in a distributed setting. torchrun, which comes with modern PyTorch, makes it easy to execute distributed training/sampling. To use torchrun
, simply do the following on a N
-GPU machine
torchrun --nnodes 1 --nproc_per_node <N> \
--no_python \
python <script.py> ... <arguments> ... \
-- batch_size 256
If you have an N = 4
GPU node, this run will use an effective batch size of N * 256 = 1024
for both training and sampling.
@inproceedings{das2023spdiffusion,
title={Image generation with shortest path diffusion},
author={Ayan Das and
Stathi Fotiadis and
Anil Batra and
Farhang Nabiei and
FengTing Liao and
Sattar Vakili and
Da-shan Shiu and
Alberto Bernaccia
},
booktitle={International Conference on Machine Learning},
year={2023}
}