Skip to content

Commit 31afa44

Browse files
Initial Commit for LDM
0 parents  commit 31afa44

27 files changed

+2038
-0
lines changed

.gitignore

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Ignore all image files
2+
*.jpg
3+
*.png
4+
*.jpeg
5+
6+
# Ignore pycharm and system files
7+
.DS_Store
8+
*.idea
9+
__pycache__
10+
*.zip
11+
12+
# Ignore dataset files
13+
*.csv
14+
*.json
15+
16+
# Ignore checkpoints
17+
*.pth
18+
19+
# Ignore pickle files
20+
*.pkl

README.md

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
Stable Diffusion Implementation in PyTorch
2+
========
3+
4+
This repository implements Stable Diffusion.
5+
As of now this only implements unconditional latent diffusion models and trains on mnist and celebhq dataset.
6+
Pretty soon it will also have code for conditional ldm.
7+
8+
For autoencoder I provide code for vae as well as vqvae.
9+
But both the stages of training use VQVAE only. One can easily change that to vae if needed
10+
11+
For diffusion part, as of now it only implements DDPM with linear schedule.
12+
13+
14+
## Stable Diffusion Videos
15+
16+
17+
18+
## Sample Output for Autoencoder on CelebHQ
19+
Image - Top, Reconstructions - Below
20+
21+
22+
## Sample Output for LDM on CelebHQ
23+
24+
25+
## Data preparation
26+
For setting up the mnist dataset:
27+
28+
Follow - https://github.com/explainingai-code/Pytorch-VAE#data-preparation
29+
30+
For setting up on CelebHQ, simply download the images from the official site.
31+
And mention the right path in the configuration.
32+
33+
34+
For training on your own dataset
35+
* Create your own config and have the path point to images (look at celebhq.yaml for guidance)
36+
* Create your own dataset class, similar to celeb_dataset.py
37+
* Map the dataset name to the right class in the training code
38+
39+
40+
# Quickstart
41+
* Create a new conda environment with python 3.8 then run below commands
42+
* ```git clone https://github.com/explainingai-code/StableDiffusion-PyTorch.git```
43+
* ```cd StableDiffusion-PyTorch```
44+
* ```pip install -r requirements.txt```
45+
* Download lpips from https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/weights/v0.1/vgg.pth and put it in ```models/weights/v0.1/vgg.pth```
46+
* For training autoencoder
47+
* ```python -m tools.train_vqvae --config config/mnist.yaml``` for training vqvae
48+
* ```python -m tools.infer_vqvae --config config/mnist.yaml``` for generating reconstructions
49+
* For training ldm
50+
* ```python -m tools.train_ddpm_vqvae --config config/mnist.yaml``` for training ddpm
51+
* ```python -m tools.sample_ddpm_vqvae --config config/mnist.yaml``` for generating images
52+
53+
## Configuration
54+
Allows you to play with different components of ddpm and autoencoder training
55+
* ```config/mnist.yaml``` - Small autoencoder and ldm can even be trained on CPU
56+
* ```config/celebhq.yaml``` - Configuration used for celebhq dataset
57+
58+
Relevant configuration parameters
59+
60+
Most parameters are self explanatory but below I mention couple which are specific to this repo.
61+
* ```autoencoder_acc_steps``` : For accumulating gradients if image size is too large for larger batch sizes
62+
* ```save_latents``` : Enable this to save the latents , during inference of autoencoder. That way ddpm training will be faster
63+
64+
## Output
65+
Outputs will be saved according to the configuration present in yaml files.
66+
67+
For every run a folder of ```task_name``` key in config will be created
68+
69+
During training of autoencoder the following output will be saved
70+
* Latest Autoencoder and discriminator checkpoint in ```task_name``` directory
71+
* Sample reconstructions in ```task_name/vqvae_autoencoder_samples```
72+
73+
During inference of autoencoder the following output will be saved
74+
* Reconstructions for random images in ```task_name```
75+
* Latents will be save in ```task_name/vqvae_latent_dir_name``` if mentioned in config
76+
77+
During training of DDPM we will save the latest checkpoint in ```task_name``` directory
78+
During sampling, sampled image grid for all timesteps in ```task_name/samples/*.png```
79+
80+
81+
82+
83+

config/__init__.py

Whitespace-only changes.

