forked from neuralchen/SimSwap
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprojected_model.py
122 lines (98 loc) · 4.59 KB
/
projected_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: fs_model_fix_idnorm_donggp_saveoptim copy.py
# Created Date: Wednesday January 12th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Thursday, 21st April 2022 8:13:37 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
import torch
import torch.nn as nn
from .base_model import BaseModel
from .fs_networks_fix import Generator_Adain_Upsample
from pg_modules.projected_discriminator import ProjectedDiscriminator
def compute_grad2(d_out, x_in):
batch_size = x_in.size(0)
grad_dout = torch.autograd.grad(
outputs=d_out.sum(), inputs=x_in,
create_graph=True, retain_graph=True, only_inputs=True
)[0]
grad_dout2 = grad_dout.pow(2)
assert(grad_dout2.size() == x_in.size())
reg = grad_dout2.view(batch_size, -1).sum(1)
return reg
class fsModel(BaseModel):
def name(self):
return 'fsModel'
def initialize(self, opt):
BaseModel.initialize(self, opt)
# if opt.resize_or_crop != 'none' or not opt.isTrain: # when training at full res this causes OOM
self.isTrain = opt.isTrain
# Generator network
self.netG = Generator_Adain_Upsample(input_nc=3, output_nc=3, latent_size=512, n_blocks=9, deep=opt.Gdeep)
self.netG.cuda()
# Id network
netArc_checkpoint = opt.Arc_path
netArc_checkpoint = torch.load(netArc_checkpoint, map_location=torch.device("cpu"))
self.netArc = netArc_checkpoint['model'].module
self.netArc = self.netArc.cuda()
self.netArc.eval()
self.netArc.requires_grad_(False)
if not self.isTrain:
pretrained_path = opt.checkpoints_dir
self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)
return
self.netD = ProjectedDiscriminator(diffaug=False, interp224=False, **{})
# self.netD.feature_network.requires_grad_(False)
self.netD.cuda()
if self.isTrain:
# define loss functions
self.criterionFeat = nn.L1Loss()
self.criterionRec = nn.L1Loss()
# initialize optimizers
# optimizer G
params = list(self.netG.parameters())
self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.99),eps=1e-8)
# optimizer D
params = list(self.netD.parameters())
self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.99),eps=1e-8)
# load networks
if opt.continue_train:
pretrained_path = '' if not self.isTrain else opt.load_pretrain
# print (pretrained_path)
self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)
self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path)
self.load_optim(self.optimizer_G, 'G', opt.which_epoch, pretrained_path)
self.load_optim(self.optimizer_D, 'D', opt.which_epoch, pretrained_path)
torch.cuda.empty_cache()
def cosin_metric(self, x1, x2):
#return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2))
return torch.sum(x1 * x2, dim=1) / (torch.norm(x1, dim=1) * torch.norm(x2, dim=1))
def save(self, which_epoch):
self.save_network(self.netG, 'G', which_epoch)
self.save_network(self.netD, 'D', which_epoch)
self.save_optim(self.optimizer_G, 'G', which_epoch)
self.save_optim(self.optimizer_D, 'D', which_epoch)
'''if self.gen_features:
self.save_network(self.netE, 'E', which_epoch, self.gpu_ids)'''
def update_fixed_params(self):
# after fixing the global generator for a number of iterations, also start finetuning it
params = list(self.netG.parameters())
if self.gen_features:
params += list(self.netE.parameters())
self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999))
if self.opt.verbose:
print('------------ Now also finetuning global generator -----------')
def update_learning_rate(self):
lrd = self.opt.lr / self.opt.niter_decay
lr = self.old_lr - lrd
for param_group in self.optimizer_D.param_groups:
param_group['lr'] = lr
for param_group in self.optimizer_G.param_groups:
param_group['lr'] = lr
if self.opt.verbose:
print('update learning rate: %f -> %f' % (self.old_lr, lr))
self.old_lr = lr