Skip to content

Commit

Permalink
updated network w/ spectral norm
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrewzh112 committed Dec 13, 2020
1 parent 1c11aaa commit 3264012
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 41 deletions.
87 changes: 71 additions & 16 deletions networks/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,35 @@


class ConvNormAct(nn.Module):
def __init__(self, in_channels, out_channels, mode=None, activation='relu', normalization='bn', kernel_size=None):
def __init__(self, in_channels, out_channels, conv_type='basic', mode=None, activation='relu', normalization='bn', kernel_size=None):
super().__init__()
# typical convolution configs
# type of convolution
if conv_type == 'basic' and mode is None or mode == 'down':
conv = nn.Conv2d
elif conv_type == 'sn' and mode is None or mode == 'down':
conv = SN_Conv2d
elif conv_type == 'basic' and mode is None or mode == 'up':
conv = nn.ConvTranspose2d
elif conv_type == 'sn' and mode is None or mode == 'up':
conv = SN_ConvTranspose2d
else:
raise NotImplementedError('Please only choose conv [basic, sn] and mode [None, down, up]')

if mode == 'up':
if kernel_size is None:
kernel_size = 4
conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, 2, 1, bias=False)
conv = conv(in_channels=in_channels, out_channels=out_channels,
kernel_size=kernel_size, stride=2, padding=1, bias=False)
elif mode == 'down':
if kernel_size is None:
kernel_size = 4
conv = nn.Conv2d(in_channels, out_channels, kernel_size, 2, 1, bias=False)
conv = conv(in_channels=in_channels, out_channels=out_channels,
kernel_size=kernel_size, stride=2, padding=1, bias=False)
else:
if kernel_size is None:
kernel_size = 3
conv = nn.Conv2d(in_channels, out_channels, kernel_size, 1, 1, bias=False)
conv = conv(in_channels=in_channels, out_channels=out_channels,
kernel_size=kernel_size, stride=1, padding=1, bias=False)

