Skip to content

Commit 39807d2

Browse files
committed
increased model capacity
1 parent c7d4a19 commit 39807d2

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

sagan/model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def __init__(self, img_channels, h_dim, img_size):
1414
SA_Conv2d(h_dim*4),
1515
ConvNormAct(h_dim*4, h_dim*8, 'sn', 'down', activation='lrelu', normalization='bn'),
1616
ConvNormAct(h_dim*8, h_dim*8, 'sn', 'down', activation='lrelu', normalization='bn'),
17+
ConvNormAct(h_dim*8, h_dim*8, 'sn', 'down', activation='lrelu', normalization='bn'),
1718
nn.AdaptiveAvgPool2d(1),
1819
)
1920
self.in_features = h_dim*8
@@ -29,13 +30,14 @@ def forward(self, x):
2930
class Generator(nn.Module):
3031
def __init__(self, h_dim, z_dim, img_channels, img_size):
3132
super().__init__()
32-
self.min_hw = (img_size // (2 ** 5)) ** 2
33+
self.min_hw = (img_size // (2 ** 6)) ** 2
3334
self.h_dim = h_dim
3435
self.project = SN_Linear(in_features=z_dim, out_features=h_dim*8 * self.min_hw ** 2, bias=False)
3536
self.gen = nn.Sequential(
3637
nn.BatchNorm2d(h_dim*8, momentum=0.9),
3738
nn.ReLU(),
3839
ConvNormAct(h_dim*8, h_dim*8, 'sn', 'up', activation='relu', normalization='bn'),
40+
ConvNormAct(h_dim*8, h_dim*8, 'sn', 'up', activation='relu', normalization='bn'),
3941
ConvNormAct(h_dim*8, h_dim*4, 'sn', 'up', activation='relu', normalization='bn'),
4042
ConvNormAct(h_dim*4, h_dim*2, 'sn', 'up', activation='relu', normalization='bn'),
4143
SA_Conv2d(h_dim*2),

0 commit comments

Comments
 (0)