Skip to content

Commit

Permalink
added loss fn and trainable vae
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrewzh112 committed Dec 2, 2020
1 parent 15558ee commit 40967e1
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 1 deletion.
9 changes: 9 additions & 0 deletions networks/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,12 @@ def __init__(self, in_channels, activation, normalization):

def forward(self, x):
return self.resblock(x)


class Reshape(nn.Module):
def __init__(self, *args):
super().__init__()
self.shape = args

def forward(self, x):
return x.view(self.shape)
19 changes: 19 additions & 0 deletions vae/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from torch import nn
import torch


class VAELoss(nn.Module):
def __init__(self, args):
super().__init__()
if args.recon == 'l1':
self.recon = nn.L1Loss()
elif args.recon == 'l2':
self.recon = nn.MSELoss()

def _KL_Loss(self, mu, logvar):
return torch.mean(-0.5 * torch.sum(1 + logvar - mu ** 2 - logvar.exp(), dim=1), dim=0)

def forward(self, x, x_hat, mu, logvar):
reconstruction_loss = self.recon(x_hat, x)
kl_loss = self._KL_Loss(mu, logvar)
return reconstruction_loss + kl_loss
73 changes: 73 additions & 0 deletions vae/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import torch
import torchvision
import argparse
import os
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from pathlib import Path

from vae.data import get_loaders
from vae.model import VAE
from vae.loss import VAELoss


parser = argparse.ArgumentParser()
parser.add_argument('--img_channels', type=int, default=3, help='Numer of channels for images')
parser.add_argument('--model_dim', type=float, default=64, help='model dimensions multiplier')
parser.add_argument('--z_dim', type=float, default=100, help='dimension of random noise latent vector')
parser.add_argument('--img_size', type=int, default=64, help='H, W of the input images')
parser.add_argument('--crop_size', type=int, default=128, help='H, W of the input images')
parser.add_argument('--n_res_blocks', type=int, default=9, help='Number of ResNet Blocks for generators')
parser.add_argument('--lr', type=float, default=0.0002, help='Learning rate for generators')
parser.add_argument('--betas', type=tuple, default=(0.5, 0.999), help='Betas for Adam optimizer')
parser.add_argument('--epochs', type=int, default=200, help='Number of epochs')
parser.add_argument('--batch_size', type=int, default=256, help='Batch size')
parser.add_argument('--sample_size', type=int, default=32, help='Size of sampled images')
parser.add_argument('--log_dir', type=str, default='vae/logs', help='Path to where log files will be saved')
parser.add_argument('--data_path', type=str, default='data/img_align_celeba', help='Path to where image data is located')
parser.add_argument('--img_ext', type=str, default='.jpg', help='Image extentions')
parser.add_argument('--checkpoint_dir', type=str, default='vae/model_weights', help='Path to where model weights will be saved')
args = parser.parse_args()


if __name__ == '__main__':
writer = SummaryWriter(args.log_dir)
Path(args.check_point.split('/')[1]).mkdir(parents=True, exist_ok=True)
Path(args.log_dir.split('/')[1]).mkdir(parents=True, exist_ok=True)

loader = get_loaders(args)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = VAE(args).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=args.betas)
scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lambda epoch: 0.95)
fixed_z = torch.randn(args.sample_size, args.z_dim).to(device)
criterion = VAELoss(args)

for epoch in tqdm(range(args.n_epochs)):
losses = []
for img in loader:
x = img.to(device)
x_hat, mu, logvar = model(x)
loss = criterion(x, x_hat, mu, logvar)
losses.append(loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()

scheduler.step()

# logging & generating imgs from fixed vector
writer.add_scalar('Loss', sum(losses) / len(losses), global_step=epoch)
with torch.no_grad():
sampled_images = model.sample(fixed_z)
sampled_images = ((sampled_images + 1) / 2).view(-1, args.img_channels, args.img_size, args.img_size)
writer.add_image('Generated Images', torchvision.utils.make_grid(sampled_images), global_step=epoch)
tqdm.write(
f'Epoch {epoch + 1}/{args.n_epochs}, \
Loss: {sum(losses) / len(losses):.3f}'
)
torch.save({
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
}, f"{opt.checkpoint_dir}/VAE.pth")
11 changes: 10 additions & 1 deletion vae/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from torch import nn
import torch
from networks.layers import ConvNormAct, ResBlock
from networks.layers import ConvNormAct, ResBlock, Reshape


class VAE(nn.Module):
Expand All @@ -17,6 +17,15 @@ def __init__(self, args):
nn.Linear((args.img_size // (2**4))**2 * args.model_dim * 8, args.z_dim * 2)
)
self.decoder = nn.Sequential(
nn.Linear(args.z_dim, (args.img_size // (2**4))**2 * args.model_dim * 8),
nn.BatchNorm1d((args.img_size // (2**4))**2 * args.model_dim * 8),
nn.ReLU(),
Reshape(
args.batch_size,
args.model_dim * 8,
args.img_size // (2**4),
args.img_size // (2**4)
),
ConvNormAct(args.model_dim * 8, args.model_dim * 4, 'up'),
ConvNormAct(args.model_dim * 4, args.model_dim * 2, 'up'),
ConvNormAct(args.model_dim * 2, args.model_dim, 'up'),
Expand Down

0 comments on commit 40967e1

Please sign in to comment.