diff --git a/networks/layers.py b/networks/layers.py index 9990fdc..c022d52 100644 --- a/networks/layers.py +++ b/networks/layers.py @@ -81,3 +81,12 @@ def __init__(self, in_channels, activation, normalization): def forward(self, x): return self.resblock(x) + + +class Reshape(nn.Module): + def __init__(self, *args): + super().__init__() + self.shape = args + + def forward(self, x): + return x.view(self.shape) diff --git a/vae/loss.py b/vae/loss.py new file mode 100644 index 0000000..beb1930 --- /dev/null +++ b/vae/loss.py @@ -0,0 +1,19 @@ +from torch import nn +import torch + + +class VAELoss(nn.Module): + def __init__(self, args): + super().__init__() + if args.recon == 'l1': + self.recon = nn.L1Loss() + elif args.recon == 'l2': + self.recon = nn.MSELoss() + + def _KL_Loss(self, mu, logvar): + return torch.mean(-0.5 * torch.sum(1 + logvar - mu ** 2 - logvar.exp(), dim=1), dim=0) + + def forward(self, x, x_hat, mu, logvar): + reconstruction_loss = self.recon(x_hat, x) + kl_loss = self._KL_Loss(mu, logvar) + return reconstruction_loss + kl_loss diff --git a/vae/main.py b/vae/main.py index e69de29..90ed48c 100644 --- a/vae/main.py +++ b/vae/main.py @@ -0,0 +1,73 @@ +import torch +import torchvision +import argparse +import os +from torch.utils.tensorboard import SummaryWriter +from tqdm import tqdm +from pathlib import Path + +from vae.data import get_loaders +from vae.model import VAE +from vae.loss import VAELoss + + +parser = argparse.ArgumentParser() +parser.add_argument('--img_channels', type=int, default=3, help='Numer of channels for images') +parser.add_argument('--model_dim', type=float, default=64, help='model dimensions multiplier') +parser.add_argument('--z_dim', type=float, default=100, help='dimension of random noise latent vector') +parser.add_argument('--img_size', type=int, default=64, help='H, W of the input images') +parser.add_argument('--crop_size', type=int, default=128, help='H, W of the input images') +parser.add_argument('--n_res_blocks', type=int, default=9, help='Number of ResNet Blocks for generators') +parser.add_argument('--lr', type=float, default=0.0002, help='Learning rate for generators') +parser.add_argument('--betas', type=tuple, default=(0.5, 0.999), help='Betas for Adam optimizer') +parser.add_argument('--epochs', type=int, default=200, help='Number of epochs') +parser.add_argument('--batch_size', type=int, default=256, help='Batch size') +parser.add_argument('--sample_size', type=int, default=32, help='Size of sampled images') +parser.add_argument('--log_dir', type=str, default='vae/logs', help='Path to where log files will be saved') +parser.add_argument('--data_path', type=str, default='data/img_align_celeba', help='Path to where image data is located') +parser.add_argument('--img_ext', type=str, default='.jpg', help='Image extentions') +parser.add_argument('--checkpoint_dir', type=str, default='vae/model_weights', help='Path to where model weights will be saved') +args = parser.parse_args() + + +if __name__ == '__main__': + writer = SummaryWriter(args.log_dir) + Path(args.check_point.split('/')[1]).mkdir(parents=True, exist_ok=True) + Path(args.log_dir.split('/')[1]).mkdir(parents=True, exist_ok=True) + + loader = get_loaders(args) + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + model = VAE(args).to(device) + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=args.betas) + scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lambda epoch: 0.95) + fixed_z = torch.randn(args.sample_size, args.z_dim).to(device) + criterion = VAELoss(args) + + for epoch in tqdm(range(args.n_epochs)): + losses = [] + for img in loader: + x = img.to(device) + x_hat, mu, logvar = model(x) + loss = criterion(x, x_hat, mu, logvar) + losses.append(loss.item()) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + scheduler.step() + + # logging & generating imgs from fixed vector + writer.add_scalar('Loss', sum(losses) / len(losses), global_step=epoch) + with torch.no_grad(): + sampled_images = model.sample(fixed_z) + sampled_images = ((sampled_images + 1) / 2).view(-1, args.img_channels, args.img_size, args.img_size) + writer.add_image('Generated Images', torchvision.utils.make_grid(sampled_images), global_step=epoch) + tqdm.write( + f'Epoch {epoch + 1}/{args.n_epochs}, \ + Loss: {sum(losses) / len(losses):.3f}' + ) + torch.save({ + 'model': model.state_dict(), + 'optimizer': optimizer.state_dict(), + }, f"{opt.checkpoint_dir}/VAE.pth") diff --git a/vae/model.py b/vae/model.py index 5c77aba..e913fb4 100644 --- a/vae/model.py +++ b/vae/model.py @@ -1,6 +1,6 @@ from torch import nn import torch -from networks.layers import ConvNormAct, ResBlock +from networks.layers import ConvNormAct, ResBlock, Reshape class VAE(nn.Module): @@ -17,6 +17,15 @@ def __init__(self, args): nn.Linear((args.img_size // (2**4))**2 * args.model_dim * 8, args.z_dim * 2) ) self.decoder = nn.Sequential( + nn.Linear(args.z_dim, (args.img_size // (2**4))**2 * args.model_dim * 8), + nn.BatchNorm1d((args.img_size // (2**4))**2 * args.model_dim * 8), + nn.ReLU(), + Reshape( + args.batch_size, + args.model_dim * 8, + args.img_size // (2**4), + args.img_size // (2**4) + ), ConvNormAct(args.model_dim * 8, args.model_dim * 4, 'up'), ConvNormAct(args.model_dim * 4, args.model_dim * 2, 'up'), ConvNormAct(args.model_dim * 2, args.model_dim, 'up'),