diff --git a/implementations/aae/aae.py b/implementations/aae/aae.py index 69be62fc..405e9df6 100644 --- a/implementations/aae/aae.py +++ b/implementations/aae/aae.py @@ -31,15 +31,9 @@ opt = parser.parse_args() print(opt) -cuda = True if torch.cuda.is_available() else False +img_shape = (opt.channels, opt.img_size, opt.img_size) -def weights_init_normal(m): - classname = m.__class__.__name__ - if classname.find('Conv') != -1: - torch.nn.init.normal_(m.weight.data, 0.0, 0.02) - elif classname.find('BatchNorm2d') != -1: - torch.nn.init.normal_(m.weight.data, 1.0, 0.02) - torch.nn.init.constant_(m.bias.data, 0.0) +cuda = True if torch.cuda.is_available() else False def reparameterization(mu, logvar): std = torch.exp(logvar / 2) @@ -52,7 +46,7 @@ def __init__(self): super(Encoder, self).__init__() self.model = nn.Sequential( - nn.Linear(opt.img_size**2, 512), + nn.Linear(int(np.prod(img_shape)), 512), nn.LeakyReLU(0.2, inplace=True), nn.Linear(512, 512), nn.BatchNorm1d(512), @@ -67,9 +61,7 @@ def forward(self, img): x = self.model(img_flat) mu = self.mu(x) logvar = self.logvar(x) - z = reparameterization(mu, logvar) - return z class Decoder(nn.Module): @@ -82,13 +74,13 @@ def __init__(self): nn.Linear(512, 512), nn.BatchNorm1d(512), nn.LeakyReLU(0.2, inplace=True), - nn.Linear(512, opt.img_size**2), + nn.Linear(512, int(np.prod(img_shape))), nn.Tanh() ) - def forward(self, noise): - img_flat = self.model(noise) - img = img_flat.view(img_flat.shape[0], opt.channels, opt.img_size, opt.img_size) + def forward(self, z): + img_flat = self.model(z) + img = img_flat.view(img_flat.shape[0], *img_shape) return img class Discriminator(nn.Module): @@ -104,9 +96,8 @@ def __init__(self): nn.Sigmoid() ) - def forward(self, latent): - validity = self.model(latent) - + def forward(self, z): + validity = self.model(z) return validity # Use binary cross-entropy loss @@ -125,11 +116,6 @@ def forward(self, latent): adversarial_loss.cuda() pixelwise_loss.cuda() -# Initialize weights -encoder.apply(weights_init_normal) -decoder.apply(weights_init_normal) -discriminator.apply(weights_init_normal) - # Configure data loader os.makedirs('../../data/mnist', exist_ok=True) dataloader = torch.utils.data.DataLoader( diff --git a/implementations/bgan/bgan.py b/implementations/bgan/bgan.py index dcb396d2..8d9895ca 100644 --- a/implementations/bgan/bgan.py +++ b/implementations/bgan/bgan.py @@ -32,39 +32,33 @@ opt = parser.parse_args() print(opt) -cuda = True if torch.cuda.is_available() else False +img_shape = (opt.channels, opt.img_size, opt.img_size) -def weights_init_normal(m): - classname = m.__class__.__name__ - if classname.find('Linear') != -1: - torch.nn.init.normal_(m.weight.data, 0.0, 0.02) - elif classname.find('BatchNorm2d') != -1: - torch.nn.init.normal_(m.weight.data, 1.0, 0.02) - torch.nn.init.constant_(m.bias.data, 0.0) +cuda = True if torch.cuda.is_available() else False class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() + def block(in_feat, out_feat, normalize=True): + layers = [ nn.Linear(in_feat, out_feat)] + if normalize: + layers.append(nn.BatchNorm1d(out_feat, 0.8)) + layers.append(nn.LeakyReLU(0.2, inplace=True)) + return layers + self.model = nn.Sequential( - nn.Linear(opt.latent_dim, 128), - nn.LeakyReLU(0.2, inplace=True), - nn.Linear(128, 256), - nn.BatchNorm1d(256), - nn.LeakyReLU(0.2, inplace=True), - nn.Linear(256, 512), - nn.BatchNorm1d(512), - nn.LeakyReLU(0.2, inplace=True), - nn.Linear(512, 1024), - nn.BatchNorm1d(1024), - nn.LeakyReLU(0.2, inplace=True), - nn.Linear(1024, opt.img_size**2), + *block(opt.latent_dim, 128, normalize=False), + *block(128, 256), + *block(256, 512), + *block(512, 1024), + nn.Linear(1024, int(np.prod(img_shape))), nn.Tanh() ) - def forward(self, noise): - img = self.model(noise) - img = img.view(img.shape[0], opt.channels, opt.img_size, opt.img_size) + def forward(self, z): + img = self.model(z) + img = img.view(img.shape[0], *img_shape) return img class Discriminator(nn.Module): @@ -72,7 +66,7 @@ def __init__(self): super(Discriminator, self).__init__() self.model = nn.Sequential( - nn.Linear(opt.img_size**2, 512), + nn.Linear(int(np.prod(img_shape)), 512), nn.LeakyReLU(0.2, inplace=True), nn.Linear(512, 256), nn.LeakyReLU(0.2, inplace=True), @@ -83,7 +77,6 @@ def __init__(self): def forward(self, img): img_flat = img.view(img.shape[0], -1) validity = self.model(img_flat) - return validity def boundary_seeking_loss(y_pred, y_true): @@ -104,10 +97,6 @@ def boundary_seeking_loss(y_pred, y_true): discriminator.cuda() discriminator_loss.cuda() -# Initialize weights -generator.apply(weights_init_normal) -discriminator.apply(weights_init_normal) - # Configure data loader os.makedirs('../../data/mnist', exist_ok=True) mnist_loader = torch.utils.data.DataLoader( diff --git a/implementations/bicyclegan/bicyclegan.py b/implementations/bicyclegan/bicyclegan.py index 39bb9e4e..76c75d64 100644 --- a/implementations/bicyclegan/bicyclegan.py +++ b/implementations/bicyclegan/bicyclegan.py @@ -25,16 +25,16 @@ parser.add_argument('--epoch', type=int, default=0, help='epoch to start training from') parser.add_argument('--n_epochs', type=int, default=200, help='number of epochs of training') parser.add_argument('--dataset_name', type=str, default="edges2shoes", help='name of the dataset') -parser.add_argument('--batch_size', type=int, default=1, help='size of the batches') +parser.add_argument('--batch_size', type=int, default=8, help='size of the batches') parser.add_argument('--lr', type=float, default=0.0002, help='adam: learning rate') parser.add_argument('--b1', type=float, default=0.5, help='adam: decay of first order momentum of gradient') parser.add_argument('--b2', type=float, default=0.999, help='adam: decay of first order momentum of gradient') parser.add_argument('--n_cpu', type=int, default=8, help='number of cpu threads to use during batch generation') -parser.add_argument('--img_height', type=int, default=256, help='size of image height') -parser.add_argument('--img_width', type=int, default=256, help='size of image width') +parser.add_argument('--img_height', type=int, default=128, help='size of image height') +parser.add_argument('--img_width', type=int, default=128, help='size of image width') parser.add_argument('--channels', type=int, default=3, help='number of image channels') -parser.add_argument('--latent_dim', type=int, default=8, help='dimensionality of latent representation') -parser.add_argument('--sample_interval', type=int, default=1000, help='interval between sampling of images from generators') +parser.add_argument('--latent_dim', type=int, default=8, help='number of latent codes') +parser.add_argument('--sample_interval', type=int, default=400, help='interval between sampling of images from generators') parser.add_argument('--checkpoint_interval', type=int, default=-1, help='interval between model checkpoints') opt = parser.parse_args() print(opt) @@ -42,14 +42,6 @@ os.makedirs('images/%s' % opt.dataset_name, exist_ok=True) os.makedirs('saved_models/%s' % opt.dataset_name, exist_ok=True) -def weights_init_normal(m): - classname = m.__class__.__name__ - if classname.find('Conv') != -1: - torch.nn.init.normal_(m.weight.data, 0.0, 0.02) - elif classname.find('BatchNorm2d') != -1: - torch.nn.init.normal_(m.weight.data, 1.0, 0.02) - torch.nn.init.constant_(m.bias.data, 0.0) - img_shape = (opt.channels, opt.img_height, opt.img_width) # Loss functions @@ -61,8 +53,8 @@ def weights_init_normal(m): cuda = True if torch.cuda.is_available() else False # Calculate outputs of multilevel PatchGAN -patch1 = (opt.batch_size, 1, opt.img_height // 2**3, opt.img_width // 2**3) -patch2 = (opt.batch_size, 1, opt.img_height // 2**4, opt.img_width // 2**4) +patch1 = (1, opt.img_height // 2**2, opt.img_width // 2**2) +patch2 = (1, opt.img_height // 2**3, opt.img_width // 2**3) # Initialize generator, encoder and discriminators generator = Generator(opt.latent_dim, img_shape) @@ -105,12 +97,6 @@ def weights_init_normal(m): Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor -# Adversarial ground truths -valid1 = Variable(Tensor(np.ones(patch1)), requires_grad=False) -valid2 = Variable(Tensor(np.ones(patch2)), requires_grad=False) -fake1 = Variable(Tensor(np.zeros(patch1)), requires_grad=False) -fake2 = Variable(Tensor(np.zeros(patch2)), requires_grad=False) - # Dataset loader transforms_ = [ transforms.Resize((opt.img_height, opt.img_width), Image.BICUBIC), transforms.ToTensor(), @@ -118,21 +104,28 @@ def weights_init_normal(m): dataloader = DataLoader(ImageDataset("../../data/%s" % opt.dataset_name, transforms_=transforms_), batch_size=opt.batch_size, shuffle=True, num_workers=opt.n_cpu) val_dataloader = DataLoader(ImageDataset("../../data/%s" % opt.dataset_name, transforms_=transforms_, mode='val'), - batch_size=5, shuffle=True, num_workers=1) + batch_size=8, shuffle=True, num_workers=1) def sample_images(batches_done): """Saves a generated sample from the validation set""" imgs = next(iter(val_dataloader)) - real_A = Variable(imgs['A'].type(Tensor)) - sampled_z = Variable(Tensor(np.random.normal(0, 1, (5, opt.latent_dim)))) - fake_B = generator(real_A, sampled_z) - real_B = Variable(imgs['B'].type(Tensor)) - img_sample = torch.cat((real_A.data, fake_B.data, real_B.data), 0) - save_image(img_sample, 'images/%s/%s.png' % (opt.dataset_name, batches_done), nrow=5, normalize=True) - -# ---------- -# Training -# ---------- + img_samples = None + for img_A, img_B in zip(imgs['A'], imgs['B']): + # Repeat input image by number of channels + real_A = img_A.view(1, *img_A.shape).repeat(8, 1, 1, 1) + real_A = Variable(real_A.type(Tensor)) + # Get interpolated noise [-1, 1] + sampled_z = np.repeat(np.linspace(-1, 1, 8)[:, np.newaxis], opt.latent_dim, 1) + sampled_z = Variable(Tensor(sampled_z)) + # Generator samples + fake_B = generator(real_A, sampled_z) + # Concatenate samples horisontally + fake_B = torch.cat([x for x in fake_B.data.cpu()], -1) + img_sample = torch.cat((img_A, fake_B), -1) + img_sample = img_sample.view(1, *img_sample.shape) + # Cocatenate with previous samples vertically + img_samples = img_sample if img_samples is None else torch.cat((img_samples, img_sample), -2) + save_image(img_samples, 'images/%s/%s.png' % (opt.dataset_name, batches_done), nrow=5, normalize=True) def reparameterization(mu, logvar): std = torch.exp(logvar / 2) @@ -140,6 +133,10 @@ def reparameterization(mu, logvar): z = sampled_z * std + mu return z +# ---------- +# Training +# ---------- + prev_time = time.time() for epoch in range(opt.epoch, opt.n_epochs): for i, batch in enumerate(dataloader): @@ -148,57 +145,80 @@ def reparameterization(mu, logvar): real_A = Variable(batch['A'].type(Tensor)) real_B = Variable(batch['B'].type(Tensor)) - # ----------------------------- + # Adversarial ground truths + valid1 = Variable(Tensor(np.ones((real_A.size(0), *patch1))), requires_grad=False) + valid2 = Variable(Tensor(np.ones((real_A.size(0), *patch2))), requires_grad=False) + fake1 = Variable(Tensor(np.zeros((real_A.size(0), *patch1))), requires_grad=False) + fake2 = Variable(Tensor(np.zeros((real_A.size(0), *patch2))), requires_grad=False) + + #------------------------------- # Train Generator and Encoder - # ----------------------------- + #------------------------------- optimizer_E.zero_grad() optimizer_G.zero_grad() + #---------- + # cVAE-GAN + #---------- + # Produce output using encoding of B (cVAE-GAN) mu, logvar = encoder(real_B) encoded_z = reparameterization(mu, logvar) fake_B = generator(real_A, encoded_z) - # Produce output using sampled z (cLR-GAN) - sampled_z = Variable(Tensor(np.random.normal(0, 1, (mu.size(0), opt.latent_dim)))) - _fake_B = generator(real_A, sampled_z) + # Discriminator evaluates generated samples + VAE_validity1, VAE_validity2 = D_VAE(fake_B) # Pixelwise loss of translated image by VAE loss_pixel = pixelwise_loss(fake_B, real_B) - # Kullback-Leibler divergence of encoded B loss_kl = torch.sum(0.5 * (mu**2 + torch.exp(logvar) - logvar - 1)) + # Adversarial loss + loss_VAE_GAN = (adversarial_loss(VAE_validity1, valid1) + \ + adversarial_loss(VAE_validity2, valid2)) / 2 - # Discriminators evaluate generated samples - VAE_validity1, VAE_validity2 = D_VAE(fake_B) + #--------- + # cLR-GAN + #--------- + + # Produce output using sampled z (cLR-GAN) + sampled_z = Variable(Tensor(np.random.normal(0, 1, (real_A.size(0), opt.latent_dim)))) + _fake_B = generator(real_A, sampled_z) + + # Discriminator evaluates generated samples LR_validity1, LR_validity2 = D_LR(_fake_B) - # Adversarial losses - loss_VAE_GAN = (adversarial_loss(VAE_validity1, valid1) + \ - adversarial_loss(VAE_validity2, valid2)) / 2 + # cLR Loss: Adversarial loss loss_LR_GAN = (adversarial_loss(LR_validity1, valid1) + \ adversarial_loss(LR_validity2, valid2)) / 2 - # Shared losses between encoder and generator + #---------------------------------- + # Total Loss (Generator + Encoder) + #---------------------------------- + loss_GE = loss_VAE_GAN + \ loss_LR_GAN + \ lambda_pixel * loss_pixel + \ lambda_kl * loss_kl - loss_GE.backward() + loss_GE.backward(retain_graph=True) optimizer_E.step() + #--------------------- + # Generator Only Loss + #--------------------- + # Latent L1 loss - _mu, _ = encoder(generator(real_A, sampled_z)) + _mu, _ = encoder(_fake_B) loss_latent = lambda_latent * latent_loss(_mu, sampled_z) loss_latent.backward() optimizer_G.step() - # -------------------------------- + #---------------------------------- # Train Discriminator (cVAE-GAN) - # -------------------------------- + #---------------------------------- optimizer_D_VAE.zero_grad() @@ -218,9 +238,9 @@ def reparameterization(mu, logvar): loss_D_VAE.backward() optimizer_D_VAE.step() - # ------------------------------- + #--------------------------------- # Train Discriminator (cLR-GAN) - # ------------------------------- + #--------------------------------- optimizer_D_LR.zero_grad() @@ -235,7 +255,7 @@ def reparameterization(mu, logvar): adversarial_loss(pred_gen2, fake2)) / 2 # Total loss - loss_D_LR = 0.5 * (loss_real + loss_fake) + loss_D_LR = (loss_real + loss_fake) / 2 loss_D_LR.backward() optimizer_D_LR.step() diff --git a/implementations/bicyclegan/models.py b/implementations/bicyclegan/models.py index 907b10a8..3ab3d855 100644 --- a/implementations/bicyclegan/models.py +++ b/implementations/bicyclegan/models.py @@ -5,6 +5,16 @@ from torchvision.models import resnet18 + +def weights_init_normal(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + torch.nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find('BatchNorm2d') != -1: + torch.nn.init.normal_(m.weight.data, 1.0, 0.02) + torch.nn.init.constant_(m.bias.data, 0.0) + + ############################## # U-NET ############################## @@ -12,17 +22,13 @@ class UNetDown(nn.Module): def __init__(self, in_size, out_size, normalize=True, dropout=0.0): super(UNetDown, self).__init__() - model = [nn.Conv2d(in_size, out_size, 3, stride=2, padding=1)] - + layers = [nn.Conv2d(in_size, out_size, 4, stride=2, padding=1, bias=False)] if normalize: - model.append(nn.BatchNorm2d(out_size, 0.8)) - - model.append(nn.LeakyReLU(0.2, inplace=True)) - + layers.append(nn.InstanceNorm2d(out_size, affine=True, track_running_stats=True)) + layers.append(nn.LeakyReLU(0.2, inplace=True)) if dropout: - model.append(nn.Dropout(dropout)) - - self.model = nn.Sequential(*model) + layers.append(nn.Dropout(dropout)) + self.model = nn.Sequential(*layers) def forward(self, x): return self.model(x) @@ -30,19 +36,19 @@ def forward(self, x): class UNetUp(nn.Module): def __init__(self, in_size, out_size, dropout=0.0): super(UNetUp, self).__init__() - model = [ nn.Upsample(scale_factor=2), - nn.Conv2d(in_size, out_size, 3, stride=1, padding=1), - nn.BatchNorm2d(out_size, 0.8), - nn.LeakyReLU(0.2, inplace=True) ] + layers = [ nn.ConvTranspose2d(in_size, out_size, 4, stride=2, padding=1, bias=False), + nn.InstanceNorm2d(out_size, affine=True, track_running_stats=True), + nn.ReLU(inplace=True)] if dropout: - model.append(nn.Dropout(dropout)) + layers.append(nn.Dropout(dropout)) - self.model = nn.Sequential(*model) + self.model = nn.Sequential(*layers) def forward(self, x, skip_input): x = self.model(x) - out = torch.cat((x, skip_input), 1) - return out + x = torch.cat((x, skip_input), 1) + + return x class Generator(nn.Module): def __init__(self, latent_dim, img_shape): @@ -57,20 +63,17 @@ def __init__(self, latent_dim, img_shape): self.down4 = UNetDown(256, 512) self.down5 = UNetDown(512, 512) self.down6 = UNetDown(512, 512) - self.down7 = UNetDown(512, 512) - self.down8 = UNetDown(512, 512, normalize=False) + self.down7 = UNetDown(512, 512, normalize=False) self.up1 = UNetUp(512, 512) self.up2 = UNetUp(1024, 512) self.up3 = UNetUp(1024, 512) - self.up4 = UNetUp(1024, 512) - self.up5 = UNetUp(1024, 256) - self.up6 = UNetUp(512, 128) - self.up7 = UNetUp(256, 64) + self.up4 = UNetUp(1024, 256) + self.up5 = UNetUp(512, 128) + self.up6 = UNetUp(256, 64) - final = [ nn.Upsample(scale_factor=2), - nn.Conv2d(128, channels, 3, 1, 1), + final = [ nn.ConvTranspose2d(128, channels, 4, stride=2, padding=1), nn.Tanh() ] self.final = nn.Sequential(*final) @@ -84,16 +87,14 @@ def forward(self, x, z): d5 = self.down5(d4) d6 = self.down6(d5) d7 = self.down7(d6) - d8 = self.down8(d7) - u1 = self.up1(d8, d7) - u2 = self.up2(u1, d6) - u3 = self.up3(u2, d5) - u4 = self.up4(u3, d4) - u5 = self.up5(u4, d3) - u6 = self.up6(u5, d2) - u7 = self.up7(u6, d1) + u1 = self.up1(d7, d6) + u2 = self.up2(u1, d5) + u3 = self.up3(u2, d4) + u4 = self.up4(u3, d3) + u5 = self.up5(u4, d2) + u6 = self.up6(u5, d1) - return self.final(u7) + return self.final(u6) ############################## # Encoder @@ -106,16 +107,16 @@ def __init__(self, latent_dim): resnet18_model = resnet18(pretrained=True) # Extracts features at the last fully-connected - self.feature_extractor = nn.Sequential(*list(resnet18_model.children())[:-2]) - self.avgpool = nn.AvgPool2d(kernel_size=8, stride=8) + self.feature_extractor = nn.Sequential(*list(resnet18_model.children())[:-3]) + self.pooling = nn.AvgPool2d(kernel_size=8, stride=8, padding=0) # Output is mu and log(var) for reparameterization trick used in VAEs - self.fc_mu = nn.Linear(512, latent_dim) - self.fc_logvar = nn.Linear(512, latent_dim) + self.fc_mu = nn.Linear(256, latent_dim) + self.fc_logvar = nn.Linear(256, latent_dim) def forward(self, img): out = self.feature_extractor(img) - out = self.avgpool(out) + out = self.pooling(out) out = out.view(out.size(0), -1) mu = self.fc_mu(out) logvar = self.fc_logvar(out) @@ -130,27 +131,25 @@ class Discriminator(nn.Module): def __init__(self, in_channels=3): super(Discriminator, self).__init__() - def discriminator_block(in_filters, out_filters, stride, normalize): - """Returns layers of each discriminator block""" - layers = [nn.Conv2d(in_filters, out_filters, 3, stride, 1)] + def downsample_block(in_filters, out_filters, normalize): + """Returns layers of each downsample block""" + layers = [nn.Conv2d(in_filters, out_filters, 4, 2, 1)] if normalize: - layers.append(nn.BatchNorm2d(out_filters)) + layers.append(nn.InstanceNorm2d(out_filters, affine=True, track_running_stats=True)) layers.append(nn.LeakyReLU(0.2, inplace=True)) return layers - # Down sampling - self.conv = nn.Sequential( - *discriminator_block(in_channels, 64, 2, False), - *discriminator_block(64, 128, 2, True), - *discriminator_block(128, 256, 2, True), + self.d1 = nn.Sequential( + *downsample_block(in_channels, 64, False), + *downsample_block(64, 128, True), + nn.Conv2d(128, 1, 3, 1, 1) ) - # Two output patches - self.out1 = nn.Conv2d(256, 1, 3, 1, 1) - self.out2 = nn.Sequential( - *discriminator_block(256, 512, 2, True), - nn.Conv2d(512, 1, 3, 1, 1) + self.d2 = nn.Sequential( + *downsample_block(in_channels, 64, False), + *downsample_block(64, 128, True), + *downsample_block(128, 256, True), + nn.Conv2d(256, 1, 3, 1, 1) ) def forward(self, img): - x = self.conv(img) - return self.out1(x), self.out2(x) + return self.d1(img), self.d2(img) diff --git a/implementations/cgan/cgan.py b/implementations/cgan/cgan.py index 574e548f..645cbb98 100644 --- a/implementations/cgan/cgan.py +++ b/implementations/cgan/cgan.py @@ -31,15 +31,9 @@ opt = parser.parse_args() print(opt) -cuda = True if torch.cuda.is_available() else False +img_shape = (opt.channels, opt.img_size, opt.img_size) -def weights_init_normal(m): - classname = m.__class__.__name__ - if classname.find('Conv') != -1: - torch.nn.init.normal_(m.weight.data, 0.0, 0.02) - elif classname.find('BatchNorm2d') != -1: - torch.nn.init.normal_(m.weight.data, 1.0, 0.02) - torch.nn.init.constant_(m.bias.data, 0.0) +cuda = True if torch.cuda.is_available() else False class Generator(nn.Module): def __init__(self): @@ -47,19 +41,19 @@ def __init__(self): self.label_emb = nn.Embedding(opt.n_classes, opt.n_classes) + def block(in_feat, out_feat, normalize=True): + layers = [ nn.Linear(in_feat, out_feat)] + if normalize: + layers.append(nn.BatchNorm1d(out_feat, 0.8)) + layers.append(nn.LeakyReLU(0.2, inplace=True)) + return layers + self.model = nn.Sequential( - nn.Linear(opt.latent_dim+opt.n_classes, 128), - nn.LeakyReLU(0.2, inplace=True), - nn.Linear(128, 256), - nn.BatchNorm1d(256), - nn.LeakyReLU(0.2, inplace=True), - nn.Linear(256, 512), - nn.BatchNorm1d(512), - nn.LeakyReLU(0.2, inplace=True), - nn.Linear(512, 1024), - nn.BatchNorm1d(1024), - nn.LeakyReLU(0.2, inplace=True), - nn.Linear(1024, opt.channels*opt.img_size**2), + *block(opt.latent_dim+opt.n_classes, 128, normalize=False), + *block(128, 256), + *block(256, 512), + *block(512, 1024), + nn.Linear(1024, int(np.prod(img_shape))), nn.Tanh() ) @@ -67,8 +61,7 @@ def forward(self, noise, labels): # Concatenate label embedding and image to produce input gen_input = torch.cat((self.label_emb(labels), noise), -1) img = self.model(gen_input) - # Reshape to image shape - img = img.view(img.size(0), opt.channels, opt.img_size, opt.img_size) + img = img.view(img.size(0), *img_shape) return img class Discriminator(nn.Module): @@ -78,7 +71,7 @@ def __init__(self): self.label_embedding = nn.Embedding(opt.n_classes, opt.n_classes) self.model = nn.Sequential( - nn.Linear(opt.n_classes + opt.img_size**2, 512), + nn.Linear(opt.n_classes + int(np.prod(img_shape)), 512), nn.LeakyReLU(0.2, inplace=True), nn.Linear(512, 512), nn.Dropout(0.4), @@ -91,9 +84,8 @@ def __init__(self): def forward(self, img, labels): # Concatenate label embedding and image to produce input - d_input = torch.cat((img.view(img.size(0), -1), self.label_embedding(labels)), -1) - validity = self.model(d_input) - + d_in = torch.cat((img.view(img.size(0), -1), self.label_embedding(labels)), -1) + validity = self.model(d_in) return validity # Loss functions @@ -110,10 +102,6 @@ def forward(self, img, labels): adversarial_loss.cuda() auxiliary_loss.cuda() -# Initialize weights -generator.apply(weights_init_normal) -discriminator.apply(weights_init_normal) - # Configure data loader os.makedirs('../../data/mnist', exist_ok=True) dataloader = torch.utils.data.DataLoader( diff --git a/implementations/cyclegan/models.py b/implementations/cyclegan/models.py index 1e671647..73094648 100644 --- a/implementations/cyclegan/models.py +++ b/implementations/cyclegan/models.py @@ -2,6 +2,15 @@ import torch.nn.functional as F import torch +def weights_init_normal(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + torch.nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find('BatchNorm2d') != -1: + torch.nn.init.normal_(m.weight.data, 1.0, 0.02) + torch.nn.init.constant_(m.bias.data, 0.0) + + ############################## # RESNET ############################## diff --git a/implementations/cyclegan/utils.py b/implementations/cyclegan/utils.py index 435597e6..3465cfc1 100644 --- a/implementations/cyclegan/utils.py +++ b/implementations/cyclegan/utils.py @@ -40,11 +40,3 @@ def __init__(self, n_epochs, offset, decay_start_epoch): def step(self, epoch): return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch)/(self.n_epochs - self.decay_start_epoch) - -def weights_init_normal(m): - classname = m.__class__.__name__ - if classname.find('Conv') != -1: - torch.nn.init.normal_(m.weight.data, 0.0, 0.02) - elif classname.find('BatchNorm2d') != -1: - torch.nn.init.normal_(m.weight.data, 1.0, 0.02) - torch.nn.init.constant_(m.bias.data, 0.0) diff --git a/implementations/dcgan/dcgan.py b/implementations/dcgan/dcgan.py index 2a47efda..2e772d5e 100644 --- a/implementations/dcgan/dcgan.py +++ b/implementations/dcgan/dcgan.py @@ -61,8 +61,8 @@ def __init__(self): nn.Tanh() ) - def forward(self, noise): - out = self.l1(noise) + def forward(self, z): + out = self.l1(z) out = out.view(out.shape[0], 128, self.init_size, self.init_size) img = self.conv_blocks(out) return img diff --git a/implementations/dualgan/datasets.py b/implementations/dualgan/datasets.py index d973acd0..f276c041 100644 --- a/implementations/dualgan/datasets.py +++ b/implementations/dualgan/datasets.py @@ -1,6 +1,7 @@ import glob import random import os +import numpy as np from torch.utils.data import Dataset from PIL import Image @@ -10,18 +11,23 @@ class ImageDataset(Dataset): def __init__(self, root, transforms_=None, mode='train'): self.transform = transforms.Compose(transforms_) - self.files = sorted(glob.glob(os.path.join(root, '%s' % mode) + '/*.*')) + self.files = sorted(glob.glob(os.path.join(root, mode) + '/*.*')) def __getitem__(self, index): - img_pair = self.transform(Image.open(self.files[index % len(self.files)])) - _, h, w = img_pair.shape - half_w = int(w/2) + img = Image.open(self.files[index % len(self.files)]) + w, h = img.size + img_A = img.crop((0, 0, w/2, h)) + img_B = img.crop((w/2, 0, w, h)) - item_A = img_pair[:, :, :half_w] - item_B = img_pair[:, :, half_w:] + if np.random.random() < 0.5: + img_A = Image.fromarray(np.array(img_A)[:, ::-1, :], 'RGB') + img_B = Image.fromarray(np.array(img_B)[:, ::-1, :], 'RGB') - return {'A': item_A, 'B': item_B} + img_A = self.transform(img_A) + img_B = self.transform(img_B) + + return {'A': img_A, 'B': img_B} def __len__(self): return len(self.files) diff --git a/implementations/dualgan/dualgan.py b/implementations/dualgan/dualgan.py index d251dd62..9e8fad1e 100644 --- a/implementations/dualgan/dualgan.py +++ b/implementations/dualgan/dualgan.py @@ -28,16 +28,17 @@ parser = argparse.ArgumentParser() parser.add_argument('--epoch', type=int, default=0, help='epoch to start training from') parser.add_argument('--n_epochs', type=int, default=200, help='number of epochs of training') -parser.add_argument('--batch_size', type=int, default=16, help='size of the batches') +parser.add_argument('--batch_size', type=int, default=8, help='size of the batches') parser.add_argument('--dataset_name', type=str, default='edges2shoes', help='name of the dataset') parser.add_argument('--lr', type=float, default=0.0002, help='adam: learning rate') parser.add_argument('--b1', type=float, default=0.5, help='adam: decay of first order momentum of gradient') parser.add_argument('--b2', type=float, default=0.999, help='adam: decay of first order momentum of gradient') parser.add_argument('--n_cpu', type=int, default=8, help='number of cpu threads to use during batch generation') -parser.add_argument('--img_size', type=int, default=64, help='size of each image dimension') +parser.add_argument('--img_size', type=int, default=128, help='size of each image dimension') parser.add_argument('--channels', type=int, default=3, help='number of image channels') parser.add_argument('--n_critic', type=int, default=5, help='number of training steps for discriminator per iter') parser.add_argument('--sample_interval', type=int, default=200, help='interval betwen image samples') +parser.add_argument('--checkpoint_interval', type=int, default=-1, help='interval between model checkpoints') opt = parser.parse_args() print(opt) @@ -46,26 +47,14 @@ img_shape = (opt.channels, opt.img_size, opt.img_size) -# Calculate output of image discriminator (PatchGAN) -patch = int(opt.img_size / 2**4) -patch = (1, patch, patch) - cuda = True if torch.cuda.is_available() else False -def weights_init_normal(m): - classname = m.__class__.__name__ - if classname.find('Conv') != -1: - torch.nn.init.normal_(m.weight.data, 0.0, 0.02) - elif classname.find('BatchNorm2d') != -1: - torch.nn.init.normal_(m.weight.data, 1.0, 0.02) - torch.nn.init.constant_(m.bias.data, 0.0) - # Loss function cycle_loss = torch.nn.L1Loss() # Loss weights lambda_adv = 1 -lambda_cycle = 100 +lambda_cycle = 10 lambda_gp = 10 # Initialize generator and discriminator @@ -95,11 +84,13 @@ def weights_init_normal(m): D_B.apply(weights_init_normal) # Configure data loader -transforms_ = [ transforms.Resize((opt.img_size, opt.img_size*2), Image.BICUBIC), +transforms_ = [ transforms.Resize((opt.img_size, opt.img_size), Image.BICUBIC), transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ] dataloader = DataLoader(ImageDataset("../../data/%s" % opt.dataset_name, transforms_=transforms_), batch_size=opt.batch_size, shuffle=True, num_workers=opt.n_cpu) +val_dataloader = DataLoader(ImageDataset("../../data/%s" % opt.dataset_name, mode='val', transforms_=transforms_), + batch_size=16, shuffle=True, num_workers=1) # Optimizers optimizer_G = torch.optim.Adam( itertools.chain(G_AB.parameters(), G_BA.parameters()), @@ -116,23 +107,32 @@ def compute_gradient_penalty(D, real_samples, fake_samples): alpha = FloatTensor(np.random.random((real_samples.size(0), 1, 1, 1))) # Get random interpolation between real and fake samples interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True) - d_interpolates = D(interpolates) - fake = Variable(FloatTensor(np.ones(d_interpolates.shape)), requires_grad=False) + validity = D(interpolates) + fake = Variable(FloatTensor(np.ones(validity.shape)), requires_grad=False) # Get gradient w.r.t. interpolates - gradients = autograd.grad(outputs=d_interpolates, inputs=interpolates, + gradients = autograd.grad(outputs=validity, inputs=interpolates, grad_outputs=fake, create_graph=True, retain_graph=True, only_inputs=True)[0] gradients = gradients.view(gradients.size(0), -1) gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() return gradient_penalty +def sample_images(batches_done): + """Saves a generated sample from the test set""" + imgs = next(iter(val_dataloader)) + real_A = Variable(imgs['A'].type(FloatTensor)) + fake_B = G_AB(real_A) + AB = torch.cat((real_A.data, fake_B.data), -2) + real_B = Variable(imgs['B'].type(FloatTensor)) + fake_A = G_BA(real_B) + BA = torch.cat((real_B.data, fake_A.data), -2) + img_sample = torch.cat((AB, BA), 0) + save_image(img_sample, 'images/%s/%s.png' % (opt.dataset_name, batches_done), nrow=8, normalize=True) + # ---------- # Training # ---------- -# Keeps the 5 latest image samples -image_samples = [] - batches_done = 0 prev_time = time.time() for epoch in range(opt.n_epochs): @@ -150,8 +150,8 @@ def compute_gradient_penalty(D, real_samples, fake_samples): optimizer_D_B.zero_grad() # Generate a batch of images - fake_A = G_BA(imgs_B) - fake_B = G_AB(imgs_A) + fake_A = G_BA(imgs_B).detach() + fake_B = G_AB(imgs_A).detach() #---------- # Domain A @@ -171,8 +171,10 @@ def compute_gradient_penalty(D, real_samples, fake_samples): # Adversarial loss D_B_loss = -torch.mean(D_B(imgs_B)) + torch.mean(D_B(fake_B)) + lambda_gp * gp_B - D_A_loss.backward() - D_B_loss.backward() + # Total loss + D_loss = D_A_loss + D_B_loss + + D_loss.backward() optimizer_D_A.step() optimizer_D_B.step() @@ -193,9 +195,9 @@ def compute_gradient_penalty(D, real_samples, fake_samples): recov_B = G_AB(fake_A) # Adversarial loss - G_adv = (-torch.mean(D_A(fake_A)) - torch.mean(D_B(fake_B))) + G_adv = -torch.mean(D_A(fake_A)) - torch.mean(D_B(fake_B)) # Cycle loss - G_cycle = cycle_loss(recov_B, imgs_B) + cycle_loss(recov_A, imgs_A) + G_cycle = cycle_loss(recov_A, imgs_A) + cycle_loss(recov_B, imgs_B) # Total loss G_loss = lambda_adv * G_adv + lambda_cycle * G_cycle @@ -207,34 +209,22 @@ def compute_gradient_penalty(D, real_samples, fake_samples): #-------------- # Determine approximate time left - batches_done = epoch * len(dataloader) + i batches_left = opt.n_epochs * len(dataloader) - batches_done - time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time)) + time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time) / opt.n_critic) prev_time = time.time() - sys.stdout.write("\r[Epoch %d/%d] [Batch %d/%d] [D_A loss: %f] [D_B loss: %f] [G loss: %f, cycle: %f] ETA: %s" % (epoch, opt.n_epochs, + sys.stdout.write("\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, cycle: %f] ETA: %s" % (epoch, opt.n_epochs, i, len(dataloader), - D_A_loss.item(), D_B_loss.item(), - G_adv.data.item(), G_cycle.item(), - time_left)) - - - + D_loss.item(), G_adv.data.item(), + G_cycle.item(), time_left)) - # Create image sample - ABA = torch.cat((imgs_A[:1].data, fake_B[:1].data, recov_A[:1].data), -2) - BAB = torch.cat((imgs_B[:1].data, fake_A[:1].data, recov_B[:1].data), -2) - sample = torch.cat((ABA, BAB), -2) - image_samples.append(sample) - if len(image_samples) > 5: - image_samples.pop(0) - # Check sample interval => save sample if there - if batches_done % opt.sample_interval == 0: - save_image(torch.cat(image_samples, -1), 'images/%s/%d.png' % (opt.dataset_name, batches_done), nrow=2, normalize=True) + # Check sample interval => save sample if there + if batches_done % opt.sample_interval == 0: + sample_images(batches_done) - batches_done += opt.n_critic + batches_done += 1 if opt.checkpoint_interval != -1 and epoch % opt.checkpoint_interval == 0: # Save model checkpoints diff --git a/implementations/dualgan/models.py b/implementations/dualgan/models.py index 4a42eacd..d83ae75b 100644 --- a/implementations/dualgan/models.py +++ b/implementations/dualgan/models.py @@ -4,21 +4,28 @@ from torchvision.models import vgg19 import math +def weights_init_normal(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + torch.nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find('BatchNorm2d') != -1: + torch.nn.init.normal_(m.weight.data, 1.0, 0.02) + torch.nn.init.constant_(m.bias.data, 0.0) + ############################## -# U-NET Generator +# U-NET ############################## class UNetDown(nn.Module): def __init__(self, in_size, out_size, normalize=True, dropout=0.0): super(UNetDown, self).__init__() - model = [ nn.Conv2d(in_size, out_size, 3, stride=2, padding=1)] + layers = [nn.Conv2d(in_size, out_size, 4, stride=2, padding=1, bias=False)] if normalize: - model.append(nn.BatchNorm2d(out_size, 0.8)) - model.append(nn.LeakyReLU(0.2, inplace=True)) + layers.append(nn.InstanceNorm2d(out_size, affine=True)) + layers.append(nn.LeakyReLU(0.2, inplace=True)) if dropout: - model.append(nn.Dropout(dropout)) - - self.model = nn.Sequential(*model) + layers.append(nn.Dropout(dropout)) + self.model = nn.Sequential(*layers) def forward(self, x): return self.model(x) @@ -26,59 +33,61 @@ def forward(self, x): class UNetUp(nn.Module): def __init__(self, in_size, out_size, dropout=0.0): super(UNetUp, self).__init__() - model = [ nn.Upsample(scale_factor=2), - nn.Conv2d(in_size, out_size, 3, stride=1, padding=1), - nn.BatchNorm2d(out_size, 0.8), - nn.LeakyReLU(0.2, inplace=True)] + layers = [ nn.ConvTranspose2d(in_size, out_size, 4, stride=2, padding=1, bias=False), + nn.InstanceNorm2d(out_size, affine=True), + nn.ReLU(inplace=True)] if dropout: - model += [nn.Dropout(dropout)] + layers.append(nn.Dropout(dropout)) - self.model = nn.Sequential(*model) + self.model = nn.Sequential(*layers) def forward(self, x, skip_input): x = self.model(x) - out = torch.cat((x, skip_input), 1) - return out + x = torch.cat((x, skip_input), 1) + + return x class Generator(nn.Module): - def __init__(self, in_channels=3, out_channels=3): + def __init__(self, channels=3): super(Generator, self).__init__() - self.down1 = UNetDown(in_channels, 64, normalize=False) + self.down1 = UNetDown(channels, 64, normalize=False) self.down2 = UNetDown(64, 128) - self.down3 = UNetDown(128, 256, dropout=0.5) + self.down3 = UNetDown(128, 256) self.down4 = UNetDown(256, 512, dropout=0.5) self.down5 = UNetDown(512, 512, dropout=0.5) - self.down6 = UNetDown(512, 512, dropout=0.5, normalize=False) + self.down6 = UNetDown(512, 512, dropout=0.5) + self.down7 = UNetDown(512, 512, dropout=0.5, normalize=False) self.up1 = UNetUp(512, 512, dropout=0.5) self.up2 = UNetUp(1024, 512, dropout=0.5) - self.up3 = UNetUp(1024, 256, dropout=0.5) - self.up4 = UNetUp(512, 128) - self.up5 = UNetUp(256, 64) - - - final = [ nn.Upsample(scale_factor=2), - nn.Conv2d(128, out_channels, 3, 1, 1), - nn.Tanh() ] - self.final = nn.Sequential(*final) + self.up3 = UNetUp(1024, 512, dropout=0.5) + self.up4 = UNetUp(1024, 256) + self.up5 = UNetUp(512, 128) + self.up6 = UNetUp(256, 64) + + self.final = nn.Sequential( + nn.ConvTranspose2d(128, channels, 4, stride=2, padding=1), + nn.Tanh() + ) def forward(self, x): - # U-Net generator with skip connections from encoder to decoder + # Propogate noise through fc layer and reshape to img shape d1 = self.down1(x) d2 = self.down2(d1) d3 = self.down3(d2) d4 = self.down4(d3) d5 = self.down5(d4) d6 = self.down6(d5) - u1 = self.up1(d6, d5) - u2 = self.up2(u1, d4) - u3 = self.up3(u2, d3) - u4 = self.up4(u3, d2) - u5 = self.up5(u4, d1) - - return self.final(u5) + d7 = self.down7(d6) + u1 = self.up1(d7, d6) + u2 = self.up2(u1, d5) + u3 = self.up3(u2, d4) + u4 = self.up4(u3, d3) + u5 = self.up5(u4, d2) + u6 = self.up6(u5, d1) + return self.final(u6) ############################## # Discriminator @@ -88,20 +97,20 @@ class Discriminator(nn.Module): def __init__(self, in_channels=3): super(Discriminator, self).__init__() - def discrimintor_block(in_features, out_features, normalization=True): + def discrimintor_block(in_features, out_features, normalize=True): """Discriminator block""" - layers = [nn.Conv2d(in_features, out_features, 3, stride=2, padding=1)] - if normalization: + layers = [nn.Conv2d(in_features, out_features, 4, stride=2, padding=1)] + if normalize: layers.append(nn.BatchNorm2d(out_features, 0.8)) layers.append(nn.LeakyReLU(0.2, inplace=True)) return layers self.model = nn.Sequential( - *discrimintor_block(in_channels, 64, False), + *discrimintor_block(in_channels, 64, normalize=False), *discrimintor_block(64, 128), *discrimintor_block(128, 256), - *discrimintor_block(256, 512), - nn.Conv2d(512, 1, 3, 1, 1) + nn.ZeroPad2d((1, 0, 1, 0)), + nn.Conv2d(256, 1, kernel_size=4) ) def forward(self, img): diff --git a/implementations/gan/gan.py b/implementations/gan/gan.py index 77ba7a23..2efdbb2d 100644 --- a/implementations/gan/gan.py +++ b/implementations/gan/gan.py @@ -30,39 +30,31 @@ opt = parser.parse_args() print(opt) -cuda = True if torch.cuda.is_available() else False +img_shape = (opt.channels, opt.img_size, opt.img_size) -def weights_init_normal(m): - classname = m.__class__.__name__ - if classname.find('Linear') != -1: - torch.nn.init.normal_(m.weight.data, 0.0, 0.02) - elif classname.find('BatchNorm') != -1: - torch.nn.init.normal_(m.weight.data, 1.0, 0.02) - torch.nn.init.constant_(m.bias.data, 0.0) +cuda = True if torch.cuda.is_available() else False class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() + def block(in_feat, out_feat): + layers = [ nn.Linear(in_feat, out_feat), + nn.LeakyReLU(0.2, inplace=True)] + return layers + self.model = nn.Sequential( - nn.Linear(opt.latent_dim, 128), - nn.LeakyReLU(0.2, inplace=True), - nn.Linear(128, 256), - nn.BatchNorm1d(256), - nn.LeakyReLU(0.2, inplace=True), - nn.Linear(256, 512), - nn.BatchNorm1d(512), - nn.LeakyReLU(0.2, inplace=True), - nn.Linear(512, 1024), - nn.BatchNorm1d(1024), - nn.LeakyReLU(0.2, inplace=True), - nn.Linear(1024, opt.img_size**2), + *block(opt.latent_dim, 128), + *block(128, 256), + *block(256, 512), + *block(512, 1024), + nn.Linear(1024, int(np.prod(img_shape))), nn.Tanh() ) - def forward(self, noise): - img = self.model(noise) - img = img.view(img.shape[0], opt.channels, opt.img_size, opt.img_size) + def forward(self, z): + img = self.model(z) + img = img.view(img.size(0), *img_shape) return img class Discriminator(nn.Module): @@ -70,7 +62,7 @@ def __init__(self): super(Discriminator, self).__init__() self.model = nn.Sequential( - nn.Linear(opt.img_size**2, 512), + nn.Linear(int(np.prod(img_shape)), 512), nn.LeakyReLU(0.2, inplace=True), nn.Linear(512, 256), nn.LeakyReLU(0.2, inplace=True), @@ -79,7 +71,7 @@ def __init__(self): ) def forward(self, img): - img_flat = img.view(img.shape[0], -1) + img_flat = img.view(img.size(0), -1) validity = self.model(img_flat) return validity @@ -96,10 +88,6 @@ def forward(self, img): discriminator.cuda() adversarial_loss.cuda() -# Initialize weights -generator.apply(weights_init_normal) -discriminator.apply(weights_init_normal) - # Configure data loader os.makedirs('../../data/mnist', exist_ok=True) dataloader = torch.utils.data.DataLoader( @@ -124,8 +112,8 @@ def forward(self, img): for i, (imgs, _) in enumerate(dataloader): # Adversarial ground truths - valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False) - fake = Variable(Tensor(imgs.shape[0], 1).fill_(0.0), requires_grad=False) + valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False) + fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False) # Configure input real_imgs = Variable(imgs.type(Tensor)) diff --git a/implementations/lsgan/lsgan.py b/implementations/lsgan/lsgan.py index 3a73299a..7f882aab 100644 --- a/implementations/lsgan/lsgan.py +++ b/implementations/lsgan/lsgan.py @@ -48,7 +48,6 @@ def __init__(self): self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128*self.init_size**2)) self.conv_blocks = nn.Sequential( - nn.BatchNorm2d(128), nn.Upsample(scale_factor=2), nn.Conv2d(128, 128, 3, stride=1, padding=1), nn.BatchNorm2d(128, 0.8), @@ -61,8 +60,8 @@ def __init__(self): nn.Tanh() ) - def forward(self, noise): - out = self.l1(noise) + def forward(self, z): + out = self.l1(z) out = out.view(out.shape[0], 128, self.init_size, self.init_size) img = self.conv_blocks(out) return img diff --git a/implementations/pix2pix/models.py b/implementations/pix2pix/models.py index 76551a43..a4a98c54 100644 --- a/implementations/pix2pix/models.py +++ b/implementations/pix2pix/models.py @@ -2,6 +2,14 @@ import torch.nn.functional as F import torch +def weights_init_normal(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + torch.nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find('BatchNorm2d') != -1: + torch.nn.init.normal_(m.weight.data, 1.0, 0.02) + torch.nn.init.constant_(m.bias.data, 0.0) + ############################## # U-NET ############################## diff --git a/implementations/pix2pix/pix2pix.py b/implementations/pix2pix/pix2pix.py index 23102eb6..a8ed7c6a 100644 --- a/implementations/pix2pix/pix2pix.py +++ b/implementations/pix2pix/pix2pix.py @@ -42,14 +42,6 @@ os.makedirs('images/%s' % opt.dataset_name, exist_ok=True) os.makedirs('saved_models/%s' % opt.dataset_name, exist_ok=True) -def weights_init_normal(m): - classname = m.__class__.__name__ - if classname.find('Conv') != -1: - torch.nn.init.normal_(m.weight.data, 0.0, 0.02) - elif classname.find('BatchNorm2d') != -1: - torch.nn.init.normal_(m.weight.data, 1.0, 0.02) - torch.nn.init.constant_(m.bias.data, 0.0) - cuda = True if torch.cuda.is_available() else False # Loss functions diff --git a/implementations/softmax_gan/softmax_gan.py b/implementations/softmax_gan/softmax_gan.py index c1206199..b2c26e64 100644 --- a/implementations/softmax_gan/softmax_gan.py +++ b/implementations/softmax_gan/softmax_gan.py @@ -30,39 +30,33 @@ opt = parser.parse_args() print(opt) -cuda = True if torch.cuda.is_available() else False +img_shape = (opt.channels, opt.img_size, opt.img_size) -def weights_init_normal(m): - classname = m.__class__.__name__ - if classname.find('Linear') != -1: - torch.nn.init.normal_(m.weight.data, 0.0, 0.02) - elif classname.find('BatchNorm') != -1: - torch.nn.init.normal_(m.weight.data, 1.0, 0.02) - torch.nn.init.constant_(m.bias.data, 0.0) +cuda = True if torch.cuda.is_available() else False class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() + def block(in_feat, out_feat, normalize=True): + layers = [ nn.Linear(in_feat, out_feat)] + if normalize: + layers.append(nn.BatchNorm1d(out_feat, 0.8)) + layers.append(nn.LeakyReLU(0.2, inplace=True)) + return layers + self.model = nn.Sequential( - nn.Linear(opt.latent_dim, 128), - nn.LeakyReLU(0.2, inplace=True), - nn.Linear(128, 256), - nn.BatchNorm1d(256), - nn.LeakyReLU(0.2, inplace=True), - nn.Linear(256, 512), - nn.BatchNorm1d(512), - nn.LeakyReLU(0.2, inplace=True), - nn.Linear(512, 1024), - nn.BatchNorm1d(1024), - nn.LeakyReLU(0.2, inplace=True), - nn.Linear(1024, opt.img_size**2), + *block(opt.latent_dim, 128, normalize=False), + *block(128, 256), + *block(256, 512), + *block(512, 1024), + nn.Linear(1024, int(np.prod(img_shape))), nn.Tanh() ) - def forward(self, noise): - img = self.model(noise) - img = img.view(img.shape[0], opt.channels, opt.img_size, opt.img_size) + def forward(self, z): + img = self.model(z) + img = img.view(img.shape[0], *img_shape) return img class Discriminator(nn.Module): @@ -95,10 +89,6 @@ def forward(self, img): discriminator.cuda() adversarial_loss.cuda() -# Initialize weights -generator.apply(weights_init_normal) -discriminator.apply(weights_init_normal) - # Configure data loader os.makedirs('../../data/mnist', exist_ok=True) dataloader = torch.utils.data.DataLoader( diff --git a/implementations/srgan/models.py b/implementations/srgan/models.py index 5f003cda..89d7194c 100644 --- a/implementations/srgan/models.py +++ b/implementations/srgan/models.py @@ -4,6 +4,14 @@ from torchvision.models import vgg19 import math +def weights_init_normal(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + torch.nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find('BatchNorm') != -1: + torch.nn.init.normal_(m.weight.data, 1.0, 0.02) + torch.nn.init.constant_(m.bias.data, 0.0) + class FeatureExtractor(nn.Module): def __init__(self): super(FeatureExtractor, self).__init__() diff --git a/implementations/srgan/srgan.py b/implementations/srgan/srgan.py index e52f1e50..f94afec3 100644 --- a/implementations/srgan/srgan.py +++ b/implementations/srgan/srgan.py @@ -51,14 +51,6 @@ cuda = True if torch.cuda.is_available() else False -def weights_init_normal(m): - classname = m.__class__.__name__ - if classname.find('Conv') != -1: - torch.nn.init.normal_(m.weight.data, 0.0, 0.02) - elif classname.find('BatchNorm') != -1: - torch.nn.init.normal_(m.weight.data, 1.0, 0.02) - torch.nn.init.constant_(m.bias.data, 0.0) - # Calculate output of image discriminator (PatchGAN) patch_h, patch_w = int(opt.hr_height / 2**4), int(opt.hr_width / 2**4) patch = (opt.batch_size, 1, patch_h, patch_w) diff --git a/implementations/stargan/models.py b/implementations/stargan/models.py index f1e10405..a96e68f6 100644 --- a/implementations/stargan/models.py +++ b/implementations/stargan/models.py @@ -2,6 +2,11 @@ import torch.nn.functional as F import torch +def weights_init_normal(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + torch.nn.init.normal_(m.weight.data, 0.0, 0.02) + ############################## # RESNET ############################## diff --git a/implementations/stargan/stargan.py b/implementations/stargan/stargan.py index e8d6fa81..ef0bdb71 100644 --- a/implementations/stargan/stargan.py +++ b/implementations/stargan/stargan.py @@ -59,11 +59,6 @@ opt = parser.parse_args() print(opt) -def weights_init_normal(m): - classname = m.__class__.__name__ - if classname.find('Conv') != -1: - torch.nn.init.normal_(m.weight.data, 0.0, 0.02) - c_dim = len(opt.selected_attrs) img_shape = (opt.channels, opt.img_height, opt.img_width) diff --git a/implementations/wgan/wgan.py b/implementations/wgan/wgan.py index 2fc6f692..51a5d0b4 100644 --- a/implementations/wgan/wgan.py +++ b/implementations/wgan/wgan.py @@ -31,39 +31,33 @@ opt = parser.parse_args() print(opt) -cuda = True if torch.cuda.is_available() else False +img_shape = (opt.channels, opt.img_size, opt.img_size) -def weights_init_normal(m): - classname = m.__class__.__name__ - if classname.find('Conv') != -1: - torch.nn.init.normal_(m.weight.data, 0.0, 0.02) - elif classname.find('BatchNorm') != -1: - torch.nn.init.normal_(m.weight.data, 1.0, 0.02) - torch.nn.init.constant_(m.bias.data, 0.0) +cuda = True if torch.cuda.is_available() else False class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() + def block(in_feat, out_feat, normalize=True): + layers = [ nn.Linear(in_feat, out_feat)] + if normalize: + layers.append(nn.BatchNorm1d(out_feat, 0.8)) + layers.append(nn.LeakyReLU(0.2, inplace=True)) + return layers + self.model = nn.Sequential( - nn.Linear(opt.latent_dim, 128), - nn.LeakyReLU(0.2, inplace=True), - nn.Linear(128, 256), - nn.BatchNorm1d(256), - nn.LeakyReLU(0.2, inplace=True), - nn.Linear(256, 512), - nn.BatchNorm1d(512), - nn.LeakyReLU(0.2, inplace=True), - nn.Linear(512, 1024), - nn.BatchNorm1d(1024), - nn.LeakyReLU(0.2, inplace=True), - nn.Linear(1024, opt.img_size**2), + *block(opt.latent_dim, 128, normalize=False), + *block(128, 256), + *block(256, 512), + *block(512, 1024), + nn.Linear(1024, int(np.prod(img_shape))), nn.Tanh() ) - def forward(self, noise): - img = self.model(noise) - img = img.view(img.shape[0], opt.channels, opt.img_size, opt.img_size) + def forward(self, z): + img = self.model(z) + img = img.view(img.shape[0], *img_shape) return img class Discriminator(nn.Module): @@ -71,7 +65,7 @@ def __init__(self): super(Discriminator, self).__init__() self.model = nn.Sequential( - nn.Linear(opt.img_size**2, 512), + nn.Linear(int(np.prod(img_shape)), 512), nn.LeakyReLU(0.2, inplace=True), nn.Linear(512, 256), nn.LeakyReLU(0.2, inplace=True), @@ -81,7 +75,6 @@ def __init__(self): def forward(self, img): img_flat = img.view(img.shape[0], -1) validity = self.model(img_flat) - return validity # Initialize generator and discriminator @@ -92,10 +85,6 @@ def forward(self, img): generator.cuda() discriminator.cuda() -# Initialize weights -generator.apply(weights_init_normal) -discriminator.apply(weights_init_normal) - # Configure data loader os.makedirs('../../data/mnist', exist_ok=True) dataloader = torch.utils.data.DataLoader( diff --git a/implementations/wgan_gp/wgan_gp.py b/implementations/wgan_gp/wgan_gp.py index 76177de1..beae53ce 100644 --- a/implementations/wgan_gp/wgan_gp.py +++ b/implementations/wgan_gp/wgan_gp.py @@ -34,39 +34,33 @@ opt = parser.parse_args() print(opt) -cuda = True if torch.cuda.is_available() else False +img_shape = (opt.channels, opt.img_size, opt.img_size) -def weights_init_normal(m): - classname = m.__class__.__name__ - if classname.find('Conv') != -1: - torch.nn.init.normal_(m.weight.data, 0.0, 0.02) - elif classname.find('BatchNorm') != -1: - torch.nn.init.normal_(m.weight.data, 1.0, 0.02) - torch.nn.init.constant_(m.bias.data, 0.0) +cuda = True if torch.cuda.is_available() else False class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() + def block(in_feat, out_feat, normalize=True): + layers = [ nn.Linear(in_feat, out_feat)] + if normalize: + layers.append(nn.BatchNorm1d(out_feat, 0.8)) + layers.append(nn.LeakyReLU(0.2, inplace=True)) + return layers + self.model = nn.Sequential( - nn.Linear(opt.latent_dim, 128), - nn.LeakyReLU(0.2, inplace=True), - nn.Linear(128, 256), - nn.BatchNorm1d(256), - nn.LeakyReLU(0.2, inplace=True), - nn.Linear(256, 512), - nn.BatchNorm1d(512), - nn.LeakyReLU(0.2, inplace=True), - nn.Linear(512, 1024), - nn.BatchNorm1d(1024), - nn.LeakyReLU(0.2, inplace=True), - nn.Linear(1024, opt.img_size**2), + *block(opt.latent_dim, 128, normalize=False), + *block(128, 256), + *block(256, 512), + *block(512, 1024), + nn.Linear(1024, int(np.prod(img_shape))), nn.Tanh() ) - def forward(self, noise): - img = self.model(noise) - img = img.view(img.shape[0], opt.channels, opt.img_size, opt.img_size) + def forward(self, z): + img = self.model(z) + img = img.view(img.shape[0], *img_shape) return img class Discriminator(nn.Module): @@ -74,7 +68,7 @@ def __init__(self): super(Discriminator, self).__init__() self.model = nn.Sequential( - nn.Linear(opt.img_size**2, 512), + nn.Linear(int(np.prod(img_shape)), 512), nn.LeakyReLU(0.2, inplace=True), nn.Linear(512, 256), nn.LeakyReLU(0.2, inplace=True), @@ -84,7 +78,6 @@ def __init__(self): def forward(self, img): img_flat = img.view(img.shape[0], -1) validity = self.model(img_flat) - return validity # Loss weight for gradient penalty @@ -98,10 +91,6 @@ def forward(self, img): generator.cuda() discriminator.cuda() -# Initialize weights -generator.apply(weights_init_normal) -discriminator.apply(weights_init_normal) - # Configure data loader os.makedirs('../../data/mnist', exist_ok=True) dataloader = torch.utils.data.DataLoader(