-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 219e399
Showing
48 changed files
with
10,929 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# Compiled python modules. | ||
*.pyc | ||
|
||
# Byte-compiled | ||
_pycache__/ | ||
.cache/ | ||
.idea/ | ||
|
||
# Python egg metadata, regenerated from source files by setuptools. | ||
/*.egg-info | ||
.eggs/ | ||
|
||
# PyPI distribution artifacts. | ||
build/ | ||
dist/ | ||
|
||
# Tests | ||
.pytest_cache/ | ||
|
||
# Other | ||
*.DS_Store | ||
|
||
experiments |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
# Maximum Likelihood Training for Score-Based Diffusion ODEs by High-Order Denoising Score Matching (ICML 2022) | ||
|
||
The official code for the paper [Maximum Likelihood Training for Score-Based Diffusion ODEs by High-Order Denoising Score Matching](https://arxiv.org/abs/2206.08265) by Cheng Lu, Kaiwen Zheng, Fan Bao, Jianfei Chen, Chongxuan Li and Jun Zhu, published in ICML 2022. | ||
|
||
The code implementation is based on [score_flow](https://github.com/yang-song/score_flow) by Yang Song. | ||
|
||
-------------------- | ||
|
||
Score-based diffusion models include two types: ScoreSDE and ScoreODE. [Previous work](https://arxiv.org/abs/2101.09258) showed that the weighted combination of first-order score matching losses can upper bound the Kullback–Leibler divergence between the data distribution and the ScoreSDE model distribution. However, the relationship between score matching and ScoreODE is unclear. In this work, we prove that: | ||
|
||
- The model distributions of ScoreSDE and ScoreODE are **always different** if the data distribution is not a Gaussian distribution. | ||
- To upper bound the KL-divergence of ScoreODE, we need first-order, second-order and third-order score matching for the score model. | ||
- We further propose an error-bounded high-order denoising score matching method. The higher-order score matching error can be exactly upper bounded by the training error and the lower-order score matching errors, which enables high-order score matching. | ||
|
||
In short, The previous work [Maximum Likelihood Training of Score-Based Diffusion Models](https://arxiv.org/abs/2101.09258) is a method for maximum likelihood training of **ScoreSDE** (a.k.a. **diffusion SDE**), and our work is a method for maximum likelihood training of **ScoreODE** (a.k.a. **diffusion ODE**). | ||
|
||
## Code Structure | ||
The code implementation is based on [score_flow](https://github.com/yang-song/score_flow) by Yang Song. We further implement the proposed high-order denoising score matching losses in `losses.py`. | ||
|
||
## How to run the code | ||
|
||
### Dependencies | ||
|
||
We use the same denpendencies as [score_flow](https://github.com/yang-song/score_flow). To install the packages, we recommend the ``jaxlib==0.1.69``. You need to find a corresponding version for your python3 version and cuda version at: [https://storage.googleapis.com/jax-releases/jax_cuda_releases.html](https://storage.googleapis.com/jax-releases/jax_cuda_releases.html). For example, to install ``jaxlib==0.1.69`` for `python==3.7` and `cuda==11.1`, you need to firstly download the wheel file: | ||
```sh | ||
wget https://storage.googleapis.com/jax-releases/cuda111/jaxlib-0.1.69+cuda111-cp37-none-manylinux2010_x86_64.whl | ||
``` | ||
and then run the following command to install `jaxlib`: | ||
```sh | ||
pip3 install jaxlib-0.1.69+cuda111-cp37-none-manylinux2010_x86_64.whl | ||
``` | ||
After install `jaxlib`, you need to run to following command to install the other packages: | ||
```sh | ||
pip3 install -r requirements.txt | ||
``` | ||
|
||
### Stats files for quantitative evaluation | ||
|
||
We use the same stats files by [score_flow](https://github.com/yang-song/score_flow) for computing FID and Inception scores for CIFAR-10 and ImageNet 32x32. You can find `cifar10_stats.npz` and `imagenet32_stats.npz` under the directory `assets/stats` in Yang Song's [Google drive](https://drive.google.com/drive/folders/1gbDrVrFVSupFMRoK7HZo8aFgPvOtpmqB?usp=sharing). Download them and save to `assets/stats/` in the code repo. | ||
|
||
### Usage | ||
The running command is the same as [score_flow](https://github.com/yang-song/score_flow). Here are some common options: | ||
|
||
```sh | ||
main.py: | ||
--config: Training configuration. | ||
(default: 'None') | ||
--eval_folder: The folder name for storing evaluation results | ||
(default: 'eval') | ||
--mode: <train|eval>: Running mode: train or eval. We did not train our model by further variational dequantizations. | ||
--workdir: Working directory | ||
``` | ||
|
||
* `config` is the path to the config file. Our config files are provided in `configs/`. They are formatted according to [`ml_collections`](https://github.com/google/ml_collections) and should be quite self-explanatory. | ||
|
||
**Naming conventions of config files**: the name of a config file contains the following attributes: | ||
|
||
* dataset: Either `cifar10` or `imagenet32` | ||
* model: Either `ddpmpp_continuous` or `ddpmpp_deep_continuous` | ||
|
||
* `workdir` is the path that stores all artifacts of one experiment, like checkpoints, samples, and evaluation results. | ||
|
||
* `eval_folder` is the name of a subfolder in `workdir` that stores all artifacts of the evaluation process, like meta checkpoints for supporting pre-emption recovery, image samples, and numpy dumps of quantitative results. | ||
|
||
* `mode` is either "train" or "eval". When set to "train", it starts the training of a new model, or resumes the training of an old model if its meta-checkpoints (for resuming running after pre-emption in a cloud environment) exist in `workdir/checkpoints-meta` . When set to "eval", it can do the following: | ||
|
||
* Compute the log-likelihood on the training or test dataset. | ||
* Compute the lower bound of the log-likelihood on the training or test dataset. | ||
* Evaluate the loss function on the test / validation dataset. | ||
* Generate a fixed number of samples and compute its Inception score, FID, or KID. Prior to evaluation, stats files must have already been downloaded/computed and stored in `assets/stats`. | ||
|
||
These functionalities can be configured through config files, or more conveniently, through the command-line support of the `ml_collections` package. | ||
|
||
### Configurations for high-order denoising score matching | ||
To set the order of the score matching training losses, set `--config.training.score_matching_order` to be `1` (the previous first-order) or `2` or `3`. Note that for third-order score matching training, the batch size needs to turn smaller to avoid OOM. | ||
|
||
### Configurations for evaluation | ||
To generate samples and evaluate sample quality, use the `--config.eval.enable_sampling` flag; to compute log-likelihoods, use the `--config.eval.enable_bpd` flag, and specify `--config.eval.dataset=train/test` to indicate whether to compute the likelihoods on the training or test dataset. Turn on `--config.eval.bound` to evaluate the variational bound for the log-likelihood. Enable `--config.eval.dequantizer` to use variational dequantization for likelihood computation. `--config.eval.num_repeats` configures the number of repetitions across the dataset (more can reduce the variance of the likelihoods; default to 5). | ||
|
||
## Train high-order DSM by pretrained checkpoints | ||
For VESDE on CIFAR-10, we use the pretrained checkpoints by first-order DSM in [score_sde checkpoints](https://drive.google.com/drive/folders/1RAG8qpOTURkrqXKwdAR1d6cU9rwoQYnH?usp=sharing). | ||
|
||
For VESDE on ImageNet32, as score_sde did not provide the checkpoints, we train the first-order model by ourselves, and then train the model by the high-order DSM. | ||
|
||
For VPSDE, we use the pretrained checkpoints by first-order DSM in [score_flow checkpoints](https://drive.google.com/drive/folders/1gbDrVrFVSupFMRoK7HZo8aFgPvOtpmqB?usp=sharing). | ||
|
||
## References | ||
|
||
If you find the code useful for your research, please consider citing | ||
```bib | ||
@inproceedings{lu2022maximum, | ||
title={Maximum Likelihood Training for Score-Based Diffusion ODEs by High-Order Denoising Score Matching}, | ||
author={Lu, Cheng and Zheng, Kaiwen and Bao, Fan and Chen, Jianfei and Li, Chongxuan and Zhu, Jun}, | ||
booktitle={International Conference on Machine Learning}, | ||
year={2022} | ||
organization={PMLR} | ||
} | ||
``` | ||
|
||
This work is built upon some previous papers which might also interest you: | ||
|
||
* Yang Song, Jascha Sohl-Dickstein, Diederik P. Kingma, Abhishek Kumar, Stefano Ermon, and Ben Poole. "Score-Based Generative Modeling through Stochastic Differential Equations". *Proceedings of the 9th International Conference on Learning Representations*, 2021. | ||
* Yang Song, Conor Durkan, Iain Murray, and Stefano Ermon. "Maximum Likelihood Training of Score-Based Diffusion Models". *Advances in Neural Information Processing Systems*, 2021. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,200 @@ | ||
# coding=utf-8 | ||
# Copyright 2020 The Google Research Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# pylint: skip-file | ||
# pytype: skip-file | ||
"""Various sampling methods.""" | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
import numpy as np | ||
|
||
import utils | ||
from utils import batch_mul | ||
from models import utils as mutils | ||
from utils import get_div_fn, get_value_div_fn | ||
|
||
|
||
def get_likelihood_bound_fn(sde, model, inverse_scaler, hutchinson_type='Rademacher', | ||
dsm=True, eps=1e-5, N=1000, importance_weighting=True, | ||
eps_offset=True): | ||
"""Create a function to compute the unbiased log-likelihood bound of a given data point. | ||
Args: | ||
sde: A `sde_lib.SDE` object that represents the forward SDE. | ||
model: A `flax.linen.Module` object that represents the architecture of the score-based model. | ||
inverse_scaler: The inverse data normalizer. | ||
hutchinson_type: "Rademacher" or "Gaussian". The type of noise for Hutchinson-Skilling trace estimator. | ||
dsm: bool. Use denoising score matching bound if enabled; otherwise use sliced score matching. | ||
eps: A `float` number. The probability flow ODE is integrated to `eps` for numerical stability. | ||
N: The number of time values to be sampled. | ||
importance_weighting: True if enable importance weighting for potential variance reduction. | ||
eps_offset: True if use Jensen's inequality to offset the likelihood bound due to non-zero starting time. | ||
Returns: | ||
A function that takes random states, replicated training states, and a batch of data points | ||
and returns the log-likelihoods in bits/dim, the latent code, and the number of function | ||
evaluations cost by computation. | ||
""" | ||
|
||
def value_div_score_fn(state, x, t, eps): | ||
"""Pmapped divergence of the drift function.""" | ||
score_fn = mutils.get_score_fn(sde, model, state.params_ema, state.model_state, train=False, continuous=True) | ||
value_div_fn = get_value_div_fn(lambda x, t: score_fn(x, t)) | ||
return value_div_fn(x, t, eps) | ||
|
||
def div_drift_fn(x, t, eps): | ||
div_fn = get_div_fn(lambda x, t: sde.sde(x, t)[0]) | ||
return div_fn(x, t, eps) | ||
|
||
def likelihood_bound_fn(prng, state, data): | ||
"""Compute an unbiased estimate to the log-likelihood in bits/dim. | ||
Args: | ||
prng: An array of random states. The list dimension equals the number of devices. | ||
pstate: Replicated training state for running on multiple devices. | ||
data: A JAX array of shape [#devices, batch size, ...]. | ||
Returns: | ||
bpd: A JAX array of shape [#devices, batch size]. The log-likelihoods on `data` in bits/dim. | ||
N: same as input | ||
""" | ||
rng, step_rng = jax.random.split(prng) | ||
if importance_weighting: | ||
time_samples = sde.sample_importance_weighted_time_for_likelihood(step_rng, (N, data.shape[0]), eps=eps) | ||
Z = sde.likelihood_importance_cum_weight(sde.T, eps=eps) | ||
else: | ||
time_samples = jax.random.uniform(step_rng, (N, data.shape[0]), minval=eps, maxval=sde.T) | ||
Z = 1 | ||
|
||
shape = data.shape | ||
if not dsm: | ||
def scan_fn(carry, vec_time): | ||
rng, value = carry | ||
rng, step_rng = jax.random.split(rng) | ||
if hutchinson_type == 'Gaussian': | ||
epsilon = jax.random.normal(step_rng, shape) | ||
elif hutchinson_type == 'Rademacher': | ||
epsilon = jax.random.rademacher(step_rng, shape, dtype=jnp.float32) | ||
else: | ||
raise NotImplementedError(f"Hutchinson type {hutchinson_type} unknown.") | ||
|
||
rng, step_rng = jax.random.split(rng) | ||
noise = jax.random.normal(step_rng, shape) | ||
mean, std = sde.marginal_prob(data, vec_time) | ||
noisy_data = mean + utils.batch_mul(std, noise) | ||
score_val, score_div = value_div_score_fn(state, noisy_data, vec_time, epsilon) | ||
score_norm = jnp.square(score_val.reshape((score_val.shape[0], -1))).sum(axis=-1) | ||
drift_div = div_drift_fn(noisy_data, vec_time, epsilon) | ||
f, g = sde.sde(noisy_data, vec_time) | ||
integrand = utils.batch_mul(g ** 2, 2 * score_div + score_norm) - 2 * drift_div | ||
if importance_weighting: | ||
integrand = utils.batch_mul(std ** 2 / g ** 2 * Z, integrand) | ||
return (rng, value + integrand), integrand | ||
else: | ||
score_fn = mutils.get_score_fn(sde, model, state.params_ema, state.model_state, train=False, continuous=True) | ||
|
||
def scan_fn(carry, vec_time): | ||
rng, value = carry | ||
rng, step_rng = jax.random.split(rng) | ||
if hutchinson_type == 'Gaussian': | ||
epsilon = jax.random.normal(step_rng, shape) | ||
elif hutchinson_type == 'Rademacher': | ||
epsilon = jax.random.rademacher(step_rng, shape, dtype=jnp.float32) | ||
else: | ||
raise NotImplementedError(f"Hutchinson type {hutchinson_type} unknown.") | ||
rng, step_rng = jax.random.split(rng) | ||
noise = jax.random.normal(step_rng, shape) | ||
mean, std = sde.marginal_prob(data, vec_time) | ||
noisy_data = mean + utils.batch_mul(std, noise) | ||
drift_div = div_drift_fn(noisy_data, vec_time, epsilon) | ||
score_val = score_fn(noisy_data, vec_time) | ||
grad = utils.batch_mul(-(noisy_data - mean), 1 / std ** 2) | ||
diff1 = score_val - grad | ||
diff1 = jnp.square(diff1.reshape((diff1.shape[0], -1))).sum(axis=-1) | ||
diff2 = jnp.square(grad.reshape((grad.shape[0], -1))).sum(axis=-1) | ||
f, g = sde.sde(noisy_data, vec_time) | ||
integrand = utils.batch_mul(g ** 2, diff1 - diff2) - 2 * drift_div | ||
if importance_weighting: | ||
integrand = utils.batch_mul(std ** 2 / g ** 2 * Z, integrand) | ||
return (rng, value + integrand), integrand | ||
|
||
(rng, integral), _ = jax.lax.scan(scan_fn, (rng, jnp.zeros((shape[0],))), time_samples) | ||
integral = integral / N | ||
mean, std = sde.marginal_prob(data, jnp.ones((data.shape[0],)) * sde.T) | ||
rng, step_rng = jax.random.split(rng) | ||
noise = jax.random.normal(step_rng, shape) | ||
neg_prior_logp = -sde.prior_logp(mean + utils.batch_mul(std, noise)) | ||
nlogp = neg_prior_logp + 0.5 * integral | ||
|
||
# whether to enable likelihood offset | ||
if eps_offset: | ||
score_fn = mutils.get_score_fn(sde, model, state.params_ema, state.model_state, train=False, continuous=True) | ||
offset_fn = get_likelihood_offset_fn(sde, score_fn, eps) | ||
rng, step_rng = jax.random.split(rng) | ||
nlogp = nlogp + offset_fn(step_rng, data) | ||
|
||
bpd = nlogp / np.log(2) | ||
dim = np.prod(shape[1:]) | ||
bpd = bpd / dim | ||
|
||
# A hack to convert log-likelihoods to bits/dim | ||
# based on the gradient of the inverse data normalizer. | ||
offset = jnp.log2(jax.grad(inverse_scaler)(0.)) + 8. | ||
bpd += offset | ||
|
||
return bpd, N | ||
|
||
return jax.pmap(likelihood_bound_fn, axis_name='batch') | ||
|
||
|
||
def get_likelihood_offset_fn(sde, score_fn, eps=1e-5): | ||
"""Create a function to compute the unbiased log-likelihood bound of a given data point. | ||
""" | ||
|
||
def likelihood_offset_fn(prng, data): | ||
"""Compute an unbiased estimate to the log-likelihood in bits/dim. | ||
Args: | ||
prng: An array of random states. The list dimension equals the number of devices. | ||
pstate: Replicated training state for running on multiple devices. | ||
data: A JAX array of shape [#devices, batch size, ...]. | ||
Returns: | ||
bpd: A JAX array of shape [#devices, batch size]. The log-likelihoods on `data` in bits/dim. | ||
N: same as input | ||
""" | ||
rng, step_rng = jax.random.split(prng) | ||
shape = data.shape | ||
|
||
eps_vec = jnp.full((shape[0],), eps) | ||
p_mean, p_std = sde.marginal_prob(data, eps_vec) | ||
rng, step_rng = jax.random.split(rng) | ||
noisy_data = p_mean + batch_mul(p_std, jax.random.normal(step_rng, shape)) | ||
score = score_fn(noisy_data, eps_vec) | ||
|
||
alpha, beta = sde.marginal_prob(jnp.ones_like(data), eps_vec) | ||
q_mean = noisy_data / alpha + batch_mul(beta ** 2, score / alpha) | ||
q_std = beta / jnp.mean(alpha, axis=(1, 2, 3)) | ||
|
||
n_dim = np.prod(data.shape[1:]) | ||
p_entropy = n_dim / 2. * (np.log(2 * np.pi) + 2 * jnp.log(p_std) + 1.) | ||
q_recon = n_dim / 2. * (np.log(2 * np.pi) + 2 * jnp.log(q_std)) + batch_mul(0.5 / (q_std ** 2), | ||
jnp.square(data - q_mean).sum( | ||
axis=(1, 2, 3))) | ||
offset = q_recon - p_entropy | ||
return offset | ||
|
||
return likelihood_offset_fn |
Oops, something went wrong.