Skip to content

Commit 7b6915f

Browse files
committed
added wgangp loss, changed sagan model architecture and training step
1 parent f5e6312 commit 7b6915f

File tree

5 files changed

+49
-27
lines changed

5 files changed

+49
-27
lines changed

networks/layers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def predict(self, x):
146146
class SN_Conv2d(nn.Module):
147147
def __init__(self, eps=1e-12, **kwargs):
148148
super().__init__()
149-
self.conv = nn.utils.spectral_norm(nn.Conv2d(**kwargs), eps)
149+
self.conv = nn.utils.spectral_norm(nn.Conv2d(**kwargs), eps=eps)
150150

151151
def forward(self, x):
152152
return self.conv(x)
@@ -155,7 +155,7 @@ def forward(self, x):
155155
class SN_ConvTranspose2d(nn.Module):
156156
def __init__(self, eps=1e-12, **kwargs):
157157
super().__init__()
158-
self.conv = nn.utils.spectral_norm(nn.ConvTranspose2d(**kwargs), self.eps)
158+
self.conv = nn.utils.spectral_norm(nn.ConvTranspose2d(**kwargs), eps=eps)
159159

160160
def forward(self, x):
161161
return self.conv(x)
@@ -164,7 +164,7 @@ def forward(self, x):
164164
class SN_Linear(nn.Module):
165165
def __init__(self, eps=1e-12, **kwargs):
166166
super().__init__()
167-
self.fc = nn.utils.spectral_norm(nn.Linear(**kwargs), eps)
167+
self.fc = nn.utils.spectral_norm(nn.Linear(**kwargs), eps=eps)
168168

169169
def forward(self, x):
170170
return self.fc(x)
@@ -173,7 +173,7 @@ def forward(self, x):
173173
class SN_Embedding(nn.Module):
174174
def __init__(self, eps=1e-12, **kwargs):
175175
super().__init__()
176-
self.embed = nn.utils.spectral_norm(nn.Embedding(**kwargs), eps)
176+
self.embed = nn.utils.spectral_norm(nn.Embedding(**kwargs), eps=eps)
177177

178178
def forward(self, x):
179179
return self.Embedding(x)

networks/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ def initialize_modules(model, nonlinearity='leaky_relu'):
1212
)
1313
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.Linear)):
1414
nn.init.normal_(m.weight, 0.0, 0.02)
15-
nn.init.constant_(m.bias, 0)
15+
if m.bias is not None:
16+
nn.init.constant_(m.bias, 0)
1617

1718

1819
def load_weights(state_dict_path, models, model_names, optimizers=[], optimizer_names=[], return_val=None, return_vals=None):

sagan/loss.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from torch import nn
22
import torch
3+
from torch.autograd import Variable
34

45

56
class Hinge_loss(nn.Module):
@@ -29,26 +30,37 @@ def _discriminator_loss(self, real_logits, fake_logits):
2930

3031

3132
class Wasserstein_GP_Loss(nn.Module):
32-
def __init__(self, reduction='mean'):
33+
def __init__(self, lambda_gp=10, reduction='mean'):
3334
super().__init__()
3435
assert reduction in ('sum', 'mean')
3536
self.reduction = reduction
37+
self.lambda_gp = lambda_gp
3638

3739
def forward(self, fake_logits, mode, real_logits=None):
38-
assert mode in ('generator', 'discriminator', 'gradient penalty')
40+
assert mode in ('generator', 'discriminator')
3941
if mode == 'generator':
4042
return self._generator_loss(fake_logits)
4143
elif mode == 'discriminator':
4244
return self._discriminator_loss(real_logits, fake_logits)
43-
else:
44-
self._grad_penalty_loss()
4545

4646
def _generator_loss(self, fake_logits):
4747
return - fake_logits.mean()
4848

4949
def __discriminator_loss(self, real_logits, fake_logits):
5050
return - real_logits.mean() + fake_logits.mean()
51+
52+
def get_interpolates(self, reals, fakes):
53+
alpha = torch.rand(reals.size(0), 1, 1, 1).expand_as(reals).to(reals.device)
54+
interpolates = alpha * reals.data + ((1 - alpha) * fakes.data)
55+
return Variable(interpolates, requires_grad=True)
5156

52-
def _grad_penalty_loss(self):
53-
# TODO
54-
pass
57+
def grad_penalty_loss(self, interpolates, interpolate_logits):
58+
gradients = torch.autograd.grad(outputs=interpolate_logits,
59+
inputs=interpolates,
60+
grad_outputs=interpolate_logits.new_ones(interpolate_logits.size()),
61+
create_graph=True,
62+
retain_graph=True,
63+
only_inputs=True)[0]
64+
gradients = gradients.view(gradients.size(0), -1)
65+
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * self.lambda_gp
66+
return gradient_penalty

