-
Notifications
You must be signed in to change notification settings - Fork 5
/
pggan.py
443 lines (386 loc) · 19.8 KB
/
pggan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
# -*- coding: utf-8 -*-
from __future__ import print_function
import os, sys
from math import floor, ceil
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam
import torchvision.transforms as transforms
from network import *
from config import config
import dataloader as dl
import tf_recorder as tensorboard
import utils as utils
__author__ = 'Rahul Bhalley'
class PGGAN:
def __init__(self, config):
self.config = config
self.use_cuda = False
self.use_mps = False
if torch.backends.mps.is_available():
self.use_mps = True
self.device = torch.device("mps")
torch.set_default_device(self.device)
elif torch.cuda.is_available():
self.use_cuda = True
self.device = torch.device("cuda")
torch.set_default_tensor_type('torch.cuda.FloatTensor')
else:
self.device = torch.device("cpu")
torch.set_default_tensor_type('torch.FloatTensor')
print(f"Training on device: {self.device}")
self.nz = config.nz
self.optimizer = config.optimizer
self.resl = 2 # we start with resolution 2^2 = 4
self.lr = config.lr
self.eps_drift = config.eps_drift
self.smoothing = config.smoothing
self.max_resl = config.max_resl
self.trns_tick = config.trns_tick
self.stab_tick = config.stab_tick
self.TICK = config.TICK
self.global_iter = 0
self.global_tick = 0
self.kimgs = 0
self.stack = 0
self.epoch = 0
self.fadein = {'gen': None, 'dis': None}
self.complete = {'gen': 0, 'dis': 0}
self.phase = 'init'
self.flag_flush_gen = False
self.flag_flush_dis = False
self.flag_add_noise = self.config.flag_add_noise
self.flag_add_drift = self.config.flag_add_drift
self.gan_type = config.gan_type
self.gp_lambda = 10 # Gradient penalty lambda for WGAN-GP
# Network settings
self.G = Generator(config)
print('Generator architecture:\n{}'.format(self.G.model))
self.D = Discriminator(config)
print('Discriminator architecture:\n{}'.format(self.D.model))
self.criterion = nn.MSELoss()
if self.use_cuda or self.use_mps:
self.criterion = self.criterion.to(self.device)
if self.use_cuda:
torch.cuda.manual_seed(config.random_seed)
if config.n_gpu == 1:
self.G = nn.DataParallel(self.G).to(self.device)
self.D = nn.DataParallel(self.D).to(self.device)
else:
gpus = []
for i in range(config.n_gpu):
gpus.append(i)
self.G = nn.DataParallel(self.G, device_ids=gpus).to(self.device)
self.D = nn.DataParallel(self.D, device_ids=gpus).to(self.device)
# Define tensors, ship model to cuda, and get dataloader
self.renew_everything()
# Tensorboard
self.use_tb = config.use_tb
if self.use_tb:
self.tb = tensorboard.tf_recorder()
def resl_scheduler(self):
'''
This method will schedule image resolution `self.resl` progressively
It should be called every iteration to ensure real value is updated properly
Step 1. `trns_tick` -> transition in generator
Step 2. `stab_tick` -> stabilize
Step 3: `trns_tick` -> transition in discriminator
Step 4: `stab_tick` -> stabilize
'''
if floor(self.resl) != 2:
self.trns_tick = self.config.trns_tick
self.stab_tick = self.config.stab_tick
self.batchsize = self.loader.batchsize
delta = 1.0 / (2 * self.trns_tick + 2 * self.stab_tick)
d_alpha = 1.0 * self.batchsize / self.trns_tick / self.TICK
# Update `alpha` if fade-in layer exists in `Generator`
if self.fadein['gen'] is not None:
if self.resl % 1.0 < self.trns_tick * delta:
self.fadein['gen'].update_alpha(d_alpha)
self.complete['gen'] = self.fadein['gen'].alpha * 100
self.phase = 'gtrns'
elif self.resl % 1.0 >= self.trns_tick * delta and self.resl % 1.0 < (self.trns_tick + self.stab_tick) * delta:
self.phase = 'gstab'
# Update `alpha` if fade-in layer exists in `Discriminator`
if self.fadein['dis'] is not None:
if self.resl % 1.0 >= (self.trns_tick + self.stab_tick) * delta and self.resl % 1.0 < (self.stab_tick + self.trns_tick * 2) * delta:
self.fadein['dis'].update_alpha(d_alpha)
self.complete['dis'] = self.fadein['dis'].alpha * 100
self.phase = 'dtrns'
elif self.resl % 1.0 >= (self.stab_tick + self.trns_tick * 2) * delta and self.phase != 'final':
self.phase = 'dstab'
prev_kimgs = self.kimgs
self.kimgs = self.kimgs + self.batchsize
if self.kimgs % self.TICK < prev_kimgs % self.TICK:
self.global_tick = self.global_tick + 1
# Increase `resl` linearly every tick and
# grow the network architecture
prev_resl = floor(self.resl)
self.resl = self.resl + delta
self.resl = max(2, min(10.5, self.resl)) # clamping , range: 4 ~ 1024
#
# Flush the networks
#
if self.flag_flush_gen and self.resl % 1.0 >= (self.trns_tick + self.stab_tick) * delta and prev_resl != 2:
if self.fadein['gen'] is not None:
self.fadein['gen'].update_alpha(d_alpha)
self.complete['gen'] = self.fadein['gen'].alpha * 100
self.flag_flush_gen = False
self.G.flush_network() # flush Generator
print('Generator flushed:\n{}'.format(self.G.model))
self.fadein['gen'] = None
self.complete['gen'] = 0.0
self.phase = 'dtrns'
elif self.flag_flush_dis and floor(self.resl) != prev_resl and prev_resl != 2:
if self.fadein['dis'] is not None:
self.fadein['dis'].update_alpha(d_alpha)
self.complete['dis'] = self.fadein['dis'].alpha * 100
self.flag_flush_dis = False
self.D.flush_network() # flush Discriminator
print('Discriminator flushed:\n{}'.format(self.D.model))
self.fadein['dis'] = None
self.complete['dis'] = 0.0
if floor(self.resl) < self.max_resl and self.phase != 'final':
self.phase = 'gtrns'
#
# Grow the networks
#
if floor(self.resl) != prev_resl and floor(self.resl) < self.max_resl + 1:
self.lr = self.lr * float(self.config.lr_decay)
self.G.grow_network(floor(self.resl))
self.D.grow_network(floor(self.resl))
self.renew_everything()
self.fadein['gen'] = self.G.model.fadein_block
self.fadein['dis'] = self.D.model.fadein_block
self.flag_flush_gen = True
self.flag_flush_dis = True
if floor(self.resl) >= self.max_resl and self.resl % 1.0 >= (self.stab_tick + self.trns_tick * 2) * delta:
self.phase = 'final'
self.resl = self.max_resl + (self.stab_tick + self.trns_tick * 2) * delta
def renew_everything(self):
'''Renew the dataloader
'''
self.loader = dl.dataloader(self.config, self.device)
self.loader.renew(min(floor(self.resl), self.max_resl))
# Define tensors
self.z = torch.FloatTensor(self.loader.batchsize, self.nz).to(self.device)
self.x = torch.FloatTensor(self.loader.batchsize, 3, self.loader.imsize, self.loader.imsize).to(self.device)
self.x_tilde = torch.FloatTensor(self.loader.batchsize, 3, self.loader.imsize, self.loader.imsize).to(self.device)
self.real_label = torch.FloatTensor(self.loader.batchsize).fill_(1).to(self.device)
self.fake_label = torch.FloatTensor(self.loader.batchsize).fill_(0).to(self.device)
# Enable device
if self.use_cuda:
torch.cuda.manual_seed(config.random_seed)
# Wrapping `autograd.Variable`
self.x = self.x.requires_grad_()
self.x_tilde = self.x_tilde.requires_grad_()
self.z = self.z.requires_grad_()
self.real_label = self.real_label.requires_grad_()
self.fake_label = self.fake_label.requires_grad_()
# Ship new model to device
self.G = self.G.to(self.device)
self.D = self.D.to(self.device)
# Setup the optimizer
betas = (self.config.beta1, self.config.beta2)
if self.optimizer == 'adam':
if self.gan_type == 'wgan' or self.gan_type == 'wgan-gp':
self.opt_g = Adam(self.G.parameters(), lr=self.lr, betas=(0.0, 0.9))
self.opt_d = Adam(self.D.parameters(), lr=self.lr, betas=(0.0, 0.9))
else:
self.opt_g = Adam(self.G.parameters(), lr=self.lr, betas=betas, weight_decay=0.0)
self.opt_d = Adam(self.D.parameters(), lr=self.lr, betas=betas, weight_decay=0.0)
def feed_interpolated_input(self, x):
if self.phase == 'gtrns' and floor(self.resl) > 2 and floor(self.resl) <= self.max_resl:
alpha = self.complete['gen'] / 100.0
transform = transforms.Compose( [
transforms.ToPILImage(),
transforms.Resize(size=int(pow(2, floor(self.resl) - 1)), interpolation=0),
transforms.Resize(size=int(pow(2, floor(self.resl))), interpolation=0),
transforms.ToTensor(),
] )
x_low = x.clone().add(1).mul(0.5)
for i in range(x_low.size(0)):
x_low[i] = transform(x_low[i]).mul(2).add(-1)
x = torch.add(x.mul(alpha), x_low.mul(1 - alpha)) # interpolated_x
if self.use_cuda or self.use_mps:
return x.to(self.device)
else:
return x
def add_noise(self, x):
if self.flag_add_noise == False:
return x
if hasattr(self, '_d_'):
self._d_ = self._d_ * 0.9 + torch.mean(self.fx_tilde).item() * 0.1
else:
self._d_ = 0.0
strength = 0.2 * max(0, self._d_ - 0.5) ** 2
z = np.random.randn(*x.size()).astype(np.float32) * strength
z = torch.from_numpy(z).to(self.device)
return x + z
def compute_gradient_penalty(self, real_samples, fake_samples):
alpha = torch.rand(real_samples.size(0), 1, 1, 1).to(self.device)
interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
d_interpolates = self.D(interpolates)
fake = torch.ones(d_interpolates.size()).to(self.device)
gradients = torch.autograd.grad(
outputs=d_interpolates,
inputs=interpolates,
grad_outputs=fake,
create_graph=True,
retain_graph=True,
only_inputs=True,
)[0]
gradients = gradients.view(gradients.size(0), -1)
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
return gradient_penalty
def train_step(self):
self.G.zero_grad()
self.D.zero_grad()
# Update discriminator
self.x.data = self.feed_interpolated_input(self.loader.get_batch())
if self.flag_add_noise:
self.x = self.add_noise(self.x)
self.z.data.resize_(self.loader.batchsize, self.nz).normal_(0.0, 1.0)
self.x_tilde = self.G(self.z)
self.fx = self.D(self.x)
self.fx_tilde = self.D(self.x_tilde.detach())
if self.gan_type == 'standard':
loss_d = self.criterion(self.fx, self.real_label) + self.criterion(self.fx_tilde, self.fake_label)
elif self.gan_type == 'wgan':
loss_d = torch.mean(self.fx_tilde) - torch.mean(self.fx)
if self.flag_add_drift:
loss_d += self.config.lambda_drift * torch.mean(self.fx ** 2)
elif self.gan_type == 'wgan-gp':
loss_d = torch.mean(self.fx_tilde) - torch.mean(self.fx)
gradient_penalty = self.compute_gradient_penalty(self.x.data, self.x_tilde.data)
loss_d += self.config.lambda_gp * gradient_penalty
elif self.gan_type == 'lsgan':
loss_d = 0.5 * torch.mean((self.fx - 1) ** 2) + 0.5 * torch.mean(self.fx_tilde ** 2)
elif self.gan_type == 'began':
loss_d = torch.mean(torch.abs(self.x - self.D(self.x))) - self.k * torch.mean(torch.abs(self.x_tilde - self.D(self.x_tilde)))
self.k = self.k + self.config.lambda_k * (self.config.gamma * loss_d.item() - loss_d.item())
self.k = max(min(1, self.k), 0)
elif self.gan_type in ['cgan', 'acgan']:
real_labels = torch.LongTensor(self.loader.batchsize).random_(0, self.config.n_classes).to(self.device)
fake_labels = torch.LongTensor(self.loader.batchsize).random_(0, self.config.n_classes).to(self.device)
self.fx, real_aux = self.D(self.x, real_labels)
self.fx_tilde, fake_aux = self.D(self.x_tilde.detach(), fake_labels)
loss_d = self.criterion(self.fx, self.real_label) + self.criterion(self.fx_tilde, self.fake_label)
if self.gan_type == 'acgan':
loss_d += F.cross_entropy(real_aux, real_labels) + F.cross_entropy(fake_aux, fake_labels)
elif self.gan_type == 'infogan':
self.fx, _ = self.D(self.x)
self.fx_tilde, _ = self.D(self.x_tilde.detach())
loss_d = self.criterion(self.fx, self.real_label) + self.criterion(self.fx_tilde, self.fake_label)
loss_d.backward()
self.opt_d.step()
if self.gan_type == 'wgan':
for p in self.D.parameters():
p.data.clamp_(-0.01, 0.01)
# Update generator
self.G.zero_grad()
self.fx_tilde = self.D(self.x_tilde)
if self.gan_type == 'standard':
loss_g = self.criterion(self.fx_tilde, self.real_label.detach())
elif self.gan_type in ['wgan', 'wgan-gp']:
loss_g = -torch.mean(self.fx_tilde)
elif self.gan_type == 'lsgan':
loss_g = 0.5 * torch.mean((self.fx_tilde - 1) ** 2)
elif self.gan_type == 'began':
loss_g = torch.mean(torch.abs(self.x_tilde - self.D(self.x_tilde)))
elif self.gan_type in ['cgan', 'acgan']:
self.fx_tilde, fake_aux = self.D(self.x_tilde, fake_labels)
loss_g = self.criterion(self.fx_tilde, self.real_label.detach())
if self.gan_type == 'acgan':
loss_g += F.cross_entropy(fake_aux, fake_labels)
elif self.gan_type == 'infogan':
self.fx_tilde, recon_c = self.D(self.x_tilde)
loss_g = self.criterion(self.fx_tilde, self.real_label.detach())
loss_info = F.mse_loss(recon_c, self.c)
loss_g += self.config.lambda_info * loss_info
loss_g.backward()
self.opt_g.step()
return loss_d.item(), loss_g.item()
def train(self):
# noise for test
self.z_test = torch.FloatTensor(self.loader.batchsize, self.nz).to(self.device)
self.z_test = self.z_test.requires_grad_(False)
self.z_test.normal_(0.0, 1.0)
for step in range(0, self.max_resl + 1 + 5):
for iter in tqdm(range(0, (self.trns_tick * 2 + self.stab_tick * 2) * self.TICK, self.loader.batchsize)):
self.global_iter = self.global_iter + 1
self.stack = self.stack + self.loader.batchsize
if self.stack > ceil(len(self.loader.dataset)):
self.epoch = self.epoch + 1
self.stack = int(self.stack % (ceil(len(self.loader.dataset))))
# Resolution scheduler
self.resl_scheduler()
# Train step
loss_d, loss_g = self.train_step()
# Log information
log_msg = ' [E:{0}][T:{1}][{2:6}/{3:6}] errD: {4:.4f} | errG: {5:.4f} | [lr:{11:.5f}][cur:{6:.3f}][resl:{7:4}][{8}][{9:.1f}%][{10:.1f}%]'.format(
self.epoch, self.global_tick, self.stack, len(self.loader.dataset), loss_d, loss_g, self.resl, int(pow(2,floor(self.resl))), self.phase, self.complete['gen'], self.complete['dis'], self.lr)
tqdm.write(log_msg)
# Save the model
self.snapshot('./repo/model')
# Save the image grid
if self.global_iter % self.config.save_img_every == 0:
x_test = self.G(self.z_test)
os.system('mkdir -p repo/save/grid')
utils.save_image_grid(x_test.data, 'repo/save/grid/{}_{}_G{}_D{}.jpg'.format(int(self.global_iter / self.config.save_img_every), self.phase, self.complete['gen'], self.complete['dis']))
os.system('mkdir -p repo/save/resl_{}'.format(int(floor(self.resl))))
utils.save_image_single(x_test.data, 'repo/save/resl_{}/{}_{}_G{}_D{}.jpg'.format(int(floor(self.resl)), int(self.global_iter / self.config.save_img_every), self.phase, self.complete['gen'], self.complete['dis']))
# Tensorboard visualization
if self.use_tb:
x_test = self.G(self.z_test)
self.tb.add_scalar('data/loss_g', loss_g, self.global_iter)
self.tb.add_scalar('data/loss_d', loss_d, self.global_iter)
self.tb.add_scalar('tick/lr', self.lr, self.global_iter)
self.tb.add_scalar('tick/cur_resl', int(pow(2,floor(self.resl))), self.global_iter)
self.tb.add_image_grid('grid/x_test', 4, utils.adjust_dyn_range(x_test.data.float(), [-1, 1], [0, 1]), self.global_iter)
self.tb.add_image_grid('grid/x_tilde', 4, utils.adjust_dyn_range(self.x_tilde.data.float(), [-1, 1], [0, 1]), self.global_iter)
self.tb.add_image_grid('grid/x_intp', 4, utils.adjust_dyn_range(self.x.data.float(), [-1, 1], [0, 1]), self.global_iter)
def get_state(self, target):
if target == 'gen':
state = {
'resl': self.resl,
'state_dict': self.G.state_dict(),
'optimizer': self.opt_g.state_dict(),
}
return state
elif target == 'dis':
state = {
'resl': self.resl,
'state_dict': self.D.state_dict(),
'optimizer': self.opt_d.state_dict(),
}
return state
def snapshot(self, path):
if not os.path.exists(path):
os.system('mkdir -p {}'.format(path))
# Save every 100 tick if the network is in stab phase
ndis = 'dis_R{}_T{}.pth.tar'.format(int(floor(self.resl)), self.global_tick)
ngen = 'gen_R{}_T{}.pth.tar'.format(int(floor(self.resl)), self.global_tick)
if self.global_tick % 50 == 0:
if self.phase == 'gstab' or self.phase == 'dstab' or self.phase == 'final':
save_path = os.path.join(path, ndis)
if not os.path.exists(save_path):
torch.save(self.get_state('dis'), save_path)
save_path = os.path.join(path, ngen)
torch.save(self.get_state('gen'), save_path)
print('[snapshot] model saved @ {}'.format(path))
def evaluate(self):
pass
def test_growth(self):
self.G.grow_network(3)
self.G.flush_network()
print(self.G.model)
# Perform the training of PGGAN
print('Configuration')
for k, v in vars(config).items():
print('{}: {}'.format(k, v))
torch.backends.cudnn.benchmark = True # boost the speed
pggan = PGGAN(config)
pggan.train()