forked from pytorch/opacus
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdcgan.py
356 lines (313 loc) · 10.8 KB
/
dcgan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Runs DCGAN training with differential privacy.
"""
from __future__ import print_function
import argparse
import os
import random
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from opacus import PrivacyEngine
from tqdm import tqdm
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--data-root", required=True, help="path to dataset")
parser.add_argument(
"--workers", type=int, help="number of data loading workers", default=2
)
parser.add_argument("--batch-size", type=int, default=64, help="input batch size")
parser.add_argument(
"--imageSize",
type=int,
default=64,
help="the height / width of the input image to network",
)
parser.add_argument("--nz", type=int, default=100, help="size of the latent z vector")
parser.add_argument("--ngf", type=int, default=128)
parser.add_argument("--ndf", type=int, default=128)
parser.add_argument(
"--epochs", type=int, default=25, help="number of epochs to train for"
)
parser.add_argument(
"--lr", type=float, default=0.0002, help="learning rate, default=0.0002"
)
parser.add_argument(
"--beta1", type=float, default=0.5, help="beta1 for adam. default=0.5"
)
parser.add_argument("--ngpu", type=int, default=1, help="number of GPUs to use")
parser.add_argument("--netG", default="", help="path to netG (to continue training)")
parser.add_argument("--netD", default="", help="path to netD (to continue training)")
parser.add_argument(
"--outf", default=".", help="folder to output images and model checkpoints"
)
parser.add_argument("--manualSeed", type=int, help="manual seed")
parser.add_argument(
"--target-digit",
type=int,
default=8,
help="the target digit(0~9) for MNIST training",
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="GPU ID for this process (default: 'cuda')",
)
parser.add_argument(
"--disable-dp",
action="store_true",
default=False,
help="Disable privacy training and just train with vanilla SGD",
)
parser.add_argument(
"--secure-rng",
action="store_true",
default=False,
help="Enable Secure RNG to have trustworthy privacy guarantees. Comes at a performance cost",
)
parser.add_argument(
"-r",
"--n-runs",
type=int,
default=1,
metavar="R",
help="number of runs to average on (default: 1)",
)
parser.add_argument(
"--sigma",
type=float,
default=1.0,
metavar="S",
help="Noise multiplier (default 1.0)",
)
parser.add_argument(
"-c",
"--max-per-sample-grad_norm",
type=float,
default=1.0,
metavar="C",
help="Clip per-sample gradients to this norm (default 1.0)",
)
parser.add_argument(
"--delta",
type=float,
default=1e-5,
metavar="D",
help="Target delta (default: 1e-5)",
)
opt = parser.parse_args()
try:
os.makedirs(opt.outf)
except OSError:
pass
if opt.manualSeed is None:
opt.manualSeed = random.randint(1, 10000)
print("Random Seed: ", opt.manualSeed)
random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)
cudnn.benchmark = True
try:
dataset = dset.MNIST(
root=opt.data_root,
download=True,
transform=transforms.Compose(
[
transforms.Resize(opt.imageSize),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)),
]
),
)
idx = dataset.targets == opt.target_digit
dataset.targets = dataset.targets[idx]
dataset.data = dataset.data[idx]
nc = 1
except ValueError:
print("Cannot load dataset")
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=int(opt.workers),
batch_size=opt.batch_size,
)
device = torch.device(opt.device)
ngpu = int(opt.ngpu)
nz = int(opt.nz)
ngf = int(opt.ngf)
ndf = int(opt.ndf)
# custom weights initialization called on netG and netD
def weights_init(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find("BatchNorm") != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
class Generator(nn.Module):
def __init__(self, ngpu):
super(Generator, self).__init__()
self.ngpu = ngpu
self.main = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
nn.GroupNorm(min(32, ndf * 8), ndf * 8),
nn.ReLU(True),
# state size. (ngf*8) x 4 x 4
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.GroupNorm(min(32, ndf * 4), ndf * 4),
nn.ReLU(True),
# state size. (ngf*4) x 8 x 8
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.GroupNorm(min(32, ndf * 2), ndf * 2),
nn.ReLU(True),
# state size. (ngf*2) x 16 x 16
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
nn.GroupNorm(min(32, ndf), ndf),
nn.ReLU(True),
# state size. (ngf) x 32 x 32
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
nn.Tanh()
# state size. (nc) x 64 x 64
)
def forward(self, input):
if input.is_cuda and self.ngpu > 1:
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
else:
output = self.main(input)
return output
netG = Generator(ngpu)
netG = netG.to(device)
netG.apply(weights_init)
if opt.netG != "":
netG.load_state_dict(torch.load(opt.netG))
class Discriminator(nn.Module):
def __init__(self, ngpu):
super(Discriminator, self).__init__()
self.ngpu = ngpu
self.main = nn.Sequential(
# input is (nc) x 64 x 64
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf) x 32 x 32
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.GroupNorm(min(32, ndf * 2), ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*2) x 16 x 16
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.GroupNorm(min(32, ndf * 4), ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*4) x 8 x 8
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
nn.GroupNorm(min(32, ndf * 8), ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*8) x 4 x 4
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid(),
)
def forward(self, input):
if input.is_cuda and self.ngpu > 1:
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
else:
output = self.main(input)
return output.view(-1, 1).squeeze(1)
netD = Discriminator(ngpu)
netD = netD.to(device)
netD.apply(weights_init)
if opt.netD != "":
netD.load_state_dict(torch.load(opt.netD))
criterion = nn.BCELoss()
FIXED_NOISE = torch.randn(opt.batch_size, nz, 1, 1, device=device)
REAL_LABEL = 1.0
FAKE_LABEL = 0.0
# setup optimizer
optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
if not opt.disable_dp:
privacy_engine = PrivacyEngine(secure_mode=opt.secure_rng)
netD, optimizerD, dataloader = privacy_engine.make_private(
module=netD,
optimizer=optimizerD,
data_loader=dataloader,
noise_multiplier=opt.sigma,
max_grad_norm=opt.max_per_sample_grad_norm,
)
optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
for epoch in range(opt.epochs):
data_bar = tqdm(dataloader)
for i, data in enumerate(data_bar, 0):
############################
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
###########################
optimizerD.zero_grad()
real_data = data[0].to(device)
batch_size = real_data.size(0)
# train with fake
noise = torch.randn(batch_size, nz, 1, 1, device=device)
fake = netG(noise)
label_fake = torch.full((batch_size,), FAKE_LABEL, device=device)
output = netD(fake.detach())
errD_fake = criterion(output, label_fake)
errD_fake.backward()
optimizerD.step()
optimizerD.zero_grad()
# train with real
label_true = torch.full((batch_size,), REAL_LABEL, device=device)
output = netD(real_data)
errD_real = criterion(output, label_true)
errD_real.backward()
optimizerD.step()
D_x = output.mean().item()
D_G_z1 = output.mean().item()
errD = errD_real + errD_fake
############################
# (2) Update G network: maximize log(D(G(z)))
###########################
optimizerG.zero_grad()
optimizerD.zero_grad()
label_g = torch.full((batch_size,), REAL_LABEL, device=device)
output_g = netD(fake)
errG = criterion(output_g, label_g)
errG.backward()
D_G_z2 = output.mean().item()
optimizerG.step()
data_bar.set_description(
f"epoch: {epoch}, Loss_D: {errD.item()} "
f"Loss_G: {errG.item()} D(x): {D_x} "
f"D(G(z)): {D_G_z1}/{D_G_z2}"
)
if not opt.disable_dp:
epsilon, best_alpha = privacy_engine.accountant.get_privacy_spent(
delta=opt.delta
)
print(
"(ε = %.2f, δ = %.2f) for α = %.2f" % (epsilon, opt.delta, best_alpha)
)
if i % 100 == 0:
vutils.save_image(
real_data, "%s/real_samples.png" % opt.outf, normalize=True
)
fake = netG(FIXED_NOISE)
vutils.save_image(
fake.detach(),
"%s/fake_samples_epoch_%03d.png" % (opt.outf, epoch),
normalize=True,
)
# do checkpointing
torch.save(netG.state_dict(), "%s/netG_epoch_%d.pth" % (opt.outf, epoch))
torch.save(netD.state_dict(), "%s/netD_epoch_%d.pth" % (opt.outf, epoch))