sagan/model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ def __init__(self, img_channels, h_dim, img_size):
1111
SN_Conv2d(in_channels=img_channels, out_channels=h_dim, kernel_size=4, stride=2, padding=1),
1212
ConvNormAct(h_dim, h_dim*2, 'sn', 'down', activation='lrelu', normalization='bn'),
1313
ConvNormAct(h_dim*2, h_dim*4, 'sn', 'down', activation='lrelu', normalization='bn'),
14+
SA_Conv2d(h_dim*4),
1415
ConvNormAct(h_dim*4, h_dim*8, 'sn', 'down', activation='lrelu', normalization='bn'),
16+
ConvNormAct(h_dim*8, h_dim*8, 'sn', 'down', activation='lrelu', normalization='bn'),
1517
nn.AdaptiveAvgPool2d(1),
1618
)
1719
self.in_features = h_dim*8
@@ -27,12 +29,13 @@ def forward(self, x):
2729
class Generator(nn.Module):
2830
def __init__(self, h_dim, z_dim, img_channels, img_size):
2931
super().__init__()
30-
self.min_hw = (img_size // (2 ** 4)) ** 2
32+
self.min_hw = (img_size // (2 ** 5)) ** 2
3133
self.h_dim = h_dim
3234
self.project = SN_Linear(in_features=z_dim, out_features=h_dim*8 * self.min_hw ** 2, bias=False)
3335
self.gen = nn.Sequential(
3436
nn.BatchNorm2d(h_dim*8, momentum=0.9),
3537
nn.ReLU(),
38+
ConvNormAct(h_dim*8, h_dim*8, 'sn', 'up', activation='relu', normalization='bn'),
3639
ConvNormAct(h_dim*8, h_dim*4, 'sn', 'up', activation='relu', normalization='bn'),
3740
ConvNormAct(h_dim*4, h_dim*2, 'sn', 'up', activation='relu', normalization='bn'),
3841
SA_Conv2d(h_dim*2),

sagan/train.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,10 @@
2828
parser.add_argument('--download', action="store_true", default=False, help='If auto download CelebA dataset')
2929

3030
# training parameters
31-
parser.add_argument('--lr_G', type=float, default=0.0004, help='Learning rate for generator')
31+
parser.add_argument('--lr_G', type=float, default=0.0001, help='Learning rate for generator')
3232
parser.add_argument('--lr_D', type=float, default=0.0004, help='Learning rate for discriminator')
3333
parser.add_argument('--betas', type=tuple, default=(0.0, 0.9), help='Betas for Adam optimizer')
34+
parser.add_argument('--lambda_gp', type=float, default=10., help='Gradient penalty term')
3435
parser.add_argument('--n_epochs', type=int, default=50, help='Number of epochs')
3536
parser.add_argument('--batch_size', type=int, default=256, help='Batch size')
3637
parser.add_argument('--continue_train', action="store_true", default=False, help='Whether to save samples locally')
@@ -62,7 +63,7 @@ def train():
6263
D = torch.nn.DataParallel(Discriminator(opt.img_channels, opt.h_dim, opt.img_size), device_ids=opt.devices).to(device)
6364

6465
if opt.criterion == 'wasserstein-gp':
65-
criterion = Wasserstein_GP_Loss()
66+
criterion = Wasserstein_GP_Loss(opt.lambda_gp)
6667
elif opt.criterion == 'hinge':
6768
criterion = Hinge_loss()
6869
else:
@@ -97,28 +98,33 @@ def train():
9798
reals = reals.to(device)
9899
z = torch.randn(reals.size(0), opt.z_dim).to(device)
99100

100-
# forward
101+
# forward generator
102+
optimizer_G.zero_grad()
101103
fakes = G(z)
104+
105+
# compute loss & update gen
106+
g_loss = criterion(fake_logits=D(fakes), mode='generator')
107+
g_loss.backward()
108+
optimizer_G.step()
109+
110+
# forward discriminator
111+
optimizer_D.zero_grad()
102112
logits_fake = D(fakes.detach())
103113
logits_real = D(reals)
104114

105-
# compute losses
115+
# compute loss & update disc
106116
d_loss = criterion(fake_logits=logits_fake, real_logits=logits_real, mode='discriminator')
117+
118+
# if wgangp, calculate gradient penalty and add to current d_loss
107119
if opt.criterion == 'wasserstein-gp':
108-
# TODO
109-
continue
110-
g_loss = criterion(fake_logits=D(fakes), mode='generator')
120+
interpolates = criterion.get_interpolates(reals, fakes)
121+
interpolated_logits = D(interpolates)
122+
grad_penalty = criterion.grad_penalty_loss(interpolates, interpolated_logits)
123+
d_loss = d_loss + grad_penalty
111124

112-
# update discriminator
113-
optimizer_D.zero_grad()
114125
d_loss.backward()
115126
optimizer_D.step()
116127

117-
# update generator
118-
optimizer_G.zero_grad()
119-
g_loss.backward()
120-
optimizer_G.step()
121-
122128
# logging
123129
d_losses.append(d_loss.item())
124130
g_losses.append(g_loss.item())

0 commit comments

Comments
 (0)