Skip to content

Commit

Permalink
Clean up of implementations using fully-connected networks
Browse files Browse the repository at this point in the history
  • Loading branch information
eriklindernoren committed May 12, 2018
1 parent e108a05 commit 9e3ac57
Show file tree
Hide file tree
Showing 22 changed files with 378 additions and 435 deletions.
32 changes: 9 additions & 23 deletions implementations/aae/aae.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,9 @@
opt = parser.parse_args()
print(opt)

cuda = True if torch.cuda.is_available() else False
img_shape = (opt.channels, opt.img_size, opt.img_size)

def weights_init_normal(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm2d') != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)
cuda = True if torch.cuda.is_available() else False

def reparameterization(mu, logvar):
std = torch.exp(logvar / 2)
Expand All @@ -52,7 +46,7 @@ def __init__(self):
super(Encoder, self).__init__()

self.model = nn.Sequential(
nn.Linear(opt.img_size**2, 512),
nn.Linear(int(np.prod(img_shape)), 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 512),
nn.BatchNorm1d(512),
Expand All @@ -67,9 +61,7 @@ def forward(self, img):
x = self.model(img_flat)
mu = self.mu(x)
logvar = self.logvar(x)

z = reparameterization(mu, logvar)

return z

class Decoder(nn.Module):
Expand All @@ -82,13 +74,13 @@ def __init__(self):
nn.Linear(512, 512),
nn.BatchNorm1d(512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, opt.img_size**2),
nn.Linear(512, int(np.prod(img_shape))),
nn.Tanh()
)

def forward(self, noise):
img_flat = self.model(noise)
img = img_flat.view(img_flat.shape[0], opt.channels, opt.img_size, opt.img_size)
def forward(self, z):
img_flat = self.model(z)
img = img_flat.view(img_flat.shape[0], *img_shape)
return img

class Discriminator(nn.Module):
Expand All @@ -104,9 +96,8 @@ def __init__(self):
nn.Sigmoid()
)

def forward(self, latent):
validity = self.model(latent)

def forward(self, z):
validity = self.model(z)
return validity

# Use binary cross-entropy loss
Expand All @@ -125,11 +116,6 @@ def forward(self, latent):
adversarial_loss.cuda()
pixelwise_loss.cuda()

# Initialize weights
encoder.apply(weights_init_normal)
decoder.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

