From 7b6915f04c2517b89d307ac8e3aed9a77d53ea2a Mon Sep 17 00:00:00 2001 From: Andrew Zhao Date: Mon, 14 Dec 2020 11:10:27 +0800 Subject: [PATCH] added wgangp loss, changed sagan model architecture and training step --- networks/layers.py | 8 ++++---- networks/utils.py | 3 ++- sagan/loss.py | 26 +++++++++++++++++++------- sagan/model.py | 5 ++++- sagan/train.py | 34 ++++++++++++++++++++-------------- 5 files changed, 49 insertions(+), 27 deletions(-) diff --git a/networks/layers.py b/networks/layers.py index f0d6c6e..e149519 100644 --- a/networks/layers.py +++ b/networks/layers.py @@ -146,7 +146,7 @@ def predict(self, x): class SN_Conv2d(nn.Module): def __init__(self, eps=1e-12, **kwargs): super().__init__() - self.conv = nn.utils.spectral_norm(nn.Conv2d(**kwargs), eps) + self.conv = nn.utils.spectral_norm(nn.Conv2d(**kwargs), eps=eps) def forward(self, x): return self.conv(x) @@ -155,7 +155,7 @@ def forward(self, x): class SN_ConvTranspose2d(nn.Module): def __init__(self, eps=1e-12, **kwargs): super().__init__() - self.conv = nn.utils.spectral_norm(nn.ConvTranspose2d(**kwargs), self.eps) + self.conv = nn.utils.spectral_norm(nn.ConvTranspose2d(**kwargs), eps=eps) def forward(self, x): return self.conv(x) @@ -164,7 +164,7 @@ def forward(self, x): class SN_Linear(nn.Module): def __init__(self, eps=1e-12, **kwargs): super().__init__() - self.fc = nn.utils.spectral_norm(nn.Linear(**kwargs), eps) + self.fc = nn.utils.spectral_norm(nn.Linear(**kwargs), eps=eps) def forward(self, x): return self.fc(x) @@ -173,7 +173,7 @@ def forward(self, x): class SN_Embedding(nn.Module): def __init__(self, eps=1e-12, **kwargs): super().__init__() - self.embed = nn.utils.spectral_norm(nn.Embedding(**kwargs), eps) + self.embed = nn.utils.spectral_norm(nn.Embedding(**kwargs), eps=eps) def forward(self, x): return self.Embedding(x) diff --git a/networks/utils.py b/networks/utils.py index 142b164..07a22fe 100644 --- a/networks/utils.py +++ b/networks/utils.py @@ -12,7 +12,8 @@ def initialize_modules(model, nonlinearity='leaky_relu'): ) elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.Linear)): nn.init.normal_(m.weight, 0.0, 0.02) - nn.init.constant_(m.bias, 0) + if m.bias is not None: + nn.init.constant_(m.bias, 0) def load_weights(state_dict_path, models, model_names, optimizers=[], optimizer_names=[], return_val=None, return_vals=None): diff --git a/sagan/loss.py b/sagan/loss.py index 2b02f13..63f2067 100644 --- a/sagan/loss.py +++ b/sagan/loss.py @@ -1,5 +1,6 @@ from torch import nn import torch +from torch.autograd import Variable class Hinge_loss(nn.Module): @@ -29,26 +30,37 @@ def _discriminator_loss(self, real_logits, fake_logits): class Wasserstein_GP_Loss(nn.Module): - def __init__(self, reduction='mean'): + def __init__(self, lambda_gp=10, reduction='mean'): super().__init__() assert reduction in ('sum', 'mean') self.reduction = reduction + self.lambda_gp = lambda_gp def forward(self, fake_logits, mode, real_logits=None): - assert mode in ('generator', 'discriminator', 'gradient penalty') + assert mode in ('generator', 'discriminator') if mode == 'generator': return self._generator_loss(fake_logits) elif mode == 'discriminator': return self._discriminator_loss(real_logits, fake_logits) - else: - self._grad_penalty_loss() def _generator_loss(self, fake_logits): return - fake_logits.mean() def __discriminator_loss(self, real_logits, fake_logits): return - real_logits.mean() + fake_logits.mean() + + def get_interpolates(self, reals, fakes): + alpha = torch.rand(reals.size(0), 1, 1, 1).expand_as(reals).to(reals.device) + interpolates = alpha * reals.data + ((1 - alpha) * fakes.data) + return Variable(interpolates, requires_grad=True) - def _grad_penalty_loss(self): - # TODO - pass + def grad_penalty_loss(self, interpolates, interpolate_logits): + gradients = torch.autograd.grad(outputs=interpolate_logits, + inputs=interpolates, + grad_outputs=interpolate_logits.new_ones(interpolate_logits.size()), + create_graph=True, + retain_graph=True, + only_inputs=True)[0] + gradients = gradients.view(gradients.size(0), -1) + gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * self.lambda_gp + return gradient_penalty diff --git a/sagan/model.py b/sagan/model.py index a951f42..2108399 100644 --- a/sagan/model.py +++ b/sagan/model.py @@ -11,7 +11,9 @@ def __init__(self, img_channels, h_dim, img_size): SN_Conv2d(in_channels=img_channels, out_channels=h_dim, kernel_size=4, stride=2, padding=1), ConvNormAct(h_dim, h_dim*2, 'sn', 'down', activation='lrelu', normalization='bn'), ConvNormAct(h_dim*2, h_dim*4, 'sn', 'down', activation='lrelu', normalization='bn'), + SA_Conv2d(h_dim*4), ConvNormAct(h_dim*4, h_dim*8, 'sn', 'down', activation='lrelu', normalization='bn'), + ConvNormAct(h_dim*8, h_dim*8, 'sn', 'down', activation='lrelu', normalization='bn'), nn.AdaptiveAvgPool2d(1), ) self.in_features = h_dim*8 @@ -27,12 +29,13 @@ def forward(self, x): class Generator(nn.Module): def __init__(self, h_dim, z_dim, img_channels, img_size): super().__init__() - self.min_hw = (img_size // (2 ** 4)) ** 2 + self.min_hw = (img_size // (2 ** 5)) ** 2 self.h_dim = h_dim self.project = SN_Linear(in_features=z_dim, out_features=h_dim*8 * self.min_hw ** 2, bias=False) self.gen = nn.Sequential( nn.BatchNorm2d(h_dim*8, momentum=0.9), nn.ReLU(), + ConvNormAct(h_dim*8, h_dim*8, 'sn', 'up', activation='relu', normalization='bn'), ConvNormAct(h_dim*8, h_dim*4, 'sn', 'up', activation='relu', normalization='bn'), ConvNormAct(h_dim*4, h_dim*2, 'sn', 'up', activation='relu', normalization='bn'), SA_Conv2d(h_dim*2), diff --git a/sagan/train.py b/sagan/train.py index ca8b063..c191a1d 100644 --- a/sagan/train.py +++ b/sagan/train.py @@ -28,9 +28,10 @@ parser.add_argument('--download', action="store_true", default=False, help='If auto download CelebA dataset') # training parameters -parser.add_argument('--lr_G', type=float, default=0.0004, help='Learning rate for generator') +parser.add_argument('--lr_G', type=float, default=0.0001, help='Learning rate for generator') parser.add_argument('--lr_D', type=float, default=0.0004, help='Learning rate for discriminator') parser.add_argument('--betas', type=tuple, default=(0.0, 0.9), help='Betas for Adam optimizer') +parser.add_argument('--lambda_gp', type=float, default=10., help='Gradient penalty term') parser.add_argument('--n_epochs', type=int, default=50, help='Number of epochs') parser.add_argument('--batch_size', type=int, default=256, help='Batch size') parser.add_argument('--continue_train', action="store_true", default=False, help='Whether to save samples locally') @@ -62,7 +63,7 @@ def train(): D = torch.nn.DataParallel(Discriminator(opt.img_channels, opt.h_dim, opt.img_size), device_ids=opt.devices).to(device) if opt.criterion == 'wasserstein-gp': - criterion = Wasserstein_GP_Loss() + criterion = Wasserstein_GP_Loss(opt.lambda_gp) elif opt.criterion == 'hinge': criterion = Hinge_loss() else: @@ -97,28 +98,33 @@ def train(): reals = reals.to(device) z = torch.randn(reals.size(0), opt.z_dim).to(device) - # forward + # forward generator + optimizer_G.zero_grad() fakes = G(z) + + # compute loss & update gen + g_loss = criterion(fake_logits=D(fakes), mode='generator') + g_loss.backward() + optimizer_G.step() + + # forward discriminator + optimizer_D.zero_grad() logits_fake = D(fakes.detach()) logits_real = D(reals) - # compute losses + # compute loss & update disc d_loss = criterion(fake_logits=logits_fake, real_logits=logits_real, mode='discriminator') + + # if wgangp, calculate gradient penalty and add to current d_loss if opt.criterion == 'wasserstein-gp': - # TODO - continue - g_loss = criterion(fake_logits=D(fakes), mode='generator') + interpolates = criterion.get_interpolates(reals, fakes) + interpolated_logits = D(interpolates) + grad_penalty = criterion.grad_penalty_loss(interpolates, interpolated_logits) + d_loss = d_loss + grad_penalty - # update discriminator - optimizer_D.zero_grad() d_loss.backward() optimizer_D.step() - # update generator - optimizer_G.zero_grad() - g_loss.backward() - optimizer_G.step() - # logging d_losses.append(d_loss.item()) g_losses.append(g_loss.item())