|
7 | 7 | from pathlib import Path
|
8 | 8 | from datetime import datetime
|
9 | 9 |
|
| 10 | +from networks.utils import initialize_modules |
10 | 11 | from vae.data import get_loaders
|
11 | 12 | from vae.model import VAE
|
12 | 13 | from vae.loss import VAELoss
|
|
25 | 26 | parser.add_argument('--recon', type=str, default='bce', help='Reconstruction loss type [bce, l2]')
|
26 | 27 | parser.add_argument('--n_epochs', type=int, default=200, help='Number of epochs')
|
27 | 28 | parser.add_argument('--batch_size', type=int, default=512, help='Batch size')
|
28 |
| -parser.add_argument('--sample_size', type=int, default=32, help='Size of sampled images') |
| 29 | +parser.add_argument('--sample_size', type=int, default=64, help='Size of sampled images') |
29 | 30 | parser.add_argument('--log_dir', type=str, default='vae/logs', help='Path to where log files will be saved')
|
30 | 31 | parser.add_argument('--data_path', type=str, default='data/img_align_celeba', help='Path to where image data is located')
|
31 | 32 | parser.add_argument('--device_ids', type=list, default=[0, 1], help='List of GPU devices')
|
|
43 | 44 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
44 | 45 |
|
45 | 46 | model = torch.nn.DataParallel(VAE(args), device_ids=args.device_ids).to(device)
|
| 47 | + model.apply(initialize_modules) |
46 | 48 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=args.betas)
|
47 | 49 | scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lambda epoch: 0.995)
|
48 | 50 | fixed_z = torch.randn(args.sample_size, args.z_dim).to(device)
|
49 | 51 | criterion = VAELoss(args)
|
50 | 52 | pbar = tqdm(range(args.n_epochs))
|
51 | 53 | for epoch in pbar:
|
52 |
| - losses = [] |
| 54 | + losses, kdls, rls = [], [], [] |
53 | 55 | model.train()
|
54 | 56 | for img in loader:
|
55 | 57 | x = img.to(device)
|
56 | 58 | x_hat, mu, logvar = model(x)
|
57 | 59 | loss, recon_loss, kld_loss = criterion(x, x_hat, mu, logvar)
|
58 | 60 | losses.append(loss.item())
|
| 61 | + kdls.append(kld_loss.item()) |
| 62 | + rls.append(recon_loss.item()) |
59 | 63 | optimizer.zero_grad()
|
60 | 64 | loss.backward()
|
61 | 65 | optimizer.step()
|
|
65 | 69 |
|
66 | 70 | # logging & generating imgs from fixed vector
|
67 | 71 | writer.add_scalar('Loss', sum(losses) / len(losses), global_step=epoch)
|
| 72 | + writer.add_scalar('KLD Loss', sum(kdls) / len(kdls), global_step=epoch) |
| 73 | + writer.add_scalar('Reconstruction Loss', sum(rls) / len(rls), global_step=epoch) |
| 74 | + |
68 | 75 | model.eval()
|
69 | 76 | with torch.no_grad():
|
70 | 77 | sampled_images = model.module.sample(fixed_z)
|
|
0 commit comments