forked from sangwoomo/instagan
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpix2pix_model.py
111 lines (90 loc) · 4.36 KB
/
pix2pix_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import torch
from util.image_pool import ImagePool
from .base_model import BaseModel
from . import networks
class Pix2PixModel(BaseModel):
def name(self):
return 'Pix2PixModel'
@staticmethod
def modify_commandline_options(parser, is_train=True):
# changing the default values to match the pix2pix paper
# (https://phillipi.github.io/pix2pix/)
parser.set_defaults(norm='batch', netG='unet_256')
parser.set_defaults(dataset_mode='aligned')
if is_train:
parser.set_defaults(pool_size=0, no_lsgan=True)
parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss')
return parser
def initialize(self, opt):
BaseModel.initialize(self, opt)
self.isTrain = opt.isTrain
# specify the training losses you want to print out. The program will call base_model.get_current_losses
self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake']
# specify the images you want to save/display. The program will call base_model.get_current_visuals
self.visual_names = ['real_A', 'fake_B', 'real_B']
# specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
if self.isTrain:
self.model_names = ['G', 'D']
else: # during test time, only load Gs
self.model_names = ['G']
# load/define networks
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
if self.isTrain:
use_sigmoid = opt.no_lsgan
self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids)
if self.isTrain:
self.fake_AB_pool = ImagePool(opt.pool_size)
# define loss functions
self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device)
self.criterionL1 = torch.nn.L1Loss()
# initialize optimizers
self.optimizers = []
self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizers.append(self.optimizer_G)
self.optimizers.append(self.optimizer_D)
def set_input(self, input):
AtoB = self.opt.direction == 'AtoB'
self.real_A = input['A' if AtoB else 'B'].to(self.device)
self.real_B = input['B' if AtoB else 'A'].to(self.device)
self.image_paths = input['A_paths' if AtoB else 'B_paths']
def forward(self):
self.fake_B = self.netG(self.real_A)
def backward_D(self):
# Fake
# stop backprop to the generator by detaching fake_B
fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1))
pred_fake = self.netD(fake_AB.detach())
self.loss_D_fake = self.criterionGAN(pred_fake, False)
# Real
real_AB = torch.cat((self.real_A, self.real_B), 1)
pred_real = self.netD(real_AB)
self.loss_D_real = self.criterionGAN(pred_real, True)
# Combined loss
self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
self.loss_D.backward()
def backward_G(self):
# First, G(A) should fake the discriminator
fake_AB = torch.cat((self.real_A, self.fake_B), 1)
pred_fake = self.netD(fake_AB)
self.loss_G_GAN = self.criterionGAN(pred_fake, True)
# Second, G(A) = B
self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1
self.loss_G = self.loss_G_GAN + self.loss_G_L1
self.loss_G.backward()
def optimize_parameters(self):
self.forward()
# update D
self.set_requires_grad(self.netD, True)
self.optimizer_D.zero_grad()
self.backward_D()
self.optimizer_D.step()
# update G
self.set_requires_grad(self.netD, False)
self.optimizer_G.zero_grad()
self.backward_G()
self.optimizer_G.step()