config/celebhq.yaml

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
dataset_params:
2+
im_path: 'data/celeba_hq_256'
3+
im_channels : 3
4+
im_size : 256
5+
name: 'celebhq'
6+
7+
diffusion_params:
8+
num_timesteps : 1000
9+
beta_start : 0.0015
10+
beta_end : 0.0195
11+
12+
ldm_params:
13+
down_channels: [ 256, 384, 512, 768 ]
14+
mid_channels: [ 768, 512 ]
15+
down_sample: [ True, True, True ]
16+
attn_down : [True, True, True]
17+
time_emb_dim: 512
18+
norm_channels: 32
19+
num_heads: 16
20+
conv_out_channels : 128
21+
num_down_layers : 2
22+
num_mid_layers : 2
23+
num_up_layers : 2
24+
25+
autoencoder_params:
26+
z_channels: 3
27+
codebook_size : 8192
28+
down_channels : [64, 128, 256, 256]
29+
mid_channels : [256, 256]
30+
down_sample : [True, True, True]
31+
attn_down : [False, False, False]
32+
norm_channels: 32
33+
num_heads: 4
34+
num_down_layers : 2
35+
num_mid_layers : 2
36+
num_up_layers : 2
37+
38+
39+
train_params:
40+
seed : 1111
41+
task_name: 'celebhq'
42+
ldm_batch_size: 16
43+
autoencoder_batch_size: 4
44+
disc_start: 15000
45+
disc_weight: 0.5
46+
codebook_weight: 1
47+
commitment_beta: 0.2
48+
perceptual_weight: 1
49+
kl_weight: 0.000005
50+
ldm_epochs: 100
51+
autoencoder_epochs: 20
52+
num_samples: 1
53+
num_grid_rows: 1
54+
ldm_lr: 0.000005
55+
autoencoder_lr: 0.00001
56+
autoencoder_acc_steps: 4
57+
autoencoder_img_save_steps: 64
58+
save_latents : False
59+
vae_latent_dir_name: 'vae_latents'
60+
vqvae_latent_dir_name: 'vqvae_latents'
61+
ldm_ckpt_name: 'ddpm_ckpt.pth'
62+
vqvae_autoencoder_ckpt_name: 'vqvae_autoencoder_ckpt.pth'
63+
vae_autoencoder_ckpt_name: 'vae_autoencoder_ckpt.pth'
64+
vqvae_discriminator_ckpt_name: 'vqvae_discriminator_ckpt.pth'
65+
vae_discriminator_ckpt_name: 'vae_discriminator_ckpt.pth'

config/mnist.yaml

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
dataset_params:
2+
im_path: '/Users/tusharkumar/PycharmProjects/explainingai-repos/StableDiffusion-Pytorch/data/train/images'
3+
im_channels : 1
4+
im_size : 28
5+
name: 'mnist'
6+
7+
diffusion_params:
8+
num_timesteps : 1000
9+
beta_start : 0.0015
10+
beta_end : 0.0195
11+
12+
ldm_params:
13+
down_channels: [ 128, 256, 256, 256]
14+
mid_channels: [ 256, 256]
15+
down_sample: [ False, False, False ]
16+
attn_down : [True, True, True]
17+
time_emb_dim: 256
18+
norm_channels : 32
19+
num_heads : 16
20+
conv_out_channels : 128
21+
num_down_layers: 2
22+
num_mid_layers: 2
23+
num_up_layers: 2
24+
25+
autoencoder_params:
26+
z_channels: 3
27+
codebook_size : 20
28+
down_channels : [32, 64, 128]
29+
mid_channels : [128, 128]
30+
down_sample : [True, True]
31+
attn_down : [False, False]
32+
norm_channels: 32
33+
num_heads: 16
34+
num_down_layers : 1
35+
num_mid_layers : 1
36+
num_up_layers : 1
37+
38+
train_params:
39+
seed : 1111
40+
task_name: 'mnist'
41+
ldm_batch_size: 64
42+
autoencoder_batch_size: 64
43+
disc_start: 1000
44+
disc_weight: 0.5
45+
codebook_weight: 1
46+
commitment_beta: 0.2
47+
perceptual_weight: 1
48+
kl_weight: 0.000005
49+
ldm_epochs : 100
50+
autoencoder_epochs : 10
51+
num_samples : 25
52+
num_grid_rows : 5
53+
ldm_lr: 0.00001
54+
autoencoder_lr: 0.0001
55+
autoencoder_acc_steps : 1
56+
autoencoder_img_save_steps : 8
57+
save_latents : False
58+
vae_latent_dir_name : 'vae_latents'
59+
vqvae_latent_dir_name : 'vqvae_latents'
60+
ldm_ckpt_name: 'ddpm_ckpt.pth'
61+
vqvae_autoencoder_ckpt_name: 'vqvae_autoencoder_ckpt.pth'
62+
vae_autoencoder_ckpt_name: 'vae_autoencoder_ckpt.pth'
63+
vqvae_discriminator_ckpt_name: 'vqvae_discriminator_ckpt.pth'
64+
vae_discriminator_ckpt_name: 'vae_discriminator_ckpt.pth'

dataset/__init__.py

Whitespace-only changes.

