diff --git a/moco/main.py b/moco/main.py index a5de9eb..d9ea0ed 100644 --- a/moco/main.py +++ b/moco/main.py @@ -73,7 +73,7 @@ test_data = CIFAR10(root=args.data_root, train=False, transform=test_transform, download=True) test_loader = DataLoader(test_data, batch_size=args.batch_size, shuffle=False, num_workers=28) - Path(args.check_point.split('/')[1]).mkdir(parents=True, exist_ok=True) + Path('/'.join(args.check_point.split('/')[:-1])).mkdir(parents=True, exist_ok=True) Path(args.logs_root).mkdir(parents=True, exist_ok=True) f_q = torch.nn.DataParallel(MoCo(args), device_ids=[0, 1]).to(device) diff --git a/simsiam/data.py b/simsiam/data.py index 92dc7a2..34125a8 100644 --- a/simsiam/data.py +++ b/simsiam/data.py @@ -8,6 +8,7 @@ class CIFAR10Pairs(CIFAR10): + """Outputs two versions of same image through two different transforms""" def __getitem__(self, index): img = self.data[index] img = Image.fromarray(img) diff --git a/simsiam/main.py b/simsiam/main.py index cd19c71..d9c2588 100644 --- a/simsiam/main.py +++ b/simsiam/main.py @@ -26,6 +26,8 @@ parser.add_argument('--batch_size', default=128, type=int, metavar='N', help='mini-batch size') parser.add_argument('--wd', default=5e-4, type=float, metavar='W', help='weight decay') parser.add_argument('--momentum', default=0.9, type=float, help='momentum for optimizer') +parser.add_argument('--symmetric', action="store_true", default=True, help='loss function is symmetric') +parser.add_argument('--device_ids', type=list, default=[0, 1], help='List of GPU devices') # simsiam model configs parser.add_argument('-a', '--backbone', default='resnet18') @@ -42,12 +44,6 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') -def cosine_loss(p, z): - z = z.detach() - p = F.normalize(p, dim=1) - z = F.normalize(z, dim=1) - return -(p @ z.T).mean() - if __name__ == '__main__': """https://github.com/facebookresearch/moco""" @@ -77,8 +73,8 @@ def cosine_loss(p, z): test_loader = DataLoader(test_data, batch_size=args.batch_size, shuffle=False, num_workers=28) writer = SummaryWriter(args.logs_root) - model = SimSiam(args).to(device) - Path(args.check_point.split('/')[1]).mkdir(parents=True, exist_ok=True) + model = torch.nn.DataParallel(SimSiam(args), device_ids=args.device_ids).to(device) + Path('/'.join(args.check_point.split('/')[:-1])).mkdir(parents=True, exist_ok=True) Path(args.logs_root).mkdir(parents=True, exist_ok=True) optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, @@ -92,8 +88,10 @@ def cosine_loss(p, z): for x1, x2 in train_loader: x1, x2 = x1.to(device), x2.to(device) z1, z2, p1, p2 = model(x1, x2) - # symmetric loss - loss = (cosine_loss(p1, z2) + cosine_loss(p2, z1)) / 2 + if args.symmetric: + loss = (model.module.cosine_loss(p1, z2) + model.module.cosine_loss(p2, z1)) / 2 + else: + loss = model.module.cosine_loss(p1, z2) train_losses.append(loss.item()) optimizer.zero_grad() loss.backward() @@ -134,7 +132,7 @@ def cosine_loss(p, z): f'Epoch {epoch + 1}/{args.epochs}, \ Train Loss: {sum(train_losses) / len(train_losses):.3f}, \ Top Acc @ 1: {top1acc:.3f}, \ - Learning Rate: {scheduler.get_last_lr()}' + Learning Rate: {scheduler.get_last_lr()[0]}' ) torch.save(model.state_dict(), args.check_point) scheduler.step() diff --git a/simsiam/model.py b/simsiam/model.py index 3370e11..9ed8be4 100644 --- a/simsiam/model.py +++ b/simsiam/model.py @@ -1,6 +1,8 @@ """https://github.com/facebookresearch/moco""" +import torch from torch import nn +from torch.nn import functional as F import torchvision @@ -49,3 +51,8 @@ def forward(self, x1, x2=None, istrain=True): return z1, z2, p1, p2 else: return self.encoder(x1) + + def cosine_loss(self, p, z): + p = F.normalize(p, dim=1) + z = F.normalize(z, dim=1).detach() + return -torch.einsum('ij,ij->i', p, z).mean() diff --git a/vae/main.py b/vae/main.py index 5a71eb4..bbec4bc 100644 --- a/vae/main.py +++ b/vae/main.py @@ -14,29 +14,40 @@ parser = argparse.ArgumentParser() -# training + +# image settings parser.add_argument('--img_channels', type=int, default=3, help='Numer of channels for images') -parser.add_argument('--model_dim', type=float, default=128, help='model dimensions multiplier') -parser.add_argument('--z_dim', type=float, default=100, help='dimension of random noise latent vector') parser.add_argument('--img_size', type=int, default=64, help='H, W of the input images') parser.add_argument('--crop_size', type=int, default=128, help='H, W of the input images') + +# model params +parser.add_argument('--z_dim', type=float, default=100, help='dimension of random noise latent vector') parser.add_argument('--n_res_blocks', type=int, default=1, help='Number of ResNet Blocks for generators') -parser.add_argument('--lr', type=float, default=0.0005, help='Learning rate for generators') -parser.add_argument('--betas', type=tuple, default=(0.5, 0.999), help='Betas for Adam optimizer') +parser.add_argument('--model_dim', type=float, default=128, help='model dimensions multiplier') + +# loss fn parser.add_argument('--beta', type=float, default=1., help='Beta hyperparam for KLD Loss') parser.add_argument('--recon', type=str, default='bce', help='Reconstruction loss type [bce, l2]') + +# training hyperparams +parser.add_argument('--device_ids', type=list, default=[0, 1], help='List of GPU devices') +parser.add_argument('--lr', type=float, default=0.0005, help='Learning rate for generators') +parser.add_argument('--betas', type=tuple, default=(0.5, 0.999), help='Betas for Adam optimizer') parser.add_argument('--n_epochs', type=int, default=200, help='Number of epochs') parser.add_argument('--batch_size', type=int, default=512, help='Batch size') -parser.add_argument('--sample_size', type=int, default=64, help='Size of sampled images') + +# logging parser.add_argument('--log_dir', type=str, default='vae/logs', help='Path to where log files will be saved') parser.add_argument('--data_path', type=str, default='data/img_align_celeba', help='Path to where image data is located') -parser.add_argument('--device_ids', type=list, default=[0, 1], help='List of GPU devices') +parser.add_argument('--sample_path', type=str, default='vae/samples', help='Path to where samples are saved') parser.add_argument('--img_ext', type=str, default='.jpg', help='Image extentions') parser.add_argument('--checkpoint_dir', type=str, default='vae/model_weights', help='Path to where model weights will be saved') -# for sampler +# for sampling +parser.add_argument('--sample_size', type=int, default=64, help='Size of sampled images') parser.add_argument('--sample', action="store_true", default=False, help='Sample from VAE') parser.add_argument('--walk', action="store_true", default=False, help='Walk through a feature & sample') + args = parser.parse_args() @@ -48,19 +59,27 @@ loader = get_loaders(args) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + # initialize model, instantiate opt & scheduler & loss fn model = torch.nn.DataParallel(VAE(args), device_ids=args.device_ids).to(device) model.apply(initialize_modules) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=args.betas) scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lambda epoch: 0.995) - fixed_z = torch.randn(args.sample_size, args.z_dim).to(device) criterion = VAELoss(args) + + # fixed z to see how model changes on the same latent vectors + fixed_z = torch.randn(args.sample_size, args.z_dim).to(device) + pbar = tqdm(range(args.n_epochs)) for epoch in pbar: losses, kdls, rls = [], [], [] model.train() for img in loader: x = img.to(device) + + # x_hat for recon loss, mu & logvar for kdl loss x_hat, mu, logvar = model(x) + + # return kdl & recon loss for logging purposes loss, recon_loss, kld_loss = criterion(x, x_hat, mu, logvar) losses.append(loss.item()) kdls.append(kld_loss.item()) @@ -77,9 +96,12 @@ writer.add_scalar('KLD Loss', sum(kdls) / len(kdls), global_step=epoch) writer.add_scalar('Reconstruction Loss', sum(rls) / len(rls), global_step=epoch) + # decode fixed z latent vectors model.eval() with torch.no_grad(): sampled_images = model.module.sample(fixed_z) + + # log images and losses & save model parameters writer.add_image('Fixed Generated Images', torchvision.utils.make_grid(sampled_images), global_step=epoch) writer.add_image('Reconstructed Images', torchvision.utils.make_grid(x_hat.detach()), global_step=epoch) writer.add_image('Original Images', torchvision.utils.make_grid(x.detach()), global_step=epoch) diff --git a/vae/sample.py b/vae/sample.py index 2df126c..1fceb27 100644 --- a/vae/sample.py +++ b/vae/sample.py @@ -10,22 +10,19 @@ class Sampler: - def __init__(self, sample_path='vae/samples', ext='.jpg'): + def __init__(self): self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') self.vae = torch.nn.DataParallel(VAE(args), device_ids=args.device_ids).to(self.device) self.vae.load_state_dict(torch.load(f"{args.checkpoint_dir}/VAE.pth")['model']) self.vae.eval() - Path(sample_path).mkdir(parents=True, exist_ok=True) - - self.sample_path = sample_path - self.ext = ext + Path(args.sample_path).mkdir(parents=True, exist_ok=True) def sample(self): with torch.no_grad(): samples = self.vae.module.sample(num_samples=args.sample_size) torchvision.utils.save_image( samples, - self.sample_path + f'/sample_{int(datetime.now().timestamp()*1e6)}' + self.ext) + args.sample_path + f'/sample_{int(datetime.now().timestamp()*1e6)}' + args.img_ext) def generate_walk_z(self): z = torch.randn(args.z_dim, device=self.device) @@ -40,7 +37,7 @@ def walk(self): samples = self.vae.module.sample(z=z) torchvision.utils.save_image( samples, - self.sample_path + f'/walk_{int(datetime.now().timestamp()*1e6)}' + self.ext) + args.sample_path + f'/walk_{int(datetime.now().timestamp()*1e6)}' + args.img_ext) if __name__ == '__main__':