1
+ {
2
+ "nbformat" : 4 ,
3
+ "nbformat_minor" : 0 ,
4
+ "metadata" : {
5
+ "colab" : {
6
+ "name" : " Untitled5.ipynb" ,
7
+ "version" : " 0.3.2" ,
8
+ "provenance" : [],
9
+ "include_colab_link" : true
10
+ },
11
+ "kernelspec" : {
12
+ "name" : " python3" ,
13
+ "display_name" : " Python 3"
14
+ }
15
+ },
16
+ "cells" : [
17
+ {
18
+ "cell_type" : " markdown" ,
19
+ "metadata" : {
20
+ "id" : " view-in-github" ,
21
+ "colab_type" : " text"
22
+ },
23
+ "source" : [
24
+ " <a href=\" https://colab.research.google.com/github/ArghyaPal/Adversarial-Data-Programming/blob/master/main.ipynb\" target=\" _parent\" ><img src=\" https://colab.research.google.com/assets/colab-badge.svg\" alt=\" Open In Colab\" /></a>"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type" : " code" ,
29
+ "metadata" : {
30
+ "id" : " WAMXDQvIclNx" ,
31
+ "colab_type" : " code" ,
32
+ "colab" : {}
33
+ },
34
+ "source" : [
35
+ " import argparse\n " ,
36
+ " import os\n " ,
37
+ " import numpy as np\n " ,
38
+ " import math\n " ,
39
+ " import scipy\n " ,
40
+ " import itertools\n " ,
41
+ " \n " ,
42
+ " import torchvision.transforms as transforms\n " ,
43
+ " from torchvision.utils import save_image\n " ,
44
+ " \n " ,
45
+ " from torch.utils.data import DataLoader\n " ,
46
+ " from torchvision import datasets\n " ,
47
+ " from torch.autograd import Variable\n " ,
48
+ " \n " ,
49
+ " import torch.nn as nn\n " ,
50
+ " import torch.nn.functional as F\n " ,
51
+ " import torch\n " ,
52
+ " \n " ,
53
+ " from model import *"
54
+ ],
55
+ "execution_count" : 0 ,
56
+ "outputs" : []
57
+ },
58
+ {
59
+ "cell_type" : " code" ,
60
+ "metadata" : {
61
+ "id" : " bM5UAfXMcpiQ" ,
62
+ "colab_type" : " code" ,
63
+ "colab" : {}
64
+ },
65
+ "source" : [
66
+ " os.makedirs(\" images\" , exist_ok=True)\n " ,
67
+ " \n " ,
68
+ " parser = argparse.ArgumentParser()\n " ,
69
+ " parser.add_argument(\" --n_epochs\" , type=int, default=200, help=\" number of epochs of training\" )\n " ,
70
+ " parser.add_argument(\" --batch_size\" , type=int, default=32, help=\" size of the batches\" )\n " ,
71
+ " parser.add_argument(\" --lr\" , type=float, default=0.0002, help=\" adam: learning rate\" )\n " ,
72
+ " parser.add_argument(\" --b1\" , type=float, default=0.5, help=\" adam: decay of first order momentum of gradient\" )\n " ,
73
+ " parser.add_argument(\" --b2\" , type=float, default=0.999, help=\" adam: decay of first order momentum of gradient\" )\n " ,
74
+ " parser.add_argument(\" --n_cpu\" , type=int, default=8, help=\" number of cpu threads to use during batch generation\" )\n " ,
75
+ " parser.add_argument(\" --latent_dim\" , type=int, default=100, help=\" dimensionality of the latent space\" )\n " ,
76
+ " parser.add_argument(\" --img_size\" , type=int, default=32, help=\" size of each image dimension\" )\n " ,
77
+ " parser.add_argument(\" --channels\" , type=int, default=3, help=\" number of image channels\" )\n " ,
78
+ " parser.add_argument(\" --sample_interval\" , type=int, default=400, help=\" interval betwen image samples\" )\n " ,
79
+ " opt = parser.parse_args()\n " ,
80
+ " print(opt) "
81
+ ],
82
+ "execution_count" : 0 ,
83
+ "outputs" : []
84
+ },
85
+ {
86
+ "cell_type" : " code" ,
87
+ "metadata" : {
88
+ "id" : " UQsuGDgBcteY" ,
89
+ "colab_type" : " code" ,
90
+ "colab" : {}
91
+ },
92
+ "source" : [
93
+ " img_shape = (opt.channels, opt.img_size, opt.img_size)\n " ,
94
+ " \n " ,
95
+ " cuda = True if torch.cuda.is_available() else False\n " ,
96
+ " \n " ,
97
+ " \n " ,
98
+ " # Loss functions\n " ,
99
+ " adversarial_loss = torch.nn.MSELoss()\n " ,
100
+ " \n " ,
101
+ " # Initialize generator and discriminator\n " ,
102
+ " generator = Generator()\n " ,
103
+ " discriminator = Discriminator()\n " ,
104
+ " \n " ,
105
+ " if cuda:\n " ,
106
+ " generator.cuda()\n " ,
107
+ " discriminator.cuda()\n " ,
108
+ " adversarial_loss.cuda()\n " ,
109
+ " \n " ,
110
+ " # Configure data loader\n " ,
111
+ " os.makedirs(\" ./data/mnist\" , exist_ok=True)\n " ,
112
+ " dataloader = torch.utils.data.DataLoader(\n " ,
113
+ " datasets.MNIST(\n " ,
114
+ " \" ./data/mnist\" ,\n " ,
115
+ " train=True,\n " ,
116
+ " download=True,\n " ,
117
+ " transform=transforms.Compose(\n " ,
118
+ " [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]\n " ,
119
+ " ),\n " ,
120
+ " ),\n " ,
121
+ " batch_size=opt.batch_size,\n " ,
122
+ " shuffle=True,\n " ,
123
+ " )\n " ,
124
+ " \n " ,
125
+ " # Optimizers\n " ,
126
+ " optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))\n " ,
127
+ " optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))\n " ,
128
+ " \n " ,
129
+ " FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor\n " ,
130
+ " LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor\n " ,
131
+ " \n " ,
132
+ " \n " ,
133
+ " def sample_image(n_row, batches_done):\n " ,
134
+ " \"\"\" Saves a grid of generated digits ranging from 0 to n_classes\"\"\"\n " ,
135
+ " # Sample noise\n " ,
136
+ " z = Variable(FloatTensor(np.random.normal(0, 1, (n_row ** 2, opt.latent_dim))))\n " ,
137
+ " # Get labels ranging from 0 to n_classes for n rows\n " ,
138
+ " labels = np.array([num for _ in range(n_row) for num in range(n_row)])\n " ,
139
+ " labels = Variable(LongTensor(labels))\n " ,
140
+ " gen_imgs, gen_labels = generator(z)\n " ,
141
+ " save_image(gen_imgs.data, \" images/%d.png\" % batches_done, nrow=n_row, normalize=True)"
142
+ ],
143
+ "execution_count" : 0 ,
144
+ "outputs" : []
145
+ },
146
+ {
147
+ "cell_type" : " code" ,
148
+ "metadata" : {
149
+ "id" : " jLtn-jdwcxFY" ,
150
+ "colab_type" : " code" ,
151
+ "colab" : {}
152
+ },
153
+ "source" : [
154
+ " \n " ,
155
+ " # ----------\n " ,
156
+ " # Training\n " ,
157
+ " # ----------\n " ,
158
+ " \n " ,
159
+ " for epoch in range(opt.n_epochs):\n " ,
160
+ " for i, (imgs, labels) in enumerate(dataloader):\n " ,
161
+ " \n " ,
162
+ " batch_size = imgs.shape[0]\n " ,
163
+ " \n " ,
164
+ " # Adversarial ground truths\n " ,
165
+ " valid = Variable(Tensor(batch_size, 1).fill_(1.0), requires_grad=False)\n " ,
166
+ " fake = Variable(Tensor(batch_size, 1).fill_(0.0), requires_grad=False)\n " ,
167
+ " \n " ,
168
+ " # Configure input\n " ,
169
+ " imgs = Variable(imgs.type(Tensor))\n " ,
170
+ " \n " ,
171
+ " # ------------------\n " ,
172
+ " # Train Generators\n " ,
173
+ " # ------------------\n " ,
174
+ " \n " ,
175
+ " optimizer_G.zero_grad()\n " ,
176
+ " \n " ,
177
+ " # Sample noise as generator input\n " ,
178
+ " z = Variable(Tensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))\n " ,
179
+ " \n " ,
180
+ " # Generate a batch of images\n " ,
181
+ " gen_imgs, gen_labels = coupled_generators(z)\n " ,
182
+ " \n " ,
183
+ " # Determine validity of generated images\n " ,
184
+ " validity = coupled_discriminators(gen_imgs, gen_labels)\n " ,
185
+ " \n " ,
186
+ " g_loss = adversarial_loss(validity, valid)\n " ,
187
+ " \n " ,
188
+ " g_loss.backward()\n " ,
189
+ " optimizer_G.step()\n " ,
190
+ " \n " ,
191
+ " # ----------------------\n " ,
192
+ " # Train Discriminators\n " ,
193
+ " # ----------------------\n " ,
194
+ " \n " ,
195
+ " optimizer_D.zero_grad()\n " ,
196
+ " \n " ,
197
+ " # Loss for real images\n " ,
198
+ " validity_real = discriminator(real_imgs, labels)\n " ,
199
+ " d_real_loss = adversarial_loss(validity_real, valid)\n " ,
200
+ " \n " ,
201
+ " # Loss for fake images\n " ,
202
+ " validity_fake = discriminator(gen_imgs.detach(), gen_labels)\n " ,
203
+ " d_fake_loss = adversarial_loss(validity_fake, fake)\n " ,
204
+ " \n " ,
205
+ " # Total discriminator loss\n " ,
206
+ " d_loss = (d_real_loss + d_fake_loss) / 2\n " ,
207
+ " \n " ,
208
+ " d_loss.backward()\n " ,
209
+ " optimizer_D.step()\n " ,
210
+ " \n " ,
211
+ " print(\n " ,
212
+ " \" [Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]\"\n " ,
213
+ " % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())\n " ,
214
+ " )\n " ,
215
+ " \n " ,
216
+ " batches_done = epoch * len(dataloader) + i\n " ,
217
+ " if batches_done % opt.sample_interval == 0:\n " ,
218
+ " sample_image(n_row=10, batches_done=batches_done)"
219
+ ],
220
+ "execution_count" : 0 ,
221
+ "outputs" : []
222
+ }
223
+ ]
224
+ }
0 commit comments