Skip to content

Commit

Permalink
refactored data modules
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrewzh112 committed Dec 14, 2020
1 parent 7b6915f commit cd3b55e
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 12 deletions.
34 changes: 34 additions & 0 deletions data/unlabelled.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from glob import glob
import os
from PIL import Image
from torchvision import transforms
from torchvision.datasets import CelebA
import torch
from torch.utils.data import Dataset, DataLoader


class celebA(Dataset):
def __init__(self, data_path, img_ext, crop_size, img_size):
self.celebs = glob(data_path + '/*' + img_ext)
Range = transforms.Lambda(lambda X: 2 * X - 1.)
self.transforms = transforms.Compose([
transforms.CenterCrop(crop_size),
transforms.Resize(img_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
Range
])

def __len__(self):
return len(self.celebs)

def __getitem__(self, index):
img = Image.open(self.celebs[index]).convert('RGB')
return self.transforms(img)


def get_celeba_loaders(data_path, img_ext, crop_size, img_size, batch_size, download):
dataset = celebA(data_path, img_ext, crop_size, img_size)
if download:
dataset = CelebA('celebA', transform=dataset.transforms, download=True)
return DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=28)
4 changes: 3 additions & 1 deletion dcgan/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
class celebA(Dataset):
def __init__(self, data_path, img_ext, crop_size, img_size):
self.celebs = glob(data_path + '/*' + img_ext)
Range = transforms.Lambda(lambda X: 2 * X - 1.)
self.transforms = transforms.Compose([
transforms.CenterCrop(crop_size),
transforms.Resize(img_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor()
transforms.ToTensor(),
Range
])

def __len__(self):
Expand Down
2 changes: 1 addition & 1 deletion dcgan/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, h_dim, z_dim, img_channels, img_size):
conv_bn_relu(h_dim*4, h_dim*2, 4, 2, 'relu', 'down'),
conv_bn_relu(h_dim*2, h_dim, 4, 2, 'relu', 'down'),
nn.ConvTranspose2d(h_dim, img_channels, 4, 2, 1),
nn.Sigmoid()
nn.Tanh()
)
initialize_weights(self)

Expand Down
8 changes: 4 additions & 4 deletions dcgan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
from dcgan.data import get_loaders
from data.unlabelled import get_celeba_loaders
from dcgan.models import Generator, Discriminator


Expand Down Expand Up @@ -40,8 +40,8 @@ def train():
os.mkdir(opt.log_dir)
if not os.path.isdir(opt.sample_dir):
os.mkdir(opt.sample_dir)
loader = get_loaders(opt.data_path, opt.img_ext, opt.crop_size,
opt.img_size, opt.batch_size, opt.download)
loader = get_celeba_loaders(opt.data_path, opt.img_ext, opt.crop_size,
opt.img_size, opt.batch_size, opt.download)
G = Generator(opt.h_dim, opt.z_dim, opt.img_channels, opt.img_size)
D = Discriminator(opt.img_channels, opt.h_dim, opt.img_size)
optimizer_G = torch.optim.Adam(G.parameters(), lr=opt.lr, betas=opt.betas)
Expand Down Expand Up @@ -77,7 +77,7 @@ def train():

with torch.no_grad():
samples = G(fixed_z)
# samples = ((samples + 1) / 2).view(-1, opt.img_channels, opt.img_size, opt.img_size)
samples = ((samples + 1) / 2).view(-1, opt.img_channels, opt.img_size, opt.img_size)
writer.add_image('Generated Images', torchvision.utils.make_grid(samples), global_step=epoch)
writer.add_scalars("Train Losses", {
"Discriminator Loss": sum(d_losses) / len(d_losses),
Expand Down
6 changes: 3 additions & 3 deletions sagan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from datetime import datetime
from pathlib import Path

from dcgan.data import get_loaders
from data.unlabelled import get_celeba_loaders
from networks.utils import load_weights
from sagan.model import Generator, Discriminator
from sagan.loss import Hinge_loss, Wasserstein_GP_Loss
Expand Down Expand Up @@ -71,8 +71,8 @@ def train():
optimizer_G = torch.optim.Adam(G.parameters(), lr=opt.lr_G, betas=opt.betas)
optimizer_D = torch.optim.Adam(D.parameters(), lr=opt.lr_D, betas=opt.betas)

loader = get_loaders(opt.data_path, opt.img_ext, opt.crop_size,
opt.img_size, opt.batch_size, opt.download)
loader = get_celeba_loaders(opt.data_path, opt.img_ext, opt.crop_size,
opt.img_size, opt.batch_size, opt.download)
# sample fixed z to see progress through training
fixed_z = torch.randn(opt.sample_size, opt.z_dim).to(device)

Expand Down
7 changes: 4 additions & 3 deletions vae/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from datetime import datetime

from networks.utils import initialize_modules
from dcgan.data import get_loaders
from data.unlabelled import get_celeba_loaders
from vae.model import VAE
from vae.loss import VAELoss

Expand Down Expand Up @@ -57,8 +57,8 @@
Path(args.checkpoint_dir).mkdir(parents=True, exist_ok=True)
Path(args.log_dir).mkdir(parents=True, exist_ok=True)

loader = get_loaders(args.data_path, args.img_ext, args.crop_size,
args.img_size, args.batch_size, args.download)
loader = get_celeba_loaders(args.data_path, args.img_ext, args.crop_size,
args.img_size, args.batch_size, args.download)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# initialize model, instantiate opt & scheduler & loss fn
Expand Down Expand Up @@ -104,6 +104,7 @@
model.eval()
with torch.no_grad():
sampled_images = model.module.sample(fixed_z)
sampled_images = (sampled_images + 1) / 2

# log images and losses & save model parameters
writer.add_image('Fixed Generated Images', torchvision.utils.make_grid(sampled_images), global_step=epoch)
Expand Down

0 comments on commit cd3b55e

Please sign in to comment.