-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
sandesh
committed
Apr 4, 2022
0 parents
commit 79a5ede
Showing
78 changed files
with
11,445 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
|
||
from dataclasses import dataclass, field | ||
import matplotlib.pyplot as plt | ||
import io | ||
import csv | ||
import numpy as np | ||
import pandas as pd | ||
import seaborn as sns | ||
import matplotlib | ||
import importlib | ||
import os | ||
import functools | ||
import itertools | ||
import torch | ||
# from losses import get_optimizer | ||
# from models.ema import ExponentialMovingAverage | ||
|
||
import torch.nn as nn | ||
import numpy as np | ||
# import tensorflow as tf | ||
# import tensorflow_datasets as tfds | ||
# import tensorflow_gan as tfgan | ||
# import tqdm | ||
import io | ||
# import likelihood | ||
# import controllable_generation | ||
from utils import restore_checkpoint | ||
sns.set(font_scale=2) | ||
sns.set(style="whitegrid") | ||
|
||
import models | ||
from models import utils as mutils | ||
# from models import ncsnv2 | ||
from models import ncsnpp | ||
# from models import ddpm as ddpm_model | ||
# from models import layerspp | ||
# from models import layers | ||
# from models import normalization | ||
# import sampling | ||
from likelihood import get_likelihood_fn | ||
from sde_lib import VESDE, VPSDE, subVPSDE | ||
from sampling import (ReverseDiffusionPredictor, | ||
LangevinCorrector, | ||
EulerMaruyamaPredictor, | ||
AncestralSamplingPredictor, | ||
NoneCorrector, | ||
NonePredictor, | ||
AnnealedLangevinDynamics) | ||
import datasets | ||
|
||
|
||
|
||
sde = 'VESDE' #@param ['VESDE', 'VPSDE', 'subVPSDE'] {"type": "string"} | ||
if sde.lower() == 'vesde': | ||
from configs.ve import cifar10_ncsnpp_continuous as configs | ||
ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth" | ||
config = configs.get_config() | ||
sde = VESDE(sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max, N=config.model.num_scales) | ||
sampling_eps = 1e-5 | ||
elif sde.lower() == 'vpsde': | ||
from configs.vp import cifar10_ddpmpp_continuous as configs | ||
ckpt_filename = "exp/vp/cifar10_ddpmpp_continuous/checkpoint_8.pth" | ||
config = configs.get_config() | ||
sde = VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales) | ||
sampling_eps = 1e-3 | ||
elif sde.lower() == 'subvpsde': | ||
from configs.subvp import cifar10_ddpmpp_continuous as configs | ||
ckpt_filename = "exp/subvp/cifar10_ddpmpp_continuous/checkpoint_26.pth" | ||
config = configs.get_config() | ||
sde = subVPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales) | ||
sampling_eps = 1e-3 | ||
|
||
batch_size = 64#@param {"type":"integer"} | ||
config.training.batch_size = batch_size | ||
config.eval.batch_size = batch_size | ||
|
||
random_seed = 0 #@param {"type": "integer"} | ||
|
||
sigmas = mutils.get_sigmas(config) | ||
scaler = datasets.get_data_scaler(config) | ||
inverse_scaler = datasets.get_data_inverse_scaler(config) | ||
score_model = mutils.create_model(config) | ||
|
||
#@title Likelihood computation | ||
train_ds, eval_ds, _ = datasets.get_dataset(config, uniform_dequantization=True, evaluation=True) | ||
eval_iter = iter(eval_ds) | ||
bpds = [] | ||
likelihood_fn = get_likelihood_fn(sde, inverse_scaler, eps=1e-5) | ||
for i in range(5): | ||
batch = next(iter(train_ds)) | ||
img = batch['image']._numpy() | ||
img = torch.tensor(img).permute(0, 3, 1, 2).to(config.device) | ||
img = scaler(img) | ||
bpd, z, nfe = likelihood_fn(score_model, img) | ||
bpds.extend(bpd) | ||
print(f"average bpd: {torch.tensor(bpds).mean().item()}, NFE: {nfe}") | ||
|
||
for i in range(5): | ||
batch = next(iter(eval_ds)) | ||
img = batch['image']._numpy() | ||
img = torch.tensor(img).permute(0, 3, 1, 2).to(config.device) | ||
img = scaler(img) | ||
bpd, z, nfe = likelihood_fn(score_model, img) | ||
bpds.extend(bpd) | ||
print(f"average bpd cifar eval: {torch.tensor(bpds).mean().item()}, NFE: {nfe}") | ||
#Previous model was cifar10 and tested on cifar10 also | ||
#Now, we load celeba dataset and test the model trained on cifar10 to celeba | ||
|
||
from configs.ve import celeba_ncsnpp as configs_celeba | ||
config_celeba = configs_celeba.get_config() | ||
train_celeba, eval_celeba, _ = datasets.get_dataset(config_celeba, uniform_dequantization=True, evaluation=True) | ||
|
||
batch = next(iter(train_celeba)) | ||
img = batch['image']._numpy() | ||
img = torch.tensor(img).permute(0, 3, 1, 2).to(config.device) | ||
img = scaler(img) | ||
bpd, z, nfe = likelihood_fn(score_model, img) | ||
bpds.extend(bpd) | ||
print(f"average bpd celeba: {torch.tensor(bpds).mean().item()}, NFE: {nfe}") | ||
|
||
|
||
# from configs.ve import bedroom_ncsnpp_continuous as configs_bedr | ||
# config_bedr = configs_bedr.get_config() | ||
# train_bedr, eval_bedr, _ = datasets.get_dataset(config_bedr, uniform_dequantization=True, evaluation=True) | ||
# | ||
# for r in range(3): | ||
# batch = next(iter(train_bedr)) | ||
# img = batch['image']._numpy() | ||
# img = torch.tensor(img).permute(0, 3, 1, 2).to(config.device) | ||
# img = scaler(img) | ||
# bpd, z, nfe = likelihood_fn(score_model, img) | ||
# bpds.extend(bpd) | ||
# print(f"average bpd bedr: {torch.tensor(bpds).mean().item()}, NFE: {nfe}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import ml_collections | ||
import torch | ||
|
||
|
||
def get_default_configs(): | ||
config = ml_collections.ConfigDict() | ||
# training | ||
config.training = training = ml_collections.ConfigDict() | ||
config.training.batch_size = 128 | ||
training.n_iters = 1300001 | ||
training.snapshot_freq = 50000 | ||
training.log_freq = 50 | ||
training.eval_freq = 100 | ||
## store additional checkpoints for preemption in cloud computing environments | ||
training.snapshot_freq_for_preemption = 10000 | ||
## produce samples at each snapshot. | ||
training.snapshot_sampling = True | ||
training.likelihood_weighting = False | ||
training.continuous = True | ||
training.reduce_mean = False | ||
|
||
# sampling | ||
config.sampling = sampling = ml_collections.ConfigDict() | ||
sampling.n_steps_each = 1 | ||
sampling.noise_removal = True | ||
sampling.probability_flow = False | ||
sampling.snr = 0.17 | ||
|
||
# evaluation | ||
config.eval = evaluate = ml_collections.ConfigDict() | ||
evaluate.begin_ckpt = 1 | ||
evaluate.end_ckpt = 26 | ||
evaluate.batch_size = 1024 | ||
evaluate.enable_sampling = True | ||
evaluate.num_samples = 50000 | ||
evaluate.enable_loss = True | ||
evaluate.enable_bpd = False | ||
evaluate.bpd_dataset = 'test' | ||
|
||
# data | ||
config.data = data = ml_collections.ConfigDict() | ||
data.dataset = 'CELEBA' | ||
data.image_size = 64 | ||
data.random_flip = True | ||
data.uniform_dequantization = False | ||
data.centered = False | ||
data.num_channels = 3 | ||
|
||
# model | ||
config.model = model = ml_collections.ConfigDict() | ||
model.sigma_max = 90. | ||
model.sigma_min = 0.01 | ||
model.num_scales = 1000 | ||
model.beta_min = 0.1 | ||
model.beta_max = 20. | ||
model.dropout = 0.1 | ||
model.embedding_type = 'fourier' | ||
|
||
# optimization | ||
config.optim = optim = ml_collections.ConfigDict() | ||
optim.weight_decay = 0 | ||
optim.optimizer = 'Adam' | ||
optim.lr = 2e-4 | ||
optim.beta1 = 0.9 | ||
optim.eps = 1e-8 | ||
optim.warmup = 5000 | ||
optim.grad_clip = 1. | ||
|
||
config.seed = 42 | ||
config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') | ||
|
||
return config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import ml_collections | ||
import torch | ||
|
||
|
||
def get_default_configs(): | ||
config = ml_collections.ConfigDict() | ||
# training | ||
config.training = training = ml_collections.ConfigDict() | ||
config.training.batch_size = 128 | ||
training.n_iters = 1300001 | ||
training.snapshot_freq = 5000 | ||
training.log_freq = 50 | ||
training.eval_freq = 100 | ||
## store additional checkpoints for preemption in cloud computing environments | ||
training.snapshot_freq_for_preemption = 10000 | ||
## produce samples at each snapshot. | ||
training.snapshot_sampling = True | ||
training.likelihood_weighting = False | ||
training.continuous = True | ||
training.reduce_mean = False | ||
|
||
# sampling | ||
config.sampling = sampling = ml_collections.ConfigDict() | ||
sampling.n_steps_each = 1 | ||
sampling.noise_removal = True | ||
sampling.probability_flow = False | ||
sampling.snr = 0.16 | ||
|
||
# evaluation | ||
config.eval = evaluate = ml_collections.ConfigDict() | ||
evaluate.begin_ckpt = 10 | ||
evaluate.end_ckpt = 26 | ||
evaluate.batch_size = 1024 | ||
evaluate.enable_sampling = False | ||
evaluate.num_samples = 50000 | ||
evaluate.enable_loss = True | ||
evaluate.enable_bpd = False | ||
evaluate.bpd_dataset = 'test' | ||
|
||
# data | ||
config.data = data = ml_collections.ConfigDict() | ||
data.dataset = 'CIFAR10' | ||
data.image_size = 32 | ||
data.random_flip = True | ||
data.centered = False | ||
data.uniform_dequantization = False | ||
data.num_channels = 3 | ||
|
||
# model | ||
config.model = model = ml_collections.ConfigDict() | ||
model.sigma_min = 0.01 | ||
model.sigma_max = 50 | ||
model.num_scales = 1000 | ||
model.beta_min = 0.1 | ||
model.beta_max = 20. | ||
model.dropout = 0.1 | ||
model.embedding_type = 'fourier' | ||
|
||
# optimization | ||
config.optim = optim = ml_collections.ConfigDict() | ||
optim.weight_decay = 0 | ||
optim.optimizer = 'Adam' | ||
optim.lr = 1e-4 | ||
optim.beta1 = 0.9 | ||
optim.eps = 1e-8 | ||
optim.warmup = 5000 | ||
optim.grad_clip = 1. | ||
|
||
config.seed = 42 | ||
config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') | ||
|
||
return config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import ml_collections | ||
import torch | ||
|
||
|
||
def get_default_configs(): | ||
config = ml_collections.ConfigDict() | ||
# training | ||
config.training = training = ml_collections.ConfigDict() | ||
config.training.batch_size = 64 | ||
training.n_iters = 2400001 | ||
training.snapshot_freq = 50000 | ||
training.log_freq = 50 | ||
training.eval_freq = 100 | ||
## store additional checkpoints for preemption in cloud computing environments | ||
training.snapshot_freq_for_preemption = 5000 | ||
## produce samples at each snapshot. | ||
training.snapshot_sampling = True | ||
training.likelihood_weighting = False | ||
training.continuous = True | ||
training.reduce_mean = False | ||
|
||
# sampling | ||
config.sampling = sampling = ml_collections.ConfigDict() | ||
sampling.n_steps_each = 1 | ||
sampling.noise_removal = True | ||
sampling.probability_flow = False | ||
sampling.snr = 0.075 | ||
|
||
# evaluation | ||
config.eval = evaluate = ml_collections.ConfigDict() | ||
evaluate.begin_ckpt = 50 | ||
evaluate.end_ckpt = 96 | ||
evaluate.batch_size = 512 | ||
evaluate.enable_sampling = True | ||
evaluate.num_samples = 50000 | ||
evaluate.enable_loss = True | ||
evaluate.enable_bpd = False | ||
evaluate.bpd_dataset = 'test' | ||
|
||
# data | ||
config.data = data = ml_collections.ConfigDict() | ||
data.dataset = 'LSUN' | ||
data.image_size = 256 | ||
data.random_flip = True | ||
data.uniform_dequantization = False | ||
data.centered = False | ||
data.num_channels = 3 | ||
|
||
# model | ||
config.model = model = ml_collections.ConfigDict() | ||
model.sigma_max = 378 | ||
model.sigma_min = 0.01 | ||
model.num_scales = 2000 | ||
model.beta_min = 0.1 | ||
model.beta_max = 20. | ||
model.dropout = 0. | ||
model.embedding_type = 'fourier' | ||
|
||
# optimization | ||
config.optim = optim = ml_collections.ConfigDict() | ||
optim.weight_decay = 0 | ||
optim.optimizer = 'Adam' | ||
optim.lr = 2e-4 | ||
optim.beta1 = 0.9 | ||
optim.eps = 1e-8 | ||
optim.warmup = 5000 | ||
optim.grad_clip = 1. | ||
|
||
config.seed = 42 | ||
config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') | ||
|
||
return config |
Oops, something went wrong.