Skip to content

Commit e0baeff

Browse files
committed
added initialization for vae
1 parent 32a8b37 commit e0baeff

File tree

2 files changed

+23
-2
lines changed

2 files changed

+23
-2
lines changed

networks/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from torch import nn
2+
3+
4+
def initialize_modules(model, nonlinearity='leaky_relu'):
5+
for m in model.modules():
6+
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
7+
nn.init.kaiming_normal_(
8+
m.weight,
9+
mode='fan_out',
10+
nonlinearity=nonlinearity
11+
)
12+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.Linear)):
13+
nn.init.normal_(m.weight, 0.0, 0.02)
14+
nn.init.constant_(m.bias, 0)

vae/main.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from pathlib import Path
88
from datetime import datetime
99

10+
from networks.utils import initialize_modules
1011
from vae.data import get_loaders
1112
from vae.model import VAE
1213
from vae.loss import VAELoss
@@ -25,7 +26,7 @@
2526
parser.add_argument('--recon', type=str, default='bce', help='Reconstruction loss type [bce, l2]')
2627
parser.add_argument('--n_epochs', type=int, default=200, help='Number of epochs')
2728
parser.add_argument('--batch_size', type=int, default=512, help='Batch size')
28-
parser.add_argument('--sample_size', type=int, default=32, help='Size of sampled images')
29+
parser.add_argument('--sample_size', type=int, default=64, help='Size of sampled images')
2930
parser.add_argument('--log_dir', type=str, default='vae/logs', help='Path to where log files will be saved')
3031
parser.add_argument('--data_path', type=str, default='data/img_align_celeba', help='Path to where image data is located')
3132
parser.add_argument('--device_ids', type=list, default=[0, 1], help='List of GPU devices')
@@ -43,19 +44,22 @@
4344
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
4445

4546
model = torch.nn.DataParallel(VAE(args), device_ids=args.device_ids).to(device)
47+
model.apply(initialize_modules)
4648
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=args.betas)
4749
scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lambda epoch: 0.995)
4850
fixed_z = torch.randn(args.sample_size, args.z_dim).to(device)
4951
criterion = VAELoss(args)
5052
pbar = tqdm(range(args.n_epochs))
5153
for epoch in pbar:
52-
losses = []
54+
losses, kdls, rls = [], [], []
5355
model.train()
5456
for img in loader:
5557
x = img.to(device)
5658
x_hat, mu, logvar = model(x)
5759
loss, recon_loss, kld_loss = criterion(x, x_hat, mu, logvar)
5860
losses.append(loss.item())
61+
kdls.append(kld_loss.item())
62+
rls.append(recon_loss.item())
5963
optimizer.zero_grad()
6064
loss.backward()
6165
optimizer.step()
@@ -65,6 +69,9 @@
6569

6670
# logging & generating imgs from fixed vector
6771
writer.add_scalar('Loss', sum(losses) / len(losses), global_step=epoch)
72+
writer.add_scalar('KLD Loss', sum(kdls) / len(kdls), global_step=epoch)
73+
writer.add_scalar('Reconstruction Loss', sum(rls) / len(rls), global_step=epoch)
74+
6875
model.eval()
6976
with torch.no_grad():
7077
sampled_images = model.module.sample(fixed_z)

0 commit comments

Comments
 (0)