From cd3b55e0664ac8cbef03c1b71004873091adc91a Mon Sep 17 00:00:00 2001 From: Andrew Zhao Date: Mon, 14 Dec 2020 11:18:53 +0800 Subject: [PATCH] refactored data modules --- data/unlabelled.py | 34 ++++++++++++++++++++++++++++++++++ dcgan/data.py | 4 +++- dcgan/models.py | 2 +- dcgan/train.py | 8 ++++---- sagan/train.py | 6 +++--- vae/main.py | 7 ++++--- 6 files changed, 49 insertions(+), 12 deletions(-) create mode 100644 data/unlabelled.py diff --git a/data/unlabelled.py b/data/unlabelled.py new file mode 100644 index 0000000..871adf8 --- /dev/null +++ b/data/unlabelled.py @@ -0,0 +1,34 @@ +from glob import glob +import os +from PIL import Image +from torchvision import transforms +from torchvision.datasets import CelebA +import torch +from torch.utils.data import Dataset, DataLoader + + +class celebA(Dataset): + def __init__(self, data_path, img_ext, crop_size, img_size): + self.celebs = glob(data_path + '/*' + img_ext) + Range = transforms.Lambda(lambda X: 2 * X - 1.) + self.transforms = transforms.Compose([ + transforms.CenterCrop(crop_size), + transforms.Resize(img_size), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + Range + ]) + + def __len__(self): + return len(self.celebs) + + def __getitem__(self, index): + img = Image.open(self.celebs[index]).convert('RGB') + return self.transforms(img) + + +def get_celeba_loaders(data_path, img_ext, crop_size, img_size, batch_size, download): + dataset = celebA(data_path, img_ext, crop_size, img_size) + if download: + dataset = CelebA('celebA', transform=dataset.transforms, download=True) + return DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=28) diff --git a/dcgan/data.py b/dcgan/data.py index 049956e..04bb063 100644 --- a/dcgan/data.py +++ b/dcgan/data.py @@ -10,11 +10,13 @@ class celebA(Dataset): def __init__(self, data_path, img_ext, crop_size, img_size): self.celebs = glob(data_path + '/*' + img_ext) + Range = transforms.Lambda(lambda X: 2 * X - 1.) self.transforms = transforms.Compose([ transforms.CenterCrop(crop_size), transforms.Resize(img_size), transforms.RandomHorizontalFlip(), - transforms.ToTensor() + transforms.ToTensor(), + Range ]) def __len__(self): diff --git a/dcgan/models.py b/dcgan/models.py index 67780ac..5090fd5 100644 --- a/dcgan/models.py +++ b/dcgan/models.py @@ -31,7 +31,7 @@ def __init__(self, h_dim, z_dim, img_channels, img_size): conv_bn_relu(h_dim*4, h_dim*2, 4, 2, 'relu', 'down'), conv_bn_relu(h_dim*2, h_dim, 4, 2, 'relu', 'down'), nn.ConvTranspose2d(h_dim, img_channels, 4, 2, 1), - nn.Sigmoid() + nn.Tanh() ) initialize_weights(self) diff --git a/dcgan/train.py b/dcgan/train.py index c45916b..f31c4c8 100644 --- a/dcgan/train.py +++ b/dcgan/train.py @@ -5,7 +5,7 @@ import os from torch.utils.tensorboard import SummaryWriter from datetime import datetime -from dcgan.data import get_loaders +from data.unlabelled import get_celeba_loaders from dcgan.models import Generator, Discriminator @@ -40,8 +40,8 @@ def train(): os.mkdir(opt.log_dir) if not os.path.isdir(opt.sample_dir): os.mkdir(opt.sample_dir) - loader = get_loaders(opt.data_path, opt.img_ext, opt.crop_size, - opt.img_size, opt.batch_size, opt.download) + loader = get_celeba_loaders(opt.data_path, opt.img_ext, opt.crop_size, + opt.img_size, opt.batch_size, opt.download) G = Generator(opt.h_dim, opt.z_dim, opt.img_channels, opt.img_size) D = Discriminator(opt.img_channels, opt.h_dim, opt.img_size) optimizer_G = torch.optim.Adam(G.parameters(), lr=opt.lr, betas=opt.betas) @@ -77,7 +77,7 @@ def train(): with torch.no_grad(): samples = G(fixed_z) - # samples = ((samples + 1) / 2).view(-1, opt.img_channels, opt.img_size, opt.img_size) + samples = ((samples + 1) / 2).view(-1, opt.img_channels, opt.img_size, opt.img_size) writer.add_image('Generated Images', torchvision.utils.make_grid(samples), global_step=epoch) writer.add_scalars("Train Losses", { "Discriminator Loss": sum(d_losses) / len(d_losses), diff --git a/sagan/train.py b/sagan/train.py index c191a1d..de3af21 100644 --- a/sagan/train.py +++ b/sagan/train.py @@ -8,7 +8,7 @@ from datetime import datetime from pathlib import Path -from dcgan.data import get_loaders +from data.unlabelled import get_celeba_loaders from networks.utils import load_weights from sagan.model import Generator, Discriminator from sagan.loss import Hinge_loss, Wasserstein_GP_Loss @@ -71,8 +71,8 @@ def train(): optimizer_G = torch.optim.Adam(G.parameters(), lr=opt.lr_G, betas=opt.betas) optimizer_D = torch.optim.Adam(D.parameters(), lr=opt.lr_D, betas=opt.betas) - loader = get_loaders(opt.data_path, opt.img_ext, opt.crop_size, - opt.img_size, opt.batch_size, opt.download) + loader = get_celeba_loaders(opt.data_path, opt.img_ext, opt.crop_size, + opt.img_size, opt.batch_size, opt.download) # sample fixed z to see progress through training fixed_z = torch.randn(opt.sample_size, opt.z_dim).to(device) diff --git a/vae/main.py b/vae/main.py index ee9848b..7406512 100644 --- a/vae/main.py +++ b/vae/main.py @@ -8,7 +8,7 @@ from datetime import datetime from networks.utils import initialize_modules -from dcgan.data import get_loaders +from data.unlabelled import get_celeba_loaders from vae.model import VAE from vae.loss import VAELoss @@ -57,8 +57,8 @@ Path(args.checkpoint_dir).mkdir(parents=True, exist_ok=True) Path(args.log_dir).mkdir(parents=True, exist_ok=True) - loader = get_loaders(args.data_path, args.img_ext, args.crop_size, - args.img_size, args.batch_size, args.download) + loader = get_celeba_loaders(args.data_path, args.img_ext, args.crop_size, + args.img_size, args.batch_size, args.download) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # initialize model, instantiate opt & scheduler & loss fn @@ -104,6 +104,7 @@ model.eval() with torch.no_grad(): sampled_images = model.module.sample(fixed_z) + sampled_images = (sampled_images + 1) / 2 # log images and losses & save model parameters writer.add_image('Fixed Generated Images', torchvision.utils.make_grid(sampled_images), global_step=epoch)