From 395ab49dba38705ea3fd4f2e91246f26143c1bf8 Mon Sep 17 00:00:00 2001 From: Andrew Zhao Date: Mon, 14 Dec 2020 21:38:38 +0800 Subject: [PATCH] init changed to ortho, following biggan paper --- networks/utils.py | 21 +++++++++++++++------ sagan/model.py | 4 ++-- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/networks/utils.py b/networks/utils.py index 07a22fe..d2037d1 100644 --- a/networks/utils.py +++ b/networks/utils.py @@ -2,14 +2,23 @@ import torch -def initialize_modules(model, nonlinearity='leaky_relu'): +def initialize_modules(model, nonlinearity='leaky_relu', init_type='kaiming'): for m in model.modules(): if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): - nn.init.kaiming_normal_( - m.weight, - mode='fan_out', - nonlinearity=nonlinearity - ) + if init_type == 'kaiming': + nn.init.kaiming_normal_( + m.weight, + mode='fan_out', + nonlinearity=nonlinearity + ) + elif init_type == 'normal': + nn.init.normal_(m.weight, 0.0, 0.02) + elif init_type == 'ortho': + nn.init.orthogonal_(m.weight) + elif init_type in ['glorot', 'xavier']: + nn.init.xavier_uniform_(m.weight) + else: + print('unrecognized init type, using default PyTorch initialization scheme...') elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.Linear)): nn.init.normal_(m.weight, 0.0, 0.02) if m.bias is not None: diff --git a/sagan/model.py b/sagan/model.py index 7ff5c13..e3767b9 100644 --- a/sagan/model.py +++ b/sagan/model.py @@ -19,7 +19,7 @@ def __init__(self, img_channels, h_dim, img_size): ) self.in_features = h_dim*8 self.fc = SN_Linear(in_features=self.in_features, out_features=1) - initialize_modules(self) + initialize_modules(self, init_type='ortho') def forward(self, x): x = self.disc(x) @@ -46,7 +46,7 @@ def __init__(self, h_dim, z_dim, img_channels, img_size): stride=2, padding=1), nn.Tanh() ) - initialize_modules(self) + initialize_modules(self, init_type='ortho') def forward(self, x): batch_size = x.size(0)