Skip to content

Commit

Permalink
added wgangp loss, changed sagan model architecture and training step
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrewzh112 committed Dec 14, 2020
1 parent f5e6312 commit 7b6915f
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 27 deletions.
8 changes: 4 additions & 4 deletions networks/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def predict(self, x):
class SN_Conv2d(nn.Module):
def __init__(self, eps=1e-12, **kwargs):
super().__init__()
self.conv = nn.utils.spectral_norm(nn.Conv2d(**kwargs), eps)
self.conv = nn.utils.spectral_norm(nn.Conv2d(**kwargs), eps=eps)

def forward(self, x):
return self.conv(x)
Expand All @@ -155,7 +155,7 @@ def forward(self, x):
class SN_ConvTranspose2d(nn.Module):
def __init__(self, eps=1e-12, **kwargs):
super().__init__()
self.conv = nn.utils.spectral_norm(nn.ConvTranspose2d(**kwargs), self.eps)
self.conv = nn.utils.spectral_norm(nn.ConvTranspose2d(**kwargs), eps=eps)

def forward(self, x):
return self.conv(x)
Expand All @@ -164,7 +164,7 @@ def forward(self, x):
class SN_Linear(nn.Module):
def __init__(self, eps=1e-12, **kwargs):
super().__init__()
self.fc = nn.utils.spectral_norm(nn.Linear(**kwargs), eps)
self.fc = nn.utils.spectral_norm(nn.Linear(**kwargs), eps=eps)

def forward(self, x):
return self.fc(x)
Expand All @@ -173,7 +173,7 @@ def forward(self, x):
class SN_Embedding(nn.Module):
def __init__(self, eps=1e-12, **kwargs):
super().__init__()
self.embed = nn.utils.spectral_norm(nn.Embedding(**kwargs), eps)
self.embed = nn.utils.spectral_norm(nn.Embedding(**kwargs), eps=eps)

def forward(self, x):
return self.Embedding(x)
Expand Down
3 changes: 2 additions & 1 deletion networks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ def initialize_modules(model, nonlinearity='leaky_relu'):
)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.Linear)):
nn.init.normal_(m.weight, 0.0, 0.02)
nn.init.constant_(m.bias, 0)
if m.bias is not None:
nn.init.constant_(m.bias, 0)


def load_weights(state_dict_path, models, model_names, optimizers=[], optimizer_names=[], return_val=None, return_vals=None):
Expand Down
26 changes: 19 additions & 7 deletions sagan/loss.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from torch import nn
import torch
from torch.autograd import Variable


class Hinge_loss(nn.Module):
Expand Down Expand Up @@ -29,26 +30,37 @@ def _discriminator_loss(self, real_logits, fake_logits):


class Wasserstein_GP_Loss(nn.Module):
def __init__(self, reduction='mean'):
def __init__(self, lambda_gp=10, reduction='mean'):
super().__init__()
assert reduction in ('sum', 'mean')
self.reduction = reduction
self.lambda_gp = lambda_gp

def forward(self, fake_logits, mode, real_logits=None):
assert mode in ('generator', 'discriminator', 'gradient penalty')
assert mode in ('generator', 'discriminator')
if mode == 'generator':
return self._generator_loss(fake_logits)
elif mode == 'discriminator':
return self._discriminator_loss(real_logits, fake_logits)
else:
self._grad_penalty_loss()

def _generator_loss(self, fake_logits):
return - fake_logits.mean()

def __discriminator_loss(self, real_logits, fake_logits):
return - real_logits.mean() + fake_logits.mean()

def get_interpolates(self, reals, fakes):
alpha = torch.rand(reals.size(0), 1, 1, 1).expand_as(reals).to(reals.device)
interpolates = alpha * reals.data + ((1 - alpha) * fakes.data)
return Variable(interpolates, requires_grad=True)

