Skip to content

Commit

Permalink
revise the loss weights
Browse files Browse the repository at this point in the history
  • Loading branch information
MingtaoGuo authored Oct 19, 2022
1 parent 4540572 commit c16fb7c
Showing 1 changed file with 9 additions and 10 deletions.
19 changes: 9 additions & 10 deletions train_simple.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
'''
This is a simplified training code of StyleSwap. It achieves comparable performance as in the paper.
@Created by rosinality and yangxy
@Modified by Mingtao Guo (gmt798714378@hotmail.com)
'''
import argparse
Expand Down Expand Up @@ -151,7 +149,7 @@ def train(args, loader, generator, discriminator, arcface, vgg19, g_optim, d_opt
requires_grad(discriminator, True)

with torch.no_grad():
z_id = arcface(F.interpolate(source, [143, 143], mode="bilinear")[..., 15:127, 15:127])
z_id = arcface(F.interpolate(source, [112, 112], mode="bilinear"))
fake_img, _, _ = generator(target, z_id.detach())
fake_pred, _ = discriminator(fake_img.detach())

Expand Down Expand Up @@ -188,7 +186,7 @@ def train(args, loader, generator, discriminator, arcface, vgg19, g_optim, d_opt
# ------------- adv loss -------------
adv_loss = g_nonsaturating_loss(fake_pred)
# ------------- id loss --------------
fake_z_id = arcface(F.interpolate(fake_img, [143, 143], mode="bilinear")[..., 15:127, 15:127])
fake_z_id = arcface(F.interpolate(fake_img, [112, 112], mode="bilinear"))
id_loss = (1 - torch.cosine_similarity(fake_z_id, z_id.detach())).mean()
# -------feature matching loss -------
fm_loss = 0
Expand All @@ -203,11 +201,10 @@ def train(args, loader, generator, discriminator, arcface, vgg19, g_optim, d_opt
rec_loss = rec_pix + rec_lpip
# ------------ mask loss -------------
bce_loss = F.binary_cross_entropy(fake_mask, mask).mean()
# bce_loss = F.l1_loss(fake_mask, mask).mean()
# ----------- total loss -------------
g_loss = adv_loss + 20 * id_loss + bce_loss + 10 * fm_loss + 10 * rec_loss
g_loss = adv_loss + 20 * id_loss + bce_loss + 100 * fm_loss + 100 * rec_loss
loss_dict['g'] = g_loss
loss_dict['adv'] = adv_loss
loss_dict['adv'] = adv_loss
loss_dict['id'] = id_loss
loss_dict['fm'] = fm_loss
loss_dict['rec'] = rec_loss
Expand Down Expand Up @@ -276,7 +273,9 @@ def train(args, loader, generator, discriminator, arcface, vgg19, g_optim, d_opt
g_ema.eval()
sample, _, fake_mask = g_ema(target, z_id)
sample = sample * fake_mask + target * (1 - fake_mask)
sample = torch.cat((target, sample, (torch.cat([fake_mask, fake_mask, fake_mask], dim=1)-0.5)/0.5, source), 0)
fake_mask = (torch.cat([fake_mask, fake_mask, fake_mask], dim=1) - 0.5) / 0.5
mask = (torch.cat([mask, mask, mask], dim=1) - 0.5) / 0.5
sample = torch.cat((target, source, sample, fake_mask, mask), 0)
utils.save_image(
sample,
f'{args.sample}/{str(i).zfill(6)}.png',
Expand All @@ -302,8 +301,8 @@ def train(args, loader, generator, discriminator, arcface, vgg19, g_optim, d_opt

parser = argparse.ArgumentParser()

parser.add_argument('--img_path', type=str, default="/data1/GMT/Dataset/FFHQ256/")
parser.add_argument('--mask_path', type=str, default="/data1/GMT/Dataset/FFHQ256parsing/")
parser.add_argument('--img_path', type=str, default="/data1/GMT/Dataset/FFHQ256std_ldmk/")
parser.add_argument('--mask_path', type=str, default="/data1/GMT/Dataset/FFHQ256std_ldmk_mask/")
parser.add_argument('--base_dir', type=str, default='./')
parser.add_argument('--arcface', type=str, default='saved_models/backbone.pth')
parser.add_argument('--iter', type=int, default=4000000)
Expand Down

0 comments on commit c16fb7c

Please sign in to comment.