dataset/celeb_dataset.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import glob
2+
import os
3+
import random
4+
import torch
5+
import torchvision
6+
from PIL import Image
7+
from torch.utils.data import DataLoader
8+
from utils.diffusion_utils import load_latents
9+
from tqdm import tqdm
10+
from torch.utils.data.dataset import Dataset
11+
12+
13+
class CelebDataset(Dataset):
14+
r"""
15+
Celeb dataset will by default resize the images.
16+
This can be replaced by any other dataset. As long as all the images
17+
are under one directory.
18+
"""
19+
20+
def __init__(self, split, im_path, im_size=256, im_channels=3, im_ext='jpg',
21+
use_latents=False, latent_path=None):
22+
self.split = split
23+
self.im_size = im_size
24+
self.im_channels = im_channels
25+
self.im_ext = im_ext
26+
self.latent_maps = None
27+
self.use_latents = False
28+
self.images = self.load_images(im_path)
29+
30+
# Whether to load images or to load latents
31+
if use_latents and latent_path is not None:
32+
latent_maps = load_latents(latent_path)
33+
if len(latent_maps) == len(self.images):
34+
self.use_latents = True
35+
self.latent_maps = latent_maps
36+
print('Found {} latents'.format(len(self.latent_maps)))
37+
else:
38+
print('Latents not found')
39+
40+
def load_images(self, im_path):
41+
r"""
42+
Gets all images from the path specified
43+
and stacks them all up
44+
"""
45+
assert os.path.exists(im_path), "images path {} does not exist".format(im_path)
46+
ims = []
47+
fnames = glob.glob(os.path.join(im_path, '*.{}'.format('png')))
48+
fnames += glob.glob(os.path.join(im_path, '*.{}'.format('jpg')))
49+
fnames += glob.glob(os.path.join(im_path, '*.{}'.format('jpeg')))
50+
for fname in fnames:
51+
ims.append(fname)
52+
print('Found {} images'.format(len(ims)))
53+
return ims
54+
55+
def __len__(self):
56+
return len(self.images)
57+
58+
def __getitem__(self, index):
59+
if self.use_latents:
60+
latent = self.latent_maps[self.images[index]]
61+
return latent
62+
else:
63+
im = Image.open(self.images[index])
64+
im_tensor = torchvision.transforms.Compose([
65+
torchvision.transforms.Resize(self.im_size),
66+
torchvision.transforms.CenterCrop(self.im_size),
67+
torchvision.transforms.ToTensor(),
68+
])(im)
69+
im.close()
70+
71+
# Convert input to -1 to 1 range.
72+
im_tensor = (2 * im_tensor) - 1
73+
return im_tensor

dataset/mnist_dataset.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import glob
2+
import os
3+
import pickle
4+
import torchvision
5+
from PIL import Image
6+
from tqdm import tqdm
7+
from utils.diffusion_utils import load_latents
8+
from torch.utils.data.dataloader import DataLoader
9+
from torch.utils.data.dataset import Dataset
10+
11+
12+
class MnistDataset(Dataset):
13+
r"""
14+
Nothing special here. Just a simple dataset class for mnist images.
15+
Created a dataset class rather using torchvision to allow
16+
replacement with any other image dataset
17+
"""
18+
19+
def __init__(self, split, im_path, im_size, im_channels,
20+
use_latents=False, latent_path=None):
21+
r"""
22+
Init method for initializing the dataset properties
23+
:param split: train/test to locate the image files
24+
:param im_path: root folder of images
25+
:param im_ext: image extension. assumes all
26+
images would be this type.
27+
"""
28+
self.split = split
29+
self.latent_maps = None
30+
self.use_latents = False
31+
self.images, self.labels = self.load_images(im_path)
32+
# Whether to load images or to load latents
33+
if use_latents and latent_path is not None:
34+
latent_maps = load_latents(latent_path)
35+
if len(latent_maps) == len(self.images):
36+
self.use_latents = True
37+
self.latent_maps = latent_maps
38+
print('Found {} latents'.format(len(self.latent_maps)))
39+
else:
40+
print('Latents not found')
41+
42+
def load_images(self, im_path):
43+
r"""
44+
Gets all images from the path specified
45+
and stacks them all up
46+
:param im_path:
47+
:return:
48+
"""
49+
assert os.path.exists(im_path), "images path {} does not exist".format(im_path)
50+
ims = []
51+
labels = []
52+
for d_name in tqdm(os.listdir(im_path)):
53+
fnames = glob.glob(os.path.join(im_path, d_name, '*.{}'.format('png')))
54+
fnames += glob.glob(os.path.join(im_path, d_name, '*.{}'.format('jpg')))
55+
fnames += glob.glob(os.path.join(im_path, d_name, '*.{}'.format('jpeg')))
56+
for fname in fnames:
57+
ims.append(fname)
58+
#labels.append(int(d_name))
59+
print('Found {} images for split {}'.format(len(ims), self.split))
60+
return ims, labels
61+
62+
def __len__(self):
63+
return len(self.images)
64+
65+
def __getitem__(self, index):
66+
if self.use_latents:
67+
latent = self.latent_maps[self.images[index]]
68+
return latent
69+
else:
70+
im = Image.open(self.images[index])
71+
im_tensor = torchvision.transforms.ToTensor()(im)
72+
73+
# Convert input to -1 to 1 range.
74+
im_tensor = (2 * im_tensor) - 1
75+
return im_tensor
76+

models/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)