@@ -14,6 +14,7 @@ def __init__(self, img_channels, h_dim, img_size):
14
14
SA_Conv2d (h_dim * 4 ),
15
15
ConvNormAct (h_dim * 4 , h_dim * 8 , 'sn' , 'down' , activation = 'lrelu' , normalization = 'bn' ),
16
16
ConvNormAct (h_dim * 8 , h_dim * 8 , 'sn' , 'down' , activation = 'lrelu' , normalization = 'bn' ),
17
+ ConvNormAct (h_dim * 8 , h_dim * 8 , 'sn' , 'down' , activation = 'lrelu' , normalization = 'bn' ),
17
18
nn .AdaptiveAvgPool2d (1 ),
18
19
)
19
20
self .in_features = h_dim * 8
@@ -29,13 +30,14 @@ def forward(self, x):
29
30
class Generator (nn .Module ):
30
31
def __init__ (self , h_dim , z_dim , img_channels , img_size ):
31
32
super ().__init__ ()
32
- self .min_hw = (img_size // (2 ** 5 )) ** 2
33
+ self .min_hw = (img_size // (2 ** 6 )) ** 2
33
34
self .h_dim = h_dim
34
35
self .project = SN_Linear (in_features = z_dim , out_features = h_dim * 8 * self .min_hw ** 2 , bias = False )
35
36
self .gen = nn .Sequential (
36
37
nn .BatchNorm2d (h_dim * 8 , momentum = 0.9 ),
37
38
nn .ReLU (),
38
39
ConvNormAct (h_dim * 8 , h_dim * 8 , 'sn' , 'up' , activation = 'relu' , normalization = 'bn' ),
40
+ ConvNormAct (h_dim * 8 , h_dim * 8 , 'sn' , 'up' , activation = 'relu' , normalization = 'bn' ),
39
41
ConvNormAct (h_dim * 8 , h_dim * 4 , 'sn' , 'up' , activation = 'relu' , normalization = 'bn' ),
40
42
ConvNormAct (h_dim * 4 , h_dim * 2 , 'sn' , 'up' , activation = 'relu' , normalization = 'bn' ),
41
43
SA_Conv2d (h_dim * 2 ),
0 commit comments