forked from neuralchen/SimSwap
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfs_model.py
242 lines (188 loc) · 9.37 KB
/
fs_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
from torch.autograd import Variable
from util.image_pool import ImagePool
from .base_model import BaseModel
from . import networks
from .fs_networks import Generator_Adain_Upsample, Discriminator
class SpecificNorm(nn.Module):
def __init__(self, epsilon=1e-8):
"""
@notice: avoid in-place ops.
https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3
"""
super(SpecificNorm, self).__init__()
self.mean = np.array([0.485, 0.456, 0.406])
self.mean = torch.from_numpy(self.mean).float().cuda()
self.mean = self.mean.view([1, 3, 1, 1])
self.std = np.array([0.229, 0.224, 0.225])
self.std = torch.from_numpy(self.std).float().cuda()
self.std = self.std.view([1, 3, 1, 1])
def forward(self, x):
mean = self.mean.expand([1, 3, x.shape[2], x.shape[3]])
std = self.std.expand([1, 3, x.shape[2], x.shape[3]])
x = (x - mean) / std
return x
class fsModel(BaseModel):
def name(self):
return 'fsModel'
def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss):
flags = (True, use_gan_feat_loss, use_vgg_loss, True, True, True, True, True)
def loss_filter(g_gan, g_gan_feat, g_vgg, g_id, g_rec, g_mask, d_real, d_fake):
return [l for (l, f) in zip((g_gan, g_gan_feat, g_vgg, g_id, g_rec, g_mask, d_real, d_fake), flags) if f]
return loss_filter
def initialize(self, opt):
BaseModel.initialize(self, opt)
if opt.resize_or_crop != 'none' or not opt.isTrain: # when training at full res this causes OOM
torch.backends.cudnn.benchmark = True
self.isTrain = opt.isTrain
device = torch.device("cuda:0")
# Generator network
self.netG = Generator_Adain_Upsample(input_nc=3, output_nc=3, latent_size=512, n_blocks=9, deep=False)
self.netG.to(device)
# Id network
netArc_checkpoint = opt.Arc_path
netArc_checkpoint = torch.load(netArc_checkpoint)
self.netArc = netArc_checkpoint['model'].module
self.netArc = self.netArc.to(device)
self.netArc.eval()
if not self.isTrain:
pretrained_path = '' if not self.isTrain else opt.load_pretrain
self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)
return
# Discriminator network
if opt.gan_mode == 'original':
use_sigmoid = True
else:
use_sigmoid = False
self.netD1 = Discriminator(input_nc=3, use_sigmoid=use_sigmoid)
self.netD2 = Discriminator(input_nc=3, use_sigmoid=use_sigmoid)
self.netD1.to(device)
self.netD2.to(device)
#
self.spNorm =SpecificNorm()
self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
# load networks
if opt.continue_train or opt.load_pretrain:
pretrained_path = '' if not self.isTrain else opt.load_pretrain
# print (pretrained_path)
self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)
self.load_network(self.netD1, 'D1', opt.which_epoch, pretrained_path)
self.load_network(self.netD2, 'D2', opt.which_epoch, pretrained_path)
if self.isTrain:
# define loss functions
self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss, not opt.no_vgg_loss)
self.criterionGAN = networks.GANLoss(opt.gan_mode, tensor=self.Tensor, opt=self.opt)
self.criterionFeat = nn.L1Loss()
self.criterionRec = nn.L1Loss()
# Names so we can breakout loss
self.loss_names = self.loss_filter('G_GAN', 'G_GAN_Feat', 'G_VGG', 'G_ID', 'G_Rec', 'D_GP',
'D_real', 'D_fake')
# initialize optimizers
# optimizer G
params = list(self.netG.parameters())
self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
# optimizer D
params = list(self.netD1.parameters()) + list(self.netD2.parameters())
self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
def _gradinet_penalty_D(self, netD, img_att, img_fake):
# interpolate sample
bs = img_fake.shape[0]
alpha = torch.rand(bs, 1, 1, 1).expand_as(img_fake).cuda()
interpolated = Variable(alpha * img_att + (1 - alpha) * img_fake, requires_grad=True)
pred_interpolated = netD.forward(interpolated)
pred_interpolated = pred_interpolated[-1]
# compute gradients
grad = torch.autograd.grad(outputs=pred_interpolated,
inputs=interpolated,
grad_outputs=torch.ones(pred_interpolated.size()).cuda(),
retain_graph=True,
create_graph=True,
only_inputs=True)[0]
# penalize gradients
grad = grad.view(grad.size(0), -1)
grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
loss_d_gp = torch.mean((grad_l2norm - 1) ** 2)
return loss_d_gp
def cosin_metric(self, x1, x2):
#return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2))
return torch.sum(x1 * x2, dim=1) / (torch.norm(x1, dim=1) * torch.norm(x2, dim=1))
def forward(self, img_id, img_att, latent_id, latent_att, for_G=False):
loss_D_fake, loss_D_real, loss_D_GP = 0, 0, 0
loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_G_ID, loss_G_Rec = 0,0,0,0,0
img_fake = self.netG.forward(img_att, latent_id)
if not self.isTrain:
return img_fake
img_fake_downsample = self.downsample(img_fake)
img_att_downsample = self.downsample(img_att)
# D_Fake
fea1_fake = self.netD1.forward(img_fake.detach())
fea2_fake = self.netD2.forward(img_fake_downsample.detach())
pred_fake = [fea1_fake, fea2_fake]
loss_D_fake = self.criterionGAN(pred_fake, False, for_discriminator=True)
# D_Feal
fea1_real = self.netD1.forward(img_att)
fea2_real = self.netD2.forward(img_att_downsample)
pred_real = [fea1_real, fea2_real]
fea_real = [fea1_real, fea2_real]
loss_D_real = self.criterionGAN(pred_real, True, for_discriminator=True)
#print('=====================D_Real========================')
# D_GP
loss_D_GP = 0
# G_GAN
fea1_fake = self.netD1.forward(img_fake)
fea2_fake = self.netD2.forward(img_fake_downsample)
#pred_fake = [fea1_fake[-1], fea2_fake[-1]]
pred_fake = [fea1_fake, fea2_fake]
fea_fake = [fea1_fake, fea2_fake]
loss_G_GAN = self.criterionGAN(pred_fake, True, for_discriminator=False)
# GAN feature matching loss
n_layers_D = 4
num_D = 2
if not self.opt.no_ganFeat_loss:
feat_weights = 4.0 / (n_layers_D + 1)
D_weights = 1.0 / num_D
for i in range(num_D):
for j in range(0, len(fea_fake[i]) - 1):
loss_G_GAN_Feat += D_weights * feat_weights * \
self.criterionFeat(fea_fake[i][j],
fea_real[i][j].detach()) * self.opt.lambda_feat
#G_ID
img_fake_down = F.interpolate(img_fake, scale_factor=0.5)
img_fake_down = self.spNorm(img_fake_down)
latent_fake = self.netArc(img_fake_down)
loss_G_ID = (1 - self.cosin_metric(latent_fake, latent_id))
#print('=====================G_ID========================')
#print(loss_G_ID)
#G_Rec
loss_G_Rec = self.criterionRec(img_fake, img_att) * self.opt.lambda_rec
# Only return the fake_B image if necessary to save BW
return [self.loss_filter(loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_G_ID, loss_G_Rec, loss_D_GP, loss_D_real, loss_D_fake),
img_fake]
def save(self, which_epoch):
self.save_network(self.netG, 'G', which_epoch, self.gpu_ids)
self.save_network(self.netD1, 'D1', which_epoch, self.gpu_ids)
self.save_network(self.netD2, 'D2', which_epoch, self.gpu_ids)
'''if self.gen_features:
self.save_network(self.netE, 'E', which_epoch, self.gpu_ids)'''
def update_fixed_params(self):
# after fixing the global generator for a number of iterations, also start finetuning it
params = list(self.netG.parameters())
if self.gen_features:
params += list(self.netE.parameters())
self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999))
if self.opt.verbose:
print('------------ Now also finetuning global generator -----------')
def update_learning_rate(self):
lrd = self.opt.lr / self.opt.niter_decay
lr = self.old_lr - lrd
for param_group in self.optimizer_D.param_groups:
param_group['lr'] = lr
for param_group in self.optimizer_G.param_groups:
param_group['lr'] = lr
if self.opt.verbose:
print('update learning rate: %f -> %f' % (self.old_lr, lr))
self.old_lr = lr