Skip to content

Commit

Permalink
modified model architecture for sagan
Browse files Browse the repository at this point in the history
  • Loading branch information
zsdonghao committed Dec 11, 2020
1 parent 42c7832 commit 1c11aaa
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 16 deletions.
17 changes: 10 additions & 7 deletions sagan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,31 @@ def __init__(self, img_channels, h_dim, img_size):
ConvNormAct(h_dim, h_dim*2, 'down', activation='lrelu', normalization='bn'),
ConvNormAct(h_dim*2, h_dim*4, 'down', activation='lrelu', normalization='bn'),
ConvNormAct(h_dim*4, h_dim*8, 'down', activation='lrelu', normalization='bn'),
nn.Flatten(),
nn.Linear(h_dim*8 * (img_size // (2 ** 4)) ** 2, 1)
nn.AdaptiveAvgPool2d(1),
)
self.in_features = h_dim*8
self.fc = nn.Linear(self.in_features, 1)
initialize_modules(self)

def forward(self, x):
return self.disc(x)
x = self.disc(x)
x = x.view(-1, self.in_features)
return self.fc(x)


class Generator(nn.Module):
def __init__(self, h_dim, z_dim, img_channels, img_size):
super().__init__()
self.s16 = img_size // 16
self.min_hw = (img_size // (2 ** 5)) ** 2
self.h_dim = h_dim
self.project = nn.Linear(z_dim, h_dim*8*self.s16*self.s16)
self.project = nn.Linear(z_dim, h_dim*8 * self.min_hw ** 2)
self.gen = nn.Sequential(
nn.BatchNorm2d(h_dim*8, momentum=0.9),
nn.ReLU(),
ConvNormAct(h_dim*8, h_dim*4, 'up', activation='relu', normalization='bn'),
ConvNormAct(h_dim*4, h_dim*2, 'up', activation='relu', normalization='bn'),
SA_Conv2d(h_dim*2),
ConvNormAct(h_dim*2, h_dim, 'up', activation='relu', normalization='bn'),
SA_Conv2d(h_dim),
nn.ConvTranspose2d(h_dim, img_channels, 4, 2, 1),
nn.Sigmoid()
)
Expand All @@ -41,5 +44,5 @@ def __init__(self, h_dim, z_dim, img_channels, img_size):
def forward(self, x):
batch_size = x.size(0)
x = self.project(x)
x = x.view(batch_size, self.h_dim*8, self.s16, self.s16)
x = x.view(batch_size, self.h_dim*8, self.min_hw, self.min_hw)
return self.gen(x)
18 changes: 9 additions & 9 deletions sagan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@

# logging parameters
parser.add_argument('--data_path', type=str, default='data/img_align_celeba', help='Path to where image data is located')
parser.add_argument('--cpt_interval', type=int, default=100, help='Checkpoint interval')
parser.add_argument('--cpt_interval', type=int, default=500, help='Checkpoint interval')
parser.add_argument('--save_local_samples', action="store_true", default=False, help='Whether to save samples locally')
parser.add_argument('--sample_size', type=int, default=64, help='Numbers of images to log')
parser.add_argument('--checkpoint_dir', type=str, default='sagan/checkpoint', help='Path to where model weights will be saved')
Expand All @@ -48,15 +48,15 @@


def train():
writer = SummaryWriter(opt.log_dir + f'/{int(datetime.now().timestamp()*1e6)}')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# creating dirs if needed
Path(opt.checkpoint_dir).mkdir(parents=True, exist_ok=True)
Path(opt.log_dir).mkdir(parents=True, exist_ok=True)
if opt.save_local_samples:
Path(opt.sample_dir).mkdir(parents=True, exist_ok=True)

writer = SummaryWriter(opt.log_dir + f'/{int(datetime.now().timestamp()*1e6)}')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

G = torch.nn.DataParallel(Generator(opt.h_dim, opt.z_dim, opt.img_channels, opt.img_size), device_ids=opt.devices).to(device)
D = torch.nn.DataParallel(Discriminator(opt.img_channels, opt.h_dim, opt.img_size), device_ids=opt.devices).to(device)

Expand Down Expand Up @@ -119,7 +119,7 @@ def train():
'Batch ID': batch_idx})

# tensorboard logging samples, not logging first iteration
if batch_idx % opt.cpt_interval == 0 and batch_idx != 0 and epoch != 0:
if batch_idx % opt.cpt_interval == 0:
ckpt_iter += 1
G.eval()
# generate image from fixed noise vector
Expand All @@ -132,6 +132,10 @@ def train():

# save sample and loss to tensorboard
writer.add_image('Generated Images', torchvision.utils.make_grid(samples), global_step=ckpt_iter)
writer.add_scalars("Train Losses", {
"Discriminator Loss": sum(d_losses) / len(d_losses),
"Generator Loss": sum(g_losses) / len(g_losses)
}, global_step=ckpt_iter)

# resetting
G.train()
Expand All @@ -142,10 +146,6 @@ def train():
Discriminator loss: {sum(d_losses) / len(d_losses):.3f}, \
Generator Loss: {sum(g_losses) / len(g_losses):.3f}'
)
writer.add_scalars("Train Losses", {
"Discriminator Loss": sum(d_losses) / len(d_losses),
"Generator Loss": sum(g_losses) / len(g_losses)
}, global_step=epoch)
torch.save({
'D': D.state_dict(),
'G': G.state_dict(),
Expand Down

0 comments on commit 1c11aaa

Please sign in to comment.