# Configure data loader
os.makedirs('../../data/mnist', exist_ok=True)
dataloader = torch.utils.data.DataLoader(
Expand Down
47 changes: 18 additions & 29 deletions implementations/bgan/bgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,47 +32,41 @@
opt = parser.parse_args()
print(opt)

cuda = True if torch.cuda.is_available() else False
img_shape = (opt.channels, opt.img_size, opt.img_size)

def weights_init_normal(m):
classname = m.__class__.__name__
if classname.find('Linear') != -1:
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm2d') != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)
cuda = True if torch.cuda.is_available() else False

class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()

def block(in_feat, out_feat, normalize=True):
layers = [ nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers

self.model = nn.Sequential(
nn.Linear(opt.latent_dim, 128),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(128, 256),
nn.BatchNorm1d(256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 512),
nn.BatchNorm1d(512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 1024),
nn.BatchNorm1d(1024),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(1024, opt.img_size**2),
*block(opt.latent_dim, 128, normalize=False),
*block(128, 256),
*block(256, 512),
*block(512, 1024),
nn.Linear(1024, int(np.prod(img_shape))),
nn.Tanh()
)

def forward(self, noise):
img = self.model(noise)
img = img.view(img.shape[0], opt.channels, opt.img_size, opt.img_size)
def forward(self, z):
img = self.model(z)
img = img.view(img.shape[0], *img_shape)
return img

class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()

self.model = nn.Sequential(
nn.Linear(opt.img_size**2, 512),
nn.Linear(int(np.prod(img_shape)), 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
Expand All @@ -83,7 +77,6 @@ def __init__(self):
def forward(self, img):
img_flat = img.view(img.shape[0], -1)
validity = self.model(img_flat)

return validity

def boundary_seeking_loss(y_pred, y_true):
Expand All @@ -104,10 +97,6 @@ def boundary_seeking_loss(y_pred, y_true):
discriminator.cuda()
discriminator_loss.cuda()

# Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

# Configure data loader
os.makedirs('../../data/mnist', exist_ok=True)
mnist_loader = torch.utils.data.DataLoader(
Expand Down
122 changes: 71 additions & 51 deletions implementations/bicyclegan/bicyclegan.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,31 +25,23 @@
parser.add_argument('--epoch', type=int, default=0, help='epoch to start training from')
parser.add_argument('--n_epochs', type=int, default=200, help='number of epochs of training')
parser.add_argument('--dataset_name', type=str, default="edges2shoes", help='name of the dataset')
parser.add_argument('--batch_size', type=int, default=1, help='size of the batches')
parser.add_argument('--batch_size', type=int, default=8, help='size of the batches')
parser.add_argument('--lr', type=float, default=0.0002, help='adam: learning rate')
parser.add_argument('--b1', type=float, default=0.5, help='adam: decay of first order momentum of gradient')
parser.add_argument('--b2', type=float, default=0.999, help='adam: decay of first order momentum of gradient')
parser.add_argument('--n_cpu', type=int, default=8, help='number of cpu threads to use during batch generation')
parser.add_argument('--img_height', type=int, default=256, help='size of image height')
parser.add_argument('--img_width', type=int, default=256, help='size of image width')
parser.add_argument('--img_height', type=int, default=128, help='size of image height')
parser.add_argument('--img_width', type=int, default=128, help='size of image width')
parser.add_argument('--channels', type=int, default=3, help='number of image channels')
parser.add_argument('--latent_dim', type=int, default=8, help='dimensionality of latent representation')
parser.add_argument('--sample_interval', type=int, default=1000, help='interval between sampling of images from generators')
parser.add_argument('--latent_dim', type=int, default=8, help='number of latent codes')
parser.add_argument('--sample_interval', type=int, default=400, help='interval between sampling of images from generators')
parser.add_argument('--checkpoint_interval', type=int, default=-1, help='interval between model checkpoints')
opt = parser.parse_args()
print(opt)

os.makedirs('images/%s' % opt.dataset_name, exist_ok=True)
os.makedirs('saved_models/%s' % opt.dataset_name, exist_ok=True)

def weights_init_normal(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm2d') != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)

img_shape = (opt.channels, opt.img_height, opt.img_width)

# Loss functions
Expand All @@ -61,8 +53,8 @@ def weights_init_normal(m):
cuda = True if torch.cuda.is_available() else False

# Calculate outputs of multilevel PatchGAN
patch1 = (opt.batch_size, 1, opt.img_height // 2**3, opt.img_width // 2**3)
patch2 = (opt.batch_size, 1, opt.img_height // 2**4, opt.img_width // 2**4)
patch1 = (1, opt.img_height // 2**2, opt.img_width // 2**2)
patch2 = (1, opt.img_height // 2**3, opt.img_width // 2**3)

# Initialize generator, encoder and discriminators
generator = Generator(opt.latent_dim, img_shape)
Expand Down Expand Up @@ -105,41 +97,46 @@ def weights_init_normal(m):

Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor

# Adversarial ground truths
valid1 = Variable(Tensor(np.ones(patch1)), requires_grad=False)
valid2 = Variable(Tensor(np.ones(patch2)), requires_grad=False)
fake1 = Variable(Tensor(np.zeros(patch1)), requires_grad=False)
fake2 = Variable(Tensor(np.zeros(patch2)), requires_grad=False)

# Dataset loader
transforms_ = [ transforms.Resize((opt.img_height, opt.img_width), Image.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ]
dataloader = DataLoader(ImageDataset("../../data/%s" % opt.dataset_name, transforms_=transforms_),
batch_size=opt.batch_size, shuffle=True, num_workers=opt.n_cpu)
val_dataloader = DataLoader(ImageDataset("../../data/%s" % opt.dataset_name, transforms_=transforms_, mode='val'),
batch_size=5, shuffle=True, num_workers=1)
batch_size=8, shuffle=True, num_workers=1)

def sample_images(batches_done):
"""Saves a generated sample from the validation set"""
imgs = next(iter(val_dataloader))
real_A = Variable(imgs['A'].type(Tensor))
sampled_z = Variable(Tensor(np.random.normal(0, 1, (5, opt.latent_dim))))
fake_B = generator(real_A, sampled_z)
real_B = Variable(imgs['B'].type(Tensor))
img_sample = torch.cat((real_A.data, fake_B.data, real_B.data), 0)
save_image(img_sample, 'images/%s/%s.png' % (opt.dataset_name, batches_done), nrow=5, normalize=True)

# ----------
# Training
# ----------
img_samples = None
for img_A, img_B in zip(imgs['A'], imgs['B']):
# Repeat input image by number of channels
real_A = img_A.view(1, *img_A.shape).repeat(8, 1, 1, 1)
real_A = Variable(real_A.type(Tensor))
# Get interpolated noise [-1, 1]
sampled_z = np.repeat(np.linspace(-1, 1, 8)[:, np.newaxis], opt.latent_dim, 1)
sampled_z = Variable(Tensor(sampled_z))
# Generator samples
fake_B = generator(real_A, sampled_z)
# Concatenate samples horisontally
fake_B = torch.cat([x for x in fake_B.data.cpu()], -1)
img_sample = torch.cat((img_A, fake_B), -1)
img_sample = img_sample.view(1, *img_sample.shape)
# Cocatenate with previous samples vertically
img_samples = img_sample if img_samples is None else torch.cat((img_samples, img_sample), -2)
save_image(img_samples, 'images/%s/%s.png' % (opt.dataset_name, batches_done), nrow=5, normalize=True)

def reparameterization(mu, logvar):
std = torch.exp(logvar / 2)
sampled_z = Variable(Tensor(np.random.normal(0, 1, (mu.size(0), opt.latent_dim))))
z = sampled_z * std + mu
return z

# ----------
# Training
# ----------

prev_time = time.time()
for epoch in range(opt.epoch, opt.n_epochs):
for i, batch in enumerate(dataloader):
Expand All @@ -148,57 +145,80 @@ def reparameterization(mu, logvar):
real_A = Variable(batch['A'].type(Tensor))
real_B = Variable(batch['B'].type(Tensor))

# -----------------------------
# Adversarial ground truths
valid1 = Variable(Tensor(np.ones((real_A.size(0), *patch1))), requires_grad=False)
valid2 = Variable(Tensor(np.ones((real_A.size(0), *patch2))), requires_grad=False)
fake1 = Variable(Tensor(np.zeros((real_A.size(0), *patch1))), requires_grad=False)
fake2 = Variable(Tensor(np.zeros((real_A.size(0), *patch2))), requires_grad=False)

#-------------------------------
# Train Generator and Encoder
# -----------------------------
#-------------------------------

optimizer_E.zero_grad()
optimizer_G.zero_grad()

#----------
# cVAE-GAN
#----------

# Produce output using encoding of B (cVAE-GAN)
mu, logvar = encoder(real_B)
encoded_z = reparameterization(mu, logvar)
fake_B = generator(real_A, encoded_z)

# Produce output using sampled z (cLR-GAN)
sampled_z = Variable(Tensor(np.random.normal(0, 1, (mu.size(0), opt.latent_dim))))
_fake_B = generator(real_A, sampled_z)
# Discriminator evaluates generated samples
VAE_validity1, VAE_validity2 = D_VAE(fake_B)

# Pixelwise loss of translated image by VAE
loss_pixel = pixelwise_loss(fake_B, real_B)

# Kullback-Leibler divergence of encoded B
loss_kl = torch.sum(0.5 * (mu**2 + torch.exp(logvar) - logvar - 1))
# Adversarial loss
loss_VAE_GAN = (adversarial_loss(VAE_validity1, valid1) + \
adversarial_loss(VAE_validity2, valid2)) / 2

# Discriminators evaluate generated samples
VAE_validity1, VAE_validity2 = D_VAE(fake_B)
#---------
# cLR-GAN
#---------

# Produce output using sampled z (cLR-GAN)
sampled_z = Variable(Tensor(np.random.normal(0, 1, (real_A.size(0), opt.latent_dim))))
_fake_B = generator(real_A, sampled_z)

# Discriminator evaluates generated samples
LR_validity1, LR_validity2 = D_LR(_fake_B)

# Adversarial losses
loss_VAE_GAN = (adversarial_loss(VAE_validity1, valid1) + \
adversarial_loss(VAE_validity2, valid2)) / 2
# cLR Loss: Adversarial loss
loss_LR_GAN = (adversarial_loss(LR_validity1, valid1) + \
adversarial_loss(LR_validity2, valid2)) / 2

# Shared losses between encoder and generator
#----------------------------------
# Total Loss (Generator + Encoder)
#----------------------------------

loss_GE = loss_VAE_GAN + \
loss_LR_GAN + \
lambda_pixel * loss_pixel + \
lambda_kl * loss_kl

loss_GE.backward()
loss_GE.backward(retain_graph=True)
optimizer_E.step()

#---------------------
# Generator Only Loss
#---------------------

# Latent L1 loss
_mu, _ = encoder(generator(real_A, sampled_z))
_mu, _ = encoder(_fake_B)
loss_latent = lambda_latent * latent_loss(_mu, sampled_z)

loss_latent.backward()
optimizer_G.step()

# --------------------------------
#----------------------------------
# Train Discriminator (cVAE-GAN)
# --------------------------------
#----------------------------------

optimizer_D_VAE.zero_grad()

Expand All @@ -218,9 +238,9 @@ def reparameterization(mu, logvar):
loss_D_VAE.backward()
optimizer_D_VAE.step()

# -------------------------------
#---------------------------------
# Train Discriminator (cLR-GAN)
# -------------------------------
#---------------------------------

optimizer_D_LR.zero_grad()

Expand All @@ -235,7 +255,7 @@ def reparameterization(mu, logvar):
adversarial_loss(pred_gen2, fake2)) / 2

# Total loss
loss_D_LR = 0.5 * (loss_real + loss_fake)
loss_D_LR = (loss_real + loss_fake) / 2

loss_D_LR.backward()
optimizer_D_LR.step()
Expand Down
Loading

0 comments on commit 9e3ac57

Please sign in to comment.