From 32640129267bdaeecf44214303dc375b81064a20 Mon Sep 17 00:00:00 2001 From: Andrew Zhao Date: Sun, 13 Dec 2020 20:54:25 +0800 Subject: [PATCH] updated network w/ spectral norm --- networks/layers.py | 87 +++++++++++++++++++++++++++++++++++++--------- networks/utils.py | 13 ++++--- sagan/loss.py | 28 ++++++++++++++- sagan/model.py | 28 ++++++++------- sagan/train.py | 18 +++++++--- 5 files changed, 133 insertions(+), 41 deletions(-) diff --git a/networks/layers.py b/networks/layers.py index fd778ad..8662a51 100644 --- a/networks/layers.py +++ b/networks/layers.py @@ -6,21 +6,35 @@ class ConvNormAct(nn.Module): - def __init__(self, in_channels, out_channels, mode=None, activation='relu', normalization='bn', kernel_size=None): + def __init__(self, in_channels, out_channels, conv_type='basic', mode=None, activation='relu', normalization='bn', kernel_size=None): super().__init__() - # typical convolution configs + # type of convolution + if conv_type == 'basic' and mode is None or mode == 'down': + conv = nn.Conv2d + elif conv_type == 'sn' and mode is None or mode == 'down': + conv = SN_Conv2d + elif conv_type == 'basic' and mode is None or mode == 'up': + conv = nn.ConvTranspose2d + elif conv_type == 'sn' and mode is None or mode == 'up': + conv = SN_ConvTranspose2d + else: + raise NotImplementedError('Please only choose conv [basic, sn] and mode [None, down, up]') + if mode == 'up': if kernel_size is None: kernel_size = 4 - conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, 2, 1, bias=False) + conv = conv(in_channels=in_channels, out_channels=out_channels, + kernel_size=kernel_size, stride=2, padding=1, bias=False) elif mode == 'down': if kernel_size is None: kernel_size = 4 - conv = nn.Conv2d(in_channels, out_channels, kernel_size, 2, 1, bias=False) + conv = conv(in_channels=in_channels, out_channels=out_channels, + kernel_size=kernel_size, stride=2, padding=1, bias=False) else: if kernel_size is None: kernel_size = 3 - conv = nn.Conv2d(in_channels, out_channels, kernel_size, 1, 1, bias=False) + conv = conv(in_channels=in_channels, out_channels=out_channels, + kernel_size=kernel_size, stride=1, padding=1, bias=False) # normalization # TODO GroupNorm @@ -31,7 +45,7 @@ def __init__(self, in_channels, out_channels, mode=None, activation='relu', norm elif normalization == 'in': norm = nn.InstanceNorm2d(out_channels) else: - raise NotImplementedError + raise NotImplementedError('Please only choose normalization [bn, ln, in]') # activations if activation == 'relu': @@ -39,7 +53,7 @@ def __init__(self, in_channels, out_channels, mode=None, activation='relu', norm elif activation == 'lrelu': act = nn.LeakyReLU(0.2) else: - raise NotImplementedError + raise NotImplementedError('Please only choose activation [relu, lrelu]') self.block = nn.Sequential( conv, @@ -61,7 +75,7 @@ def __init__(self, in_channels, activation, normalization): elif normalization == 'in': norm = nn.InstanceNorm2d(in_channels) else: - raise NotImplementedError + raise NotImplementedError('Please only choose normalization [bn, ln, in]') # activations if activation == 'relu': @@ -69,7 +83,7 @@ def __init__(self, in_channels, activation, normalization): elif activation == 'lrelu': act = nn.LeakyReLU(0.2) else: - raise NotImplementedError + raise NotImplementedError('Please only choose activation [relu, lrelu]') self.resblock = nn.Sequential( nn.Conv2d( @@ -129,14 +143,55 @@ def predict(self, x): return torch.argmax(predictions, dim=1) +class SN_Conv2d(nn.Module): + def __init__(self, eps=1e-12, **kwargs): + super().__init__() + self.conv = nn.Conv2d(**kwargs) + self.eps = eps + + def forward(self, x): + return nn.utils.spectral_norm(self.conv(x), self.eps) + + +class SN_ConvTranspose2d(nn.Module): + def __init__(self, eps=1e-12, **kwargs): + super().__init__() + self.conv = nn.ConvTranspose2d(**kwargs) + self.eps = eps + + def forward(self, x): + return nn.utils.spectral_norm(self.conv(x), self.eps) + + +class SN_Linear(nn.Module): + def __init__(self, eps=1e-12, **kwargs): + super().__init__() + self.fc = nn.Linear(**kwargs) + self.eps = eps + + def forward(self, x): + return nn.utils.spectral_norm(self.fc(x), self.eps) + + +class SN_Embedding(nn.Module): + def __init__(self, eps=1e-12, **kwargs): + super().__init__() + self.embed = nn.Embedding(**kwargs) + self.eps = eps + + def forward(self, x): + return nn.utils.spectral_norm(self.Embedding(x), self.eps) + + class SA_Conv2d(nn.Module): """SAGAN""" - def __init__(self, in_channels, K=8, down_sample=True): + def __init__(self, in_channels, conv=SN_Conv2d, K=8, down_sample=True): super().__init__() - self.f = nn.Conv2d(in_channels, in_channels // K, kernel_size=1) - self.g = nn.Conv2d(in_channels, in_channels // K, kernel_size=1) - self.h = nn.Conv2d(in_channels, in_channels // 2, kernel_size=1) - self.v = nn.Conv2d(in_channels // 2, in_channels, kernel_size=1) + + self.f = conv(in_channels=in_channels, out_channels=in_channels // K, kernel_size=1) + self.g = conv(in_channels=in_channels, out_channels=in_channels // K, kernel_size=1) + self.h = conv(in_channels=in_channels, out_channels=in_channels // 2, kernel_size=1) + self.v = conv(in_channels=in_channels // 2, out_channels=in_channels, kernel_size=1) # adaptive attention weight self.gamma = nn.Parameter(torch.tensor(0., requires_grad=True)) @@ -172,8 +227,8 @@ def forward(self, x): HW_prime = HW_prime // 4 # update (HW)'<-(HW) // 4 g = g.view(B, C // K, HW_prime) # B x (C/K) x (HW)' - h = h.view(B, 4 * C // K, HW_prime) # B x (C/2) x (HW)' + h = h.view(B, C // 2, HW_prime) # B x (C/2) x (HW)' beta = self._dot_product_softmax(f, g) # B x (HW) x (HW)' - s = torch.einsum('ijk,ilk->ijl', h, beta).view(B, 4 * C // K, H, W) # B x (C/2) x H x W + s = torch.einsum('ijk,ilk->ijl', h, beta).view(B, C // 2, H, W) # B x (C/2) x H x W return self.gamma * self.v(s) + x # B x C x H x W diff --git a/networks/utils.py b/networks/utils.py index 4ca1ced..142b164 100644 --- a/networks/utils.py +++ b/networks/utils.py @@ -15,13 +15,12 @@ def initialize_modules(model, nonlinearity='leaky_relu'): nn.init.constant_(m.bias, 0) -def put_in_list(item): - if not isinstance(item, list, tuple) and item is not None: - item = [item] - return item - - def load_weights(state_dict_path, models, model_names, optimizers=[], optimizer_names=[], return_val=None, return_vals=None): + def put_in_list(item): + if not isinstance(item, list, tuple) and item is not None: + item = [item] + return item + model = put_in_list(models) model_names = put_in_list(model_names) optimizers = put_in_list(optimizers) @@ -37,7 +36,7 @@ def load_weights(state_dict_path, models, model_names, optimizers=[], optimizer_ optimizer.load_state_dict(state_dict[optimizer_name]) if return_val is not None: - return state_dict[key] + return state_dict[return_val] if return_vals is not None: return {key: state_dict[key] for key in return_vals} diff --git a/sagan/loss.py b/sagan/loss.py index fd97694..2b02f13 100644 --- a/sagan/loss.py +++ b/sagan/loss.py @@ -2,7 +2,7 @@ import torch -class SAGAN_Hinge_loss(nn.Module): +class Hinge_loss(nn.Module): def __init__(self, reduction='mean'): super().__init__() assert reduction in ('sum', 'mean') @@ -26,3 +26,29 @@ def _discriminator_loss(self, real_logits, fake_logits): return loss.mean(0) elif self.reduction == 'sum': return loss.sum(0) + + +class Wasserstein_GP_Loss(nn.Module): + def __init__(self, reduction='mean'): + super().__init__() + assert reduction in ('sum', 'mean') + self.reduction = reduction + + def forward(self, fake_logits, mode, real_logits=None): + assert mode in ('generator', 'discriminator', 'gradient penalty') + if mode == 'generator': + return self._generator_loss(fake_logits) + elif mode == 'discriminator': + return self._discriminator_loss(real_logits, fake_logits) + else: + self._grad_penalty_loss() + + def _generator_loss(self, fake_logits): + return - fake_logits.mean() + + def __discriminator_loss(self, real_logits, fake_logits): + return - real_logits.mean() + fake_logits.mean() + + def _grad_penalty_loss(self): + # TODO + pass diff --git a/sagan/model.py b/sagan/model.py index 98f88a9..7de5a79 100644 --- a/sagan/model.py +++ b/sagan/model.py @@ -1,5 +1,6 @@ from torch import nn -from networks.layers import ConvNormAct, SA_Conv2d +from networks.layers import (ConvNormAct, SN_Linear, + SN_Conv2d, SN_ConvTranspose2d, SA_Conv2d) from networks.utils import initialize_modules @@ -7,14 +8,14 @@ class Discriminator(nn.Module): def __init__(self, img_channels, h_dim, img_size): super().__init__() self.disc = nn.Sequential( - nn.Conv2d(img_channels, h_dim, 4, 2, 1), - ConvNormAct(h_dim, h_dim*2, 'down', activation='lrelu', normalization='bn'), - ConvNormAct(h_dim*2, h_dim*4, 'down', activation='lrelu', normalization='bn'), - ConvNormAct(h_dim*4, h_dim*8, 'down', activation='lrelu', normalization='bn'), + SN_Conv2d(in_channels=img_channels, out_channels=h_dim, kernel_size=4, stride=2, padding=1), + ConvNormAct(h_dim, h_dim*2, 'sn', 'down', activation='lrelu', normalization='bn'), + ConvNormAct(h_dim*2, h_dim*4, 'sn', 'down', activation='lrelu', normalization='bn'), + ConvNormAct(h_dim*4, h_dim*8, 'sn', 'down', activation='lrelu', normalization='bn'), nn.AdaptiveAvgPool2d(1), ) self.in_features = h_dim*8 - self.fc = nn.Linear(self.in_features, 1) + self.fc = SN_Linear(in_features=self.in_features, out_features=1) initialize_modules(self) def forward(self, x): @@ -28,16 +29,17 @@ def __init__(self, h_dim, z_dim, img_channels, img_size): super().__init__() self.min_hw = (img_size // (2 ** 5)) ** 2 self.h_dim = h_dim - self.project = nn.Linear(z_dim, h_dim*8 * self.min_hw ** 2) + self.project = SN_Linear(in_features=z_dim, out_features=h_dim*8 * self.min_hw ** 2, bias=False) self.gen = nn.Sequential( nn.BatchNorm2d(h_dim*8, momentum=0.9), nn.ReLU(), - ConvNormAct(h_dim*8, h_dim*4, 'up', activation='relu', normalization='bn'), - ConvNormAct(h_dim*4, h_dim*2, 'up', activation='relu', normalization='bn'), - ConvNormAct(h_dim*2, h_dim, 'up', activation='relu', normalization='bn'), - SA_Conv2d(h_dim), - nn.ConvTranspose2d(h_dim, img_channels, 4, 2, 1), - nn.Sigmoid() + ConvNormAct(h_dim*8, h_dim*4, 'sn', 'up', activation='relu', normalization='bn'), + ConvNormAct(h_dim*4, h_dim*2, 'sn', 'up', activation='relu', normalization='bn'), + SA_Conv2d(h_dim*2), + ConvNormAct(h_dim*2, h_dim, 'sn', 'up', activation='relu', normalization='bn'), + SN_ConvTranspose2d(in_channels=h_dim, out_channels=img_channels, kernel_size=4, + stride=2, padding=1), + nn.Tanh() ) initialize_modules(self) diff --git a/sagan/train.py b/sagan/train.py index 0c54e0d..ca8b063 100644 --- a/sagan/train.py +++ b/sagan/train.py @@ -11,7 +11,7 @@ from dcgan.data import get_loaders from networks.utils import load_weights from sagan.model import Generator, Discriminator -from sagan.loss import SAGAN_Hinge_loss +from sagan.loss import Hinge_loss, Wasserstein_GP_Loss parser = argparse.ArgumentParser() @@ -30,11 +30,12 @@ # training parameters parser.add_argument('--lr_G', type=float, default=0.0004, help='Learning rate for generator') parser.add_argument('--lr_D', type=float, default=0.0004, help='Learning rate for discriminator') -parser.add_argument('--betas', type=tuple, default=(0.0, 0.99), help='Betas for Adam optimizer') +parser.add_argument('--betas', type=tuple, default=(0.0, 0.9), help='Betas for Adam optimizer') parser.add_argument('--n_epochs', type=int, default=50, help='Number of epochs') parser.add_argument('--batch_size', type=int, default=256, help='Batch size') parser.add_argument('--continue_train', action="store_true", default=False, help='Whether to save samples locally') parser.add_argument('--devices', type=list, default=[0, 1], help='List of training devices') +parser.add_argument('--criterion', type=str, default='hinge', help='Loss function [hinge, wasserstein-gp]') # logging parameters parser.add_argument('--data_path', type=str, default='data/img_align_celeba', help='Path to where image data is located') @@ -60,7 +61,12 @@ def train(): G = torch.nn.DataParallel(Generator(opt.h_dim, opt.z_dim, opt.img_channels, opt.img_size), device_ids=opt.devices).to(device) D = torch.nn.DataParallel(Discriminator(opt.img_channels, opt.h_dim, opt.img_size), device_ids=opt.devices).to(device) - criterion = SAGAN_Hinge_loss() + if opt.criterion == 'wasserstein-gp': + criterion = Wasserstein_GP_Loss() + elif opt.criterion == 'hinge': + criterion = Hinge_loss() + else: + raise NotImplementedError('Please choose criterion [hinge, wasserstein-gp]') optimizer_G = torch.optim.Adam(G.parameters(), lr=opt.lr_G, betas=opt.betas) optimizer_D = torch.optim.Adam(D.parameters(), lr=opt.lr_D, betas=opt.betas) @@ -98,6 +104,9 @@ def train(): # compute losses d_loss = criterion(fake_logits=logits_fake, real_logits=logits_real, mode='discriminator') + if opt.criterion == 'wasserstein-gp': + # TODO + continue g_loss = criterion(fake_logits=D(fakes), mode='generator') # update discriminator @@ -125,10 +134,11 @@ def train(): # generate image from fixed noise vector with torch.no_grad(): samples = G(fixed_z) + samples = (samples + 1) / 2 # save locally if opt.save_local_samples: - torchvision.utils.save_image(samples, f'{opt.sample_dir}/Interval_{ckpt_iter}.png') + torchvision.utils.save_image(samples, f'{opt.sample_dir}/Interval_{ckpt_iter}.{opt.img_ext}') # save sample and loss to tensorboard writer.add_image('Generated Images', torchvision.utils.make_grid(samples), global_step=ckpt_iter)