Skip to content

Commit

Permalink
init changed to ortho, following biggan paper
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrewzh112 committed Dec 14, 2020
1 parent 39807d2 commit 395ab49
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 8 deletions.
21 changes: 15 additions & 6 deletions networks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,23 @@
import torch


def initialize_modules(model, nonlinearity='leaky_relu'):
def initialize_modules(model, nonlinearity='leaky_relu', init_type='kaiming'):
for m in model.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
nn.init.kaiming_normal_(
m.weight,
mode='fan_out',
nonlinearity=nonlinearity
)
if init_type == 'kaiming':
nn.init.kaiming_normal_(
m.weight,
mode='fan_out',
nonlinearity=nonlinearity
)
elif init_type == 'normal':
nn.init.normal_(m.weight, 0.0, 0.02)
elif init_type == 'ortho':
nn.init.orthogonal_(m.weight)
elif init_type in ['glorot', 'xavier']:
nn.init.xavier_uniform_(m.weight)
else:
print('unrecognized init type, using default PyTorch initialization scheme...')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.Linear)):
nn.init.normal_(m.weight, 0.0, 0.02)
if m.bias is not None:
Expand Down
4 changes: 2 additions & 2 deletions sagan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self, img_channels, h_dim, img_size):
)
self.in_features = h_dim*8
self.fc = SN_Linear(in_features=self.in_features, out_features=1)
initialize_modules(self)
initialize_modules(self, init_type='ortho')

def forward(self, x):
x = self.disc(x)
Expand All @@ -46,7 +46,7 @@ def __init__(self, h_dim, z_dim, img_channels, img_size):
stride=2, padding=1),
nn.Tanh()
)
initialize_modules(self)
initialize_modules(self, init_type='ortho')

def forward(self, x):
batch_size = x.size(0)
Expand Down

0 comments on commit 395ab49

Please sign in to comment.