Skip to content

Commit

Permalink
increased model capacity
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrewzh112 committed Dec 14, 2020
1 parent c7d4a19 commit 39807d2
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion sagan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def __init__(self, img_channels, h_dim, img_size):
SA_Conv2d(h_dim*4),
ConvNormAct(h_dim*4, h_dim*8, 'sn', 'down', activation='lrelu', normalization='bn'),
ConvNormAct(h_dim*8, h_dim*8, 'sn', 'down', activation='lrelu', normalization='bn'),
ConvNormAct(h_dim*8, h_dim*8, 'sn', 'down', activation='lrelu', normalization='bn'),
nn.AdaptiveAvgPool2d(1),
)
self.in_features = h_dim*8
Expand All @@ -29,13 +30,14 @@ def forward(self, x):
class Generator(nn.Module):
def __init__(self, h_dim, z_dim, img_channels, img_size):
super().__init__()
self.min_hw = (img_size // (2 ** 5)) ** 2
self.min_hw = (img_size // (2 ** 6)) ** 2
self.h_dim = h_dim
self.project = SN_Linear(in_features=z_dim, out_features=h_dim*8 * self.min_hw ** 2, bias=False)
self.gen = nn.Sequential(
nn.BatchNorm2d(h_dim*8, momentum=0.9),
nn.ReLU(),
ConvNormAct(h_dim*8, h_dim*8, 'sn', 'up', activation='relu', normalization='bn'),
ConvNormAct(h_dim*8, h_dim*8, 'sn', 'up', activation='relu', normalization='bn'),
ConvNormAct(h_dim*8, h_dim*4, 'sn', 'up', activation='relu', normalization='bn'),
ConvNormAct(h_dim*4, h_dim*2, 'sn', 'up', activation='relu', normalization='bn'),
SA_Conv2d(h_dim*2),
Expand Down

0 comments on commit 39807d2

Please sign in to comment.