diff --git a/sagan/model.py b/sagan/model.py index 2717b14..d0774e6 100644 --- a/sagan/model.py +++ b/sagan/model.py @@ -22,21 +22,21 @@ def __init__(self, img_channels, h_dim, img_size): initialize_modules(self, init_type='ortho') def forward(self, x): - x = self.disc(x) - x = x.view(-1, self.in_features) - return self.fc(x) + x = self.disc(x) # (bs, ch, h, w) -> (bs, h_dim*8, 1, 1) + x = x.squeeze() # (bs, h_dim*8, 1, 1) -> (bs, h_dim*8) + return self.fc(x) # (bs, h_dim*8) -> (bs, 1) class Generator(nn.Module): def __init__(self, h_dim, z_dim, img_channels, img_size): super().__init__() - self.min_hw = (img_size // (2 ** 6)) ** 2 + self.min_hw = (img_size // (2 ** 6)) self.h_dim = h_dim - self.project = SN_Linear(in_features=z_dim, out_features=h_dim*8 * self.min_hw ** 2, bias=False) + self.project = SN_Linear(in_features=z_dim, out_features=h_dim*8 * self.min_hw**2, bias=False) self.gen = nn.Sequential( nn.BatchNorm2d(h_dim*8, momentum=0.9), nn.ReLU(), - ConvNormAct(h_dim*8, h_dim*8, 'sn', None, activation='relu', normalization='bn'), + ConvNormAct(h_dim*8, h_dim*8, 'sn', 'up', activation='relu', normalization='bn'), ConvNormAct(h_dim*8, h_dim*8, 'sn', 'up', activation='relu', normalization='bn'), ConvNormAct(h_dim*8, h_dim*4, 'sn', 'up', activation='relu', normalization='bn'), ConvNormAct(h_dim*4, h_dim*2, 'sn', 'up', activation='relu', normalization='bn'), @@ -50,6 +50,10 @@ def __init__(self, h_dim, z_dim, img_channels, img_size): def forward(self, x): batch_size = x.size(0) - x = self.project(x) - x = x.view(batch_size, self.h_dim*8, self.min_hw, self.min_hw) - return self.gen(x) + x = self.project(x) # (bs, z_dim) -> (bs, h_dim*8*hw_min**2) + x = x.view( + batch_size, + self.h_dim*8, + self.min_hw, + self.min_hw) # (bs, h_dim*8*hw_min**2) -> (bs, h_dim*8, hw_min, hw_min) + return self.gen(x) # (bs, h_dim*8, hw_min, hw_min) -> (bs, ch, w, h)