# normalization
# TODO GroupNorm
Expand All @@ -31,15 +45,15 @@ def __init__(self, in_channels, out_channels, mode=None, activation='relu', norm
elif normalization == 'in':
norm = nn.InstanceNorm2d(out_channels)
else:
raise NotImplementedError
raise NotImplementedError('Please only choose normalization [bn, ln, in]')

# activations
if activation == 'relu':
act = nn.ReLU()
elif activation == 'lrelu':
act = nn.LeakyReLU(0.2)
else:
raise NotImplementedError
raise NotImplementedError('Please only choose activation [relu, lrelu]')

self.block = nn.Sequential(
conv,
Expand All @@ -61,15 +75,15 @@ def __init__(self, in_channels, activation, normalization):
elif normalization == 'in':
norm = nn.InstanceNorm2d(in_channels)
else:
raise NotImplementedError
raise NotImplementedError('Please only choose normalization [bn, ln, in]')

# activations
if activation == 'relu':
act = nn.ReLU()
elif activation == 'lrelu':
act = nn.LeakyReLU(0.2)
else:
raise NotImplementedError
raise NotImplementedError('Please only choose activation [relu, lrelu]')

self.resblock = nn.Sequential(
nn.Conv2d(
Expand Down Expand Up @@ -129,14 +143,55 @@ def predict(self, x):
return torch.argmax(predictions, dim=1)


class SN_Conv2d(nn.Module):
def __init__(self, eps=1e-12, **kwargs):
super().__init__()
self.conv = nn.Conv2d(**kwargs)
self.eps = eps

def forward(self, x):
return nn.utils.spectral_norm(self.conv(x), self.eps)


class SN_ConvTranspose2d(nn.Module):
def __init__(self, eps=1e-12, **kwargs):
super().__init__()
self.conv = nn.ConvTranspose2d(**kwargs)
self.eps = eps

def forward(self, x):
return nn.utils.spectral_norm(self.conv(x), self.eps)


class SN_Linear(nn.Module):
def __init__(self, eps=1e-12, **kwargs):
super().__init__()
self.fc = nn.Linear(**kwargs)
self.eps = eps

def forward(self, x):
return nn.utils.spectral_norm(self.fc(x), self.eps)


class SN_Embedding(nn.Module):
def __init__(self, eps=1e-12, **kwargs):
super().__init__()
self.embed = nn.Embedding(**kwargs)
self.eps = eps

def forward(self, x):
return nn.utils.spectral_norm(self.Embedding(x), self.eps)


class SA_Conv2d(nn.Module):
"""SAGAN"""
def __init__(self, in_channels, K=8, down_sample=True):
def __init__(self, in_channels, conv=SN_Conv2d, K=8, down_sample=True):
super().__init__()
self.f = nn.Conv2d(in_channels, in_channels // K, kernel_size=1)
self.g = nn.Conv2d(in_channels, in_channels // K, kernel_size=1)
self.h = nn.Conv2d(in_channels, in_channels // 2, kernel_size=1)
self.v = nn.Conv2d(in_channels // 2, in_channels, kernel_size=1)

self.f = conv(in_channels=in_channels, out_channels=in_channels // K, kernel_size=1)
self.g = conv(in_channels=in_channels, out_channels=in_channels // K, kernel_size=1)
self.h = conv(in_channels=in_channels, out_channels=in_channels // 2, kernel_size=1)
self.v = conv(in_channels=in_channels // 2, out_channels=in_channels, kernel_size=1)

# adaptive attention weight
self.gamma = nn.Parameter(torch.tensor(0., requires_grad=True))
Expand Down Expand Up @@ -172,8 +227,8 @@ def forward(self, x):
HW_prime = HW_prime // 4 # update (HW)'<-(HW) // 4

g = g.view(B, C // K, HW_prime) # B x (C/K) x (HW)'
h = h.view(B, 4 * C // K, HW_prime) # B x (C/2) x (HW)'
h = h.view(B, C // 2, HW_prime) # B x (C/2) x (HW)'

beta = self._dot_product_softmax(f, g) # B x (HW) x (HW)'
s = torch.einsum('ijk,ilk->ijl', h, beta).view(B, 4 * C // K, H, W) # B x (C/2) x H x W
s = torch.einsum('ijk,ilk->ijl', h, beta).view(B, C // 2, H, W) # B x (C/2) x H x W
return self.gamma * self.v(s) + x # B x C x H x W
13 changes: 6 additions & 7 deletions networks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@ def initialize_modules(model, nonlinearity='leaky_relu'):
nn.init.constant_(m.bias, 0)


def put_in_list(item):
if not isinstance(item, list, tuple) and item is not None:
item = [item]
return item


def load_weights(state_dict_path, models, model_names, optimizers=[], optimizer_names=[], return_val=None, return_vals=None):
def put_in_list(item):
if not isinstance(item, list, tuple) and item is not None:
item = [item]
return item

model = put_in_list(models)
model_names = put_in_list(model_names)
optimizers = put_in_list(optimizers)
Expand All @@ -37,7 +36,7 @@ def load_weights(state_dict_path, models, model_names, optimizers=[], optimizer_
optimizer.load_state_dict(state_dict[optimizer_name])

if return_val is not None:
return state_dict[key]
return state_dict[return_val]

if return_vals is not None:
return {key: state_dict[key] for key in return_vals}
28 changes: 27 additions & 1 deletion sagan/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch


class SAGAN_Hinge_loss(nn.Module):
class Hinge_loss(nn.Module):
def __init__(self, reduction='mean'):
super().__init__()
assert reduction in ('sum', 'mean')
Expand All @@ -26,3 +26,29 @@ def _discriminator_loss(self, real_logits, fake_logits):
return loss.mean(0)
elif self.reduction == 'sum':
return loss.sum(0)


class Wasserstein_GP_Loss(nn.Module):
def __init__(self, reduction='mean'):
super().__init__()
assert reduction in ('sum', 'mean')
self.reduction = reduction

def forward(self, fake_logits, mode, real_logits=None):
assert mode in ('generator', 'discriminator', 'gradient penalty')
if mode == 'generator':
return self._generator_loss(fake_logits)
elif mode == 'discriminator':
return self._discriminator_loss(real_logits, fake_logits)
else:
self._grad_penalty_loss()

def _generator_loss(self, fake_logits):
return - fake_logits.mean()

def __discriminator_loss(self, real_logits, fake_logits):
return - real_logits.mean() + fake_logits.mean()

def _grad_penalty_loss(self):
# TODO
pass
28 changes: 15 additions & 13 deletions sagan/model.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
from torch import nn
from networks.layers import ConvNormAct, SA_Conv2d
from networks.layers import (ConvNormAct, SN_Linear,
SN_Conv2d, SN_ConvTranspose2d, SA_Conv2d)
from networks.utils import initialize_modules


class Discriminator(nn.Module):
def __init__(self, img_channels, h_dim, img_size):
super().__init__()
self.disc = nn.Sequential(
nn.Conv2d(img_channels, h_dim, 4, 2, 1),
ConvNormAct(h_dim, h_dim*2, 'down', activation='lrelu', normalization='bn'),
ConvNormAct(h_dim*2, h_dim*4, 'down', activation='lrelu', normalization='bn'),
ConvNormAct(h_dim*4, h_dim*8, 'down', activation='lrelu', normalization='bn'),
SN_Conv2d(in_channels=img_channels, out_channels=h_dim, kernel_size=4, stride=2, padding=1),
ConvNormAct(h_dim, h_dim*2, 'sn', 'down', activation='lrelu', normalization='bn'),
ConvNormAct(h_dim*2, h_dim*4, 'sn', 'down', activation='lrelu', normalization='bn'),
ConvNormAct(h_dim*4, h_dim*8, 'sn', 'down', activation='lrelu', normalization='bn'),
nn.AdaptiveAvgPool2d(1),
)
self.in_features = h_dim*8
self.fc = nn.Linear(self.in_features, 1)
self.fc = SN_Linear(in_features=self.in_features, out_features=1)
initialize_modules(self)

def forward(self, x):
Expand All @@ -28,16 +29,17 @@ def __init__(self, h_dim, z_dim, img_channels, img_size):
super().__init__()
self.min_hw = (img_size // (2 ** 5)) ** 2
self.h_dim = h_dim
self.project = nn.Linear(z_dim, h_dim*8 * self.min_hw ** 2)
self.project = SN_Linear(in_features=z_dim, out_features=h_dim*8 * self.min_hw ** 2, bias=False)
self.gen = nn.Sequential(
nn.BatchNorm2d(h_dim*8, momentum=0.9),
nn.ReLU(),
ConvNormAct(h_dim*8, h_dim*4, 'up', activation='relu', normalization='bn'),
ConvNormAct(h_dim*4, h_dim*2, 'up', activation='relu', normalization='bn'),
ConvNormAct(h_dim*2, h_dim, 'up', activation='relu', normalization='bn'),
SA_Conv2d(h_dim),
nn.ConvTranspose2d(h_dim, img_channels, 4, 2, 1),
nn.Sigmoid()
ConvNormAct(h_dim*8, h_dim*4, 'sn', 'up', activation='relu', normalization='bn'),
ConvNormAct(h_dim*4, h_dim*2, 'sn', 'up', activation='relu', normalization='bn'),
SA_Conv2d(h_dim*2),
ConvNormAct(h_dim*2, h_dim, 'sn', 'up', activation='relu', normalization='bn'),
SN_ConvTranspose2d(in_channels=h_dim, out_channels=img_channels, kernel_size=4,
stride=2, padding=1),
nn.Tanh()
)
initialize_modules(self)

Expand Down
18 changes: 14 additions & 4 deletions sagan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from dcgan.data import get_loaders
from networks.utils import load_weights
from sagan.model import Generator, Discriminator
from sagan.loss import SAGAN_Hinge_loss
from sagan.loss import Hinge_loss, Wasserstein_GP_Loss


parser = argparse.ArgumentParser()
Expand All @@ -30,11 +30,12 @@
# training parameters
parser.add_argument('--lr_G', type=float, default=0.0004, help='Learning rate for generator')
parser.add_argument('--lr_D', type=float, default=0.0004, help='Learning rate for discriminator')
parser.add_argument('--betas', type=tuple, default=(0.0, 0.99), help='Betas for Adam optimizer')
parser.add_argument('--betas', type=tuple, default=(0.0, 0.9), help='Betas for Adam optimizer')
parser.add_argument('--n_epochs', type=int, default=50, help='Number of epochs')
parser.add_argument('--batch_size', type=int, default=256, help='Batch size')
parser.add_argument('--continue_train', action="store_true", default=False, help='Whether to save samples locally')
parser.add_argument('--devices', type=list, default=[0, 1], help='List of training devices')
parser.add_argument('--criterion', type=str, default='hinge', help='Loss function [hinge, wasserstein-gp]')

# logging parameters
parser.add_argument('--data_path', type=str, default='data/img_align_celeba', help='Path to where image data is located')
Expand All @@ -60,7 +61,12 @@ def train():
G = torch.nn.DataParallel(Generator(opt.h_dim, opt.z_dim, opt.img_channels, opt.img_size), device_ids=opt.devices).to(device)
D = torch.nn.DataParallel(Discriminator(opt.img_channels, opt.h_dim, opt.img_size), device_ids=opt.devices).to(device)

criterion = SAGAN_Hinge_loss()
if opt.criterion == 'wasserstein-gp':
criterion = Wasserstein_GP_Loss()
elif opt.criterion == 'hinge':
criterion = Hinge_loss()
else:
raise NotImplementedError('Please choose criterion [hinge, wasserstein-gp]')
optimizer_G = torch.optim.Adam(G.parameters(), lr=opt.lr_G, betas=opt.betas)
optimizer_D = torch.optim.Adam(D.parameters(), lr=opt.lr_D, betas=opt.betas)

Expand Down Expand Up @@ -98,6 +104,9 @@ def train():

# compute losses
d_loss = criterion(fake_logits=logits_fake, real_logits=logits_real, mode='discriminator')
if opt.criterion == 'wasserstein-gp':
# TODO
continue
g_loss = criterion(fake_logits=D(fakes), mode='generator')

# update discriminator
Expand Down Expand Up @@ -125,10 +134,11 @@ def train():
# generate image from fixed noise vector
with torch.no_grad():
samples = G(fixed_z)
samples = (samples + 1) / 2

# save locally
if opt.save_local_samples:
torchvision.utils.save_image(samples, f'{opt.sample_dir}/Interval_{ckpt_iter}.png')
torchvision.utils.save_image(samples, f'{opt.sample_dir}/Interval_{ckpt_iter}.{opt.img_ext}')

# save sample and loss to tensorboard
writer.add_image('Generated Images', torchvision.utils.make_grid(samples), global_step=ckpt_iter)
Expand Down

0 comments on commit 3264012

Please sign in to comment.