def _grad_penalty_loss(self):
# TODO
pass
def grad_penalty_loss(self, interpolates, interpolate_logits):
gradients = torch.autograd.grad(outputs=interpolate_logits,
inputs=interpolates,
grad_outputs=interpolate_logits.new_ones(interpolate_logits.size()),
create_graph=True,
retain_graph=True,
only_inputs=True)[0]
gradients = gradients.view(gradients.size(0), -1)
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * self.lambda_gp
return gradient_penalty
5 changes: 4 additions & 1 deletion sagan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ def __init__(self, img_channels, h_dim, img_size):
SN_Conv2d(in_channels=img_channels, out_channels=h_dim, kernel_size=4, stride=2, padding=1),
ConvNormAct(h_dim, h_dim*2, 'sn', 'down', activation='lrelu', normalization='bn'),
ConvNormAct(h_dim*2, h_dim*4, 'sn', 'down', activation='lrelu', normalization='bn'),
SA_Conv2d(h_dim*4),
ConvNormAct(h_dim*4, h_dim*8, 'sn', 'down', activation='lrelu', normalization='bn'),
ConvNormAct(h_dim*8, h_dim*8, 'sn', 'down', activation='lrelu', normalization='bn'),
nn.AdaptiveAvgPool2d(1),
)
self.in_features = h_dim*8
Expand All @@ -27,12 +29,13 @@ def forward(self, x):
class Generator(nn.Module):
def __init__(self, h_dim, z_dim, img_channels, img_size):
super().__init__()
self.min_hw = (img_size // (2 ** 4)) ** 2
self.min_hw = (img_size // (2 ** 5)) ** 2
self.h_dim = h_dim
self.project = SN_Linear(in_features=z_dim, out_features=h_dim*8 * self.min_hw ** 2, bias=False)
self.gen = nn.Sequential(
nn.BatchNorm2d(h_dim*8, momentum=0.9),
nn.ReLU(),
ConvNormAct(h_dim*8, h_dim*8, 'sn', 'up', activation='relu', normalization='bn'),
ConvNormAct(h_dim*8, h_dim*4, 'sn', 'up', activation='relu', normalization='bn'),
ConvNormAct(h_dim*4, h_dim*2, 'sn', 'up', activation='relu', normalization='bn'),
SA_Conv2d(h_dim*2),
Expand Down
34 changes: 20 additions & 14 deletions sagan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@
parser.add_argument('--download', action="store_true", default=False, help='If auto download CelebA dataset')

# training parameters
parser.add_argument('--lr_G', type=float, default=0.0004, help='Learning rate for generator')
parser.add_argument('--lr_G', type=float, default=0.0001, help='Learning rate for generator')
parser.add_argument('--lr_D', type=float, default=0.0004, help='Learning rate for discriminator')
parser.add_argument('--betas', type=tuple, default=(0.0, 0.9), help='Betas for Adam optimizer')
parser.add_argument('--lambda_gp', type=float, default=10., help='Gradient penalty term')
parser.add_argument('--n_epochs', type=int, default=50, help='Number of epochs')
parser.add_argument('--batch_size', type=int, default=256, help='Batch size')
parser.add_argument('--continue_train', action="store_true", default=False, help='Whether to save samples locally')
Expand Down Expand Up @@ -62,7 +63,7 @@ def train():
D = torch.nn.DataParallel(Discriminator(opt.img_channels, opt.h_dim, opt.img_size), device_ids=opt.devices).to(device)

if opt.criterion == 'wasserstein-gp':
criterion = Wasserstein_GP_Loss()
criterion = Wasserstein_GP_Loss(opt.lambda_gp)
elif opt.criterion == 'hinge':
criterion = Hinge_loss()
else:
Expand Down Expand Up @@ -97,28 +98,33 @@ def train():
reals = reals.to(device)
z = torch.randn(reals.size(0), opt.z_dim).to(device)

# forward
# forward generator
optimizer_G.zero_grad()
fakes = G(z)

# compute loss & update gen
g_loss = criterion(fake_logits=D(fakes), mode='generator')
g_loss.backward()
optimizer_G.step()

# forward discriminator
optimizer_D.zero_grad()
logits_fake = D(fakes.detach())
logits_real = D(reals)

# compute losses
# compute loss & update disc
d_loss = criterion(fake_logits=logits_fake, real_logits=logits_real, mode='discriminator')

# if wgangp, calculate gradient penalty and add to current d_loss
if opt.criterion == 'wasserstein-gp':
# TODO
continue
g_loss = criterion(fake_logits=D(fakes), mode='generator')
interpolates = criterion.get_interpolates(reals, fakes)
interpolated_logits = D(interpolates)
grad_penalty = criterion.grad_penalty_loss(interpolates, interpolated_logits)
d_loss = d_loss + grad_penalty

# update discriminator
optimizer_D.zero_grad()
d_loss.backward()
optimizer_D.step()

# update generator
optimizer_G.zero_grad()
g_loss.backward()
optimizer_G.step()

# logging
d_losses.append(d_loss.item())
g_losses.append(g_loss.item())
Expand Down

0 comments on commit 7b6915f

Please sign in to comment.