From 42c78323b081741ba7e7c18a20cf562f79288255 Mon Sep 17 00:00:00 2001 From: zsdonghao Date: Fri, 11 Dec 2020 12:14:20 +0800 Subject: [PATCH] bug fixes for loading data --- sagan/train.py | 6 +++--- vae/main.py | 1 + 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/sagan/train.py b/sagan/train.py index 369a1cc..a47da8b 100644 --- a/sagan/train.py +++ b/sagan/train.py @@ -17,8 +17,8 @@ parser = argparse.ArgumentParser() # model parameters -parser.add_argument('--h_dim', type=float, default=64, help='model dimensions multiplier') -parser.add_argument('--z_dim', type=float, default=100, help='dimension of random noise latent vector') +parser.add_argument('--h_dim', type=int, default=64, help='model dimensions multiplier') +parser.add_argument('--z_dim', type=int, default=100, help='dimension of random noise latent vector') # data paramters parser.add_argument('--img_size', type=int, default=128, help='H, W of the input images') @@ -65,7 +65,7 @@ def train(): optimizer_D = torch.optim.Adam(D.parameters(), lr=opt.lr_D, betas=opt.betas) loader = get_loaders(opt.data_path, opt.img_ext, opt.crop_size, - opt.img_size, opt.batch_size) + opt.img_size, opt.batch_size, opt.download) # sample fixed z to see progress through training fixed_z = torch.randn(opt.sample_size, opt.z_dim).to(device) diff --git a/vae/main.py b/vae/main.py index c4edba5..ee9848b 100644 --- a/vae/main.py +++ b/vae/main.py @@ -19,6 +19,7 @@ parser.add_argument('--img_channels', type=int, default=3, help='Numer of channels for images') parser.add_argument('--img_size', type=int, default=64, help='H, W of the input images') parser.add_argument('--crop_size', type=int, default=128, help='H, W of the input images') +parser.add_argument('--download', action="store_true", default=False, help='If auto download CelebA dataset') # model params parser.add_argument('--z_dim', type=float, default=100, help='dimension of random noise latent vector')