Skip to content

Commit

Permalink
added initialization for vae
Browse files Browse the repository at this point in the history
  • Loading branch information
zsdonghao committed Dec 3, 2020
1 parent 32a8b37 commit e0baeff
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 2 deletions.
14 changes: 14 additions & 0 deletions networks/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from torch import nn


def initialize_modules(model, nonlinearity='leaky_relu'):
for m in model.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
nn.init.kaiming_normal_(
m.weight,
mode='fan_out',
nonlinearity=nonlinearity
)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.Linear)):
nn.init.normal_(m.weight, 0.0, 0.02)
nn.init.constant_(m.bias, 0)
11 changes: 9 additions & 2 deletions vae/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pathlib import Path
from datetime import datetime

from networks.utils import initialize_modules
from vae.data import get_loaders
from vae.model import VAE
from vae.loss import VAELoss
Expand All @@ -25,7 +26,7 @@
parser.add_argument('--recon', type=str, default='bce', help='Reconstruction loss type [bce, l2]')
parser.add_argument('--n_epochs', type=int, default=200, help='Number of epochs')
parser.add_argument('--batch_size', type=int, default=512, help='Batch size')
parser.add_argument('--sample_size', type=int, default=32, help='Size of sampled images')
parser.add_argument('--sample_size', type=int, default=64, help='Size of sampled images')
parser.add_argument('--log_dir', type=str, default='vae/logs', help='Path to where log files will be saved')
parser.add_argument('--data_path', type=str, default='data/img_align_celeba', help='Path to where image data is located')
parser.add_argument('--device_ids', type=list, default=[0, 1], help='List of GPU devices')
Expand All @@ -43,19 +44,22 @@
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = torch.nn.DataParallel(VAE(args), device_ids=args.device_ids).to(device)
model.apply(initialize_modules)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=args.betas)
scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lambda epoch: 0.995)
fixed_z = torch.randn(args.sample_size, args.z_dim).to(device)
criterion = VAELoss(args)
pbar = tqdm(range(args.n_epochs))
for epoch in pbar:
losses = []
losses, kdls, rls = [], [], []
model.train()
for img in loader:
x = img.to(device)
x_hat, mu, logvar = model(x)
loss, recon_loss, kld_loss = criterion(x, x_hat, mu, logvar)
losses.append(loss.item())
kdls.append(kld_loss.item())
rls.append(recon_loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
Expand All @@ -65,6 +69,9 @@

# logging & generating imgs from fixed vector
writer.add_scalar('Loss', sum(losses) / len(losses), global_step=epoch)
writer.add_scalar('KLD Loss', sum(kdls) / len(kdls), global_step=epoch)
writer.add_scalar('Reconstruction Loss', sum(rls) / len(rls), global_step=epoch)

model.eval()
with torch.no_grad():
sampled_images = model.module.sample(fixed_z)
Expand Down

0 comments on commit e0baeff

Please sign in to comment.