From 34814218ed5ed5ed4b5cb0395da3bb5bd438503c Mon Sep 17 00:00:00 2001 From: zsdonghao Date: Thu, 3 Dec 2020 11:46:36 +0800 Subject: [PATCH] vae bug fixes --- moco/main.py | 2 +- networks/layers.py | 2 +- simsiam/main.py | 2 +- vae/loss.py | 8 ++++---- vae/main.py | 11 ++++++----- vae/model.py | 23 ++++++++++++++--------- 6 files changed, 27 insertions(+), 21 deletions(-) diff --git a/moco/main.py b/moco/main.py index 4c58074..a5de9eb 100644 --- a/moco/main.py +++ b/moco/main.py @@ -74,7 +74,7 @@ test_loader = DataLoader(test_data, batch_size=args.batch_size, shuffle=False, num_workers=28) Path(args.check_point.split('/')[1]).mkdir(parents=True, exist_ok=True) - Path(args.logs_root.split('/')[1]).mkdir(parents=True, exist_ok=True) + Path(args.logs_root).mkdir(parents=True, exist_ok=True) f_q = torch.nn.DataParallel(MoCo(args), device_ids=[0, 1]).to(device) f_k = get_momentum_encoder(f_q) diff --git a/networks/layers.py b/networks/layers.py index c022d52..fbd03c2 100644 --- a/networks/layers.py +++ b/networks/layers.py @@ -69,7 +69,7 @@ def __init__(self, in_channels, activation, normalization): padding=1, ), norm, - activation, + act, nn.Conv2d( in_channels, in_channels, diff --git a/simsiam/main.py b/simsiam/main.py index 07cbf25..cd19c71 100644 --- a/simsiam/main.py +++ b/simsiam/main.py @@ -79,7 +79,7 @@ def cosine_loss(p, z): writer = SummaryWriter(args.logs_root) model = SimSiam(args).to(device) Path(args.check_point.split('/')[1]).mkdir(parents=True, exist_ok=True) - Path(args.logs_root.split('/')[1]).mkdir(parents=True, exist_ok=True) + Path(args.logs_root).mkdir(parents=True, exist_ok=True) optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.wd) diff --git a/vae/loss.py b/vae/loss.py index beb1930..323c474 100644 --- a/vae/loss.py +++ b/vae/loss.py @@ -3,12 +3,12 @@ class VAELoss(nn.Module): - def __init__(self, args): + def __init__(self, recon=None): super().__init__() - if args.recon == 'l1': - self.recon = nn.L1Loss() - elif args.recon == 'l2': + if recon == 'l2': self.recon = nn.MSELoss() + else: + self.recon = nn.L1Loss() def _KL_Loss(self, mu, logvar): return torch.mean(-0.5 * torch.sum(1 + logvar - mu ** 2 - logvar.exp(), dim=1), dim=0) diff --git a/vae/main.py b/vae/main.py index 90ed48c..c1c72d3 100644 --- a/vae/main.py +++ b/vae/main.py @@ -20,11 +20,12 @@ parser.add_argument('--n_res_blocks', type=int, default=9, help='Number of ResNet Blocks for generators') parser.add_argument('--lr', type=float, default=0.0002, help='Learning rate for generators') parser.add_argument('--betas', type=tuple, default=(0.5, 0.999), help='Betas for Adam optimizer') -parser.add_argument('--epochs', type=int, default=200, help='Number of epochs') +parser.add_argument('--n_epochs', type=int, default=200, help='Number of epochs') parser.add_argument('--batch_size', type=int, default=256, help='Batch size') parser.add_argument('--sample_size', type=int, default=32, help='Size of sampled images') parser.add_argument('--log_dir', type=str, default='vae/logs', help='Path to where log files will be saved') parser.add_argument('--data_path', type=str, default='data/img_align_celeba', help='Path to where image data is located') +parser.add_argument('--device_ids', type=list, default=[0, 1], help='List of GPU devices') parser.add_argument('--img_ext', type=str, default='.jpg', help='Image extentions') parser.add_argument('--checkpoint_dir', type=str, default='vae/model_weights', help='Path to where model weights will be saved') args = parser.parse_args() @@ -32,13 +33,13 @@ if __name__ == '__main__': writer = SummaryWriter(args.log_dir) - Path(args.check_point.split('/')[1]).mkdir(parents=True, exist_ok=True) - Path(args.log_dir.split('/')[1]).mkdir(parents=True, exist_ok=True) + Path(args.checkpoint_dir).mkdir(parents=True, exist_ok=True) + Path(args.log_dir).mkdir(parents=True, exist_ok=True) loader = get_loaders(args) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - model = VAE(args).to(device) + model = torch.nn.DataParallel(VAE(args), device_ids=args.device_ids).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=args.betas) scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lambda epoch: 0.95) fixed_z = torch.randn(args.sample_size, args.z_dim).to(device) @@ -70,4 +71,4 @@ torch.save({ 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), - }, f"{opt.checkpoint_dir}/VAE.pth") + }, f"{args.checkpoint_dir}/VAE.pth") diff --git a/vae/model.py b/vae/model.py index e913fb4..8c92c5e 100644 --- a/vae/model.py +++ b/vae/model.py @@ -5,6 +5,7 @@ class VAE(nn.Module): def __init__(self, args): + super().__init__() self.args = args self.encoder = nn.Sequential( ConvNormAct(args.img_channels, args.model_dim, 'down'), @@ -16,16 +17,11 @@ def __init__(self, args): nn.Flatten(), nn.Linear((args.img_size // (2**4))**2 * args.model_dim * 8, args.z_dim * 2) ) - self.decoder = nn.Sequential( + self.projector = nn.Sequential( nn.Linear(args.z_dim, (args.img_size // (2**4))**2 * args.model_dim * 8), nn.BatchNorm1d((args.img_size // (2**4))**2 * args.model_dim * 8), - nn.ReLU(), - Reshape( - args.batch_size, - args.model_dim * 8, - args.img_size // (2**4), - args.img_size // (2**4) - ), + nn.ReLU()) + self.decoder = nn.Sequential( ConvNormAct(args.model_dim * 8, args.model_dim * 4, 'up'), ConvNormAct(args.model_dim * 4, args.model_dim * 2, 'up'), ConvNormAct(args.model_dim * 2, args.model_dim, 'up'), @@ -35,17 +31,26 @@ def __init__(self, args): def reparameterize(self, mu, logvar): batch_size = mu.size(0) - z = torch.randn(batch_size, self.args.z_dim) * torch.sqrt(torch.exp(logvar)) + mu + z = torch.randn( + batch_size, + self.args.z_dim, + device=mu.device) * torch.sqrt(torch.exp(logvar)) + mu return z def forward(self, x): + batch_size = x.size(0) z = self.encoder(x) z = z.view(-1, 2, self.args.z_dim) mu, logvar = z[:, 0, :], z[:, 1, :] z = self.reparameterize(mu, logvar) + z = self.projector(z).view(batch_size, self.args.model_dim * 8, + self.args.img_size // (2**4), self.args.img_size // (2**4)) return self.decoder(z), mu, logvar def sample(self, z=None, num_samples=50): if z is None: z = torch.randn(num_samples, self.args.z_dim) + num_samples = z.size(0) + z = self.projector(z).view(num_samples, self.args.model_dim * 8, + self.args.img_size // (2**4), self.args.img_size // (2**4)) return self.decoder(z)