Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
sandesh committed Apr 4, 2022
0 parents commit 79a5ede
Show file tree
Hide file tree
Showing 78 changed files with 11,445 additions and 0 deletions.
133 changes: 133 additions & 0 deletions compute_likelihood.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@

from dataclasses import dataclass, field
import matplotlib.pyplot as plt
import io
import csv
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib
import importlib
import os
import functools
import itertools
import torch
# from losses import get_optimizer
# from models.ema import ExponentialMovingAverage

import torch.nn as nn
import numpy as np
# import tensorflow as tf
# import tensorflow_datasets as tfds
# import tensorflow_gan as tfgan
# import tqdm
import io
# import likelihood
# import controllable_generation
from utils import restore_checkpoint
sns.set(font_scale=2)
sns.set(style="whitegrid")

import models
from models import utils as mutils
# from models import ncsnv2
from models import ncsnpp
# from models import ddpm as ddpm_model
# from models import layerspp
# from models import layers
# from models import normalization
# import sampling
from likelihood import get_likelihood_fn
from sde_lib import VESDE, VPSDE, subVPSDE
from sampling import (ReverseDiffusionPredictor,
LangevinCorrector,
EulerMaruyamaPredictor,
AncestralSamplingPredictor,
NoneCorrector,
NonePredictor,
AnnealedLangevinDynamics)
import datasets



sde = 'VESDE' #@param ['VESDE', 'VPSDE', 'subVPSDE'] {"type": "string"}
if sde.lower() == 'vesde':
from configs.ve import cifar10_ncsnpp_continuous as configs
ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth"
config = configs.get_config()
sde = VESDE(sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max, N=config.model.num_scales)
sampling_eps = 1e-5
elif sde.lower() == 'vpsde':
from configs.vp import cifar10_ddpmpp_continuous as configs
ckpt_filename = "exp/vp/cifar10_ddpmpp_continuous/checkpoint_8.pth"
config = configs.get_config()
sde = VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
sampling_eps = 1e-3
elif sde.lower() == 'subvpsde':
from configs.subvp import cifar10_ddpmpp_continuous as configs
ckpt_filename = "exp/subvp/cifar10_ddpmpp_continuous/checkpoint_26.pth"
config = configs.get_config()
sde = subVPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
sampling_eps = 1e-3

batch_size = 64#@param {"type":"integer"}
config.training.batch_size = batch_size
config.eval.batch_size = batch_size

random_seed = 0 #@param {"type": "integer"}

sigmas = mutils.get_sigmas(config)
scaler = datasets.get_data_scaler(config)
inverse_scaler = datasets.get_data_inverse_scaler(config)
score_model = mutils.create_model(config)

#@title Likelihood computation
train_ds, eval_ds, _ = datasets.get_dataset(config, uniform_dequantization=True, evaluation=True)
eval_iter = iter(eval_ds)
bpds = []
likelihood_fn = get_likelihood_fn(sde, inverse_scaler, eps=1e-5)
for i in range(5):
batch = next(iter(train_ds))
img = batch['image']._numpy()
img = torch.tensor(img).permute(0, 3, 1, 2).to(config.device)
img = scaler(img)
bpd, z, nfe = likelihood_fn(score_model, img)
bpds.extend(bpd)
print(f"average bpd: {torch.tensor(bpds).mean().item()}, NFE: {nfe}")

for i in range(5):
batch = next(iter(eval_ds))
img = batch['image']._numpy()
img = torch.tensor(img).permute(0, 3, 1, 2).to(config.device)
img = scaler(img)
bpd, z, nfe = likelihood_fn(score_model, img)
bpds.extend(bpd)
print(f"average bpd cifar eval: {torch.tensor(bpds).mean().item()}, NFE: {nfe}")
#Previous model was cifar10 and tested on cifar10 also
#Now, we load celeba dataset and test the model trained on cifar10 to celeba

from configs.ve import celeba_ncsnpp as configs_celeba
config_celeba = configs_celeba.get_config()
train_celeba, eval_celeba, _ = datasets.get_dataset(config_celeba, uniform_dequantization=True, evaluation=True)

batch = next(iter(train_celeba))
img = batch['image']._numpy()
img = torch.tensor(img).permute(0, 3, 1, 2).to(config.device)
img = scaler(img)
bpd, z, nfe = likelihood_fn(score_model, img)
bpds.extend(bpd)
print(f"average bpd celeba: {torch.tensor(bpds).mean().item()}, NFE: {nfe}")


# from configs.ve import bedroom_ncsnpp_continuous as configs_bedr
# config_bedr = configs_bedr.get_config()
# train_bedr, eval_bedr, _ = datasets.get_dataset(config_bedr, uniform_dequantization=True, evaluation=True)
#
# for r in range(3):
# batch = next(iter(train_bedr))
# img = batch['image']._numpy()
# img = torch.tensor(img).permute(0, 3, 1, 2).to(config.device)
# img = scaler(img)
# bpd, z, nfe = likelihood_fn(score_model, img)
# bpds.extend(bpd)
# print(f"average bpd bedr: {torch.tensor(bpds).mean().item()}, NFE: {nfe}")
72 changes: 72 additions & 0 deletions configs/default_celeba_configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import ml_collections
import torch


def get_default_configs():
config = ml_collections.ConfigDict()
# training
config.training = training = ml_collections.ConfigDict()
config.training.batch_size = 128
training.n_iters = 1300001
training.snapshot_freq = 50000
training.log_freq = 50
training.eval_freq = 100
## store additional checkpoints for preemption in cloud computing environments
training.snapshot_freq_for_preemption = 10000
## produce samples at each snapshot.
training.snapshot_sampling = True
training.likelihood_weighting = False
training.continuous = True
training.reduce_mean = False

# sampling
config.sampling = sampling = ml_collections.ConfigDict()
sampling.n_steps_each = 1
sampling.noise_removal = True
sampling.probability_flow = False
sampling.snr = 0.17

# evaluation
config.eval = evaluate = ml_collections.ConfigDict()
evaluate.begin_ckpt = 1
evaluate.end_ckpt = 26
evaluate.batch_size = 1024
evaluate.enable_sampling = True
evaluate.num_samples = 50000
evaluate.enable_loss = True
evaluate.enable_bpd = False
evaluate.bpd_dataset = 'test'

# data
config.data = data = ml_collections.ConfigDict()
data.dataset = 'CELEBA'
data.image_size = 64
data.random_flip = True
data.uniform_dequantization = False
data.centered = False
data.num_channels = 3

# model
config.model = model = ml_collections.ConfigDict()
model.sigma_max = 90.
model.sigma_min = 0.01
model.num_scales = 1000
model.beta_min = 0.1
model.beta_max = 20.
model.dropout = 0.1
model.embedding_type = 'fourier'

# optimization
config.optim = optim = ml_collections.ConfigDict()
optim.weight_decay = 0
optim.optimizer = 'Adam'
optim.lr = 2e-4
optim.beta1 = 0.9
optim.eps = 1e-8
optim.warmup = 5000
optim.grad_clip = 1.

config.seed = 42
config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

return config
72 changes: 72 additions & 0 deletions configs/default_cifar10_configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import ml_collections
import torch


def get_default_configs():
config = ml_collections.ConfigDict()
# training
config.training = training = ml_collections.ConfigDict()
config.training.batch_size = 128
training.n_iters = 1300001
training.snapshot_freq = 5000
training.log_freq = 50
training.eval_freq = 100
## store additional checkpoints for preemption in cloud computing environments
training.snapshot_freq_for_preemption = 10000
## produce samples at each snapshot.
training.snapshot_sampling = True
training.likelihood_weighting = False
training.continuous = True
training.reduce_mean = False

# sampling
config.sampling = sampling = ml_collections.ConfigDict()
sampling.n_steps_each = 1
sampling.noise_removal = True
sampling.probability_flow = False
sampling.snr = 0.16

# evaluation
config.eval = evaluate = ml_collections.ConfigDict()
evaluate.begin_ckpt = 10
evaluate.end_ckpt = 26
evaluate.batch_size = 1024
evaluate.enable_sampling = False
evaluate.num_samples = 50000
evaluate.enable_loss = True
evaluate.enable_bpd = False
evaluate.bpd_dataset = 'test'

# data
config.data = data = ml_collections.ConfigDict()
data.dataset = 'CIFAR10'
data.image_size = 32
data.random_flip = True
data.centered = False
data.uniform_dequantization = False
data.num_channels = 3

# model
config.model = model = ml_collections.ConfigDict()
model.sigma_min = 0.01
model.sigma_max = 50
model.num_scales = 1000
model.beta_min = 0.1
model.beta_max = 20.
model.dropout = 0.1
model.embedding_type = 'fourier'

# optimization
config.optim = optim = ml_collections.ConfigDict()
optim.weight_decay = 0
optim.optimizer = 'Adam'
optim.lr = 1e-4
optim.beta1 = 0.9
optim.eps = 1e-8
optim.warmup = 5000
optim.grad_clip = 1.

config.seed = 42
config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

return config
72 changes: 72 additions & 0 deletions configs/default_lsun_configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import ml_collections
import torch


def get_default_configs():
config = ml_collections.ConfigDict()
# training
config.training = training = ml_collections.ConfigDict()
config.training.batch_size = 64
training.n_iters = 2400001
training.snapshot_freq = 50000
training.log_freq = 50
training.eval_freq = 100
## store additional checkpoints for preemption in cloud computing environments
training.snapshot_freq_for_preemption = 5000
## produce samples at each snapshot.
training.snapshot_sampling = True
training.likelihood_weighting = False
training.continuous = True
training.reduce_mean = False

# sampling
config.sampling = sampling = ml_collections.ConfigDict()
sampling.n_steps_each = 1
sampling.noise_removal = True
sampling.probability_flow = False
sampling.snr = 0.075

# evaluation
config.eval = evaluate = ml_collections.ConfigDict()
evaluate.begin_ckpt = 50
evaluate.end_ckpt = 96
evaluate.batch_size = 512
evaluate.enable_sampling = True
evaluate.num_samples = 50000
evaluate.enable_loss = True
evaluate.enable_bpd = False
evaluate.bpd_dataset = 'test'

# data
config.data = data = ml_collections.ConfigDict()
data.dataset = 'LSUN'
data.image_size = 256
data.random_flip = True
data.uniform_dequantization = False
data.centered = False
data.num_channels = 3

# model
config.model = model = ml_collections.ConfigDict()
model.sigma_max = 378
model.sigma_min = 0.01
model.num_scales = 2000
model.beta_min = 0.1
model.beta_max = 20.
model.dropout = 0.
model.embedding_type = 'fourier'

# optimization
config.optim = optim = ml_collections.ConfigDict()
optim.weight_decay = 0
optim.optimizer = 'Adam'
optim.lr = 2e-4
optim.beta1 = 0.9
optim.eps = 1e-8
optim.warmup = 5000
optim.grad_clip = 1.

config.seed = 42
config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

return config
Loading

0 comments on commit 79a5ede

Please sign in to comment.