Skip to content

Commit

Permalink
changed model output size
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrewzh112 committed Dec 14, 2020
1 parent e4826c8 commit 9e9a79e
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions sagan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,21 @@ def __init__(self, img_channels, h_dim, img_size):
initialize_modules(self, init_type='ortho')

def forward(self, x):
x = self.disc(x)
x = x.view(-1, self.in_features)
return self.fc(x)
x = self.disc(x) # (bs, ch, h, w) -> (bs, h_dim*8, 1, 1)
x = x.squeeze() # (bs, h_dim*8, 1, 1) -> (bs, h_dim*8)
return self.fc(x) # (bs, h_dim*8) -> (bs, 1)


class Generator(nn.Module):
def __init__(self, h_dim, z_dim, img_channels, img_size):
super().__init__()
self.min_hw = (img_size // (2 ** 6)) ** 2
self.min_hw = (img_size // (2 ** 6))
self.h_dim = h_dim
self.project = SN_Linear(in_features=z_dim, out_features=h_dim*8 * self.min_hw ** 2, bias=False)
self.project = SN_Linear(in_features=z_dim, out_features=h_dim*8 * self.min_hw**2, bias=False)
self.gen = nn.Sequential(
nn.BatchNorm2d(h_dim*8, momentum=0.9),
nn.ReLU(),
ConvNormAct(h_dim*8, h_dim*8, 'sn', None, activation='relu', normalization='bn'),
ConvNormAct(h_dim*8, h_dim*8, 'sn', 'up', activation='relu', normalization='bn'),
ConvNormAct(h_dim*8, h_dim*8, 'sn', 'up', activation='relu', normalization='bn'),
ConvNormAct(h_dim*8, h_dim*4, 'sn', 'up', activation='relu', normalization='bn'),
ConvNormAct(h_dim*4, h_dim*2, 'sn', 'up', activation='relu', normalization='bn'),
Expand All @@ -50,6 +50,10 @@ def __init__(self, h_dim, z_dim, img_channels, img_size):

def forward(self, x):
batch_size = x.size(0)
x = self.project(x)
x = x.view(batch_size, self.h_dim*8, self.min_hw, self.min_hw)
return self.gen(x)
x = self.project(x) # (bs, z_dim) -> (bs, h_dim*8*hw_min**2)
x = x.view(
batch_size,
self.h_dim*8,
self.min_hw,
self.min_hw) # (bs, h_dim*8*hw_min**2) -> (bs, h_dim*8, hw_min, hw_min)
return self.gen(x) # (bs, h_dim*8, hw_min, hw_min) -> (bs, ch, w, h)

0 comments on commit 9e9a79e

Please sign in to comment.