Skip to content

Commit 69bff56

Browse files
author
ItsArghya
committed
Created using Colaboratory
1 parent 9aa9899 commit 69bff56

File tree

1 file changed

+224
-0
lines changed

1 file changed

+224
-0
lines changed

main.ipynb

Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
{
2+
"nbformat": 4,
3+
"nbformat_minor": 0,
4+
"metadata": {
5+
"colab": {
6+
"name": "Untitled5.ipynb",
7+
"version": "0.3.2",
8+
"provenance": [],
9+
"include_colab_link": true
10+
},
11+
"kernelspec": {
12+
"name": "python3",
13+
"display_name": "Python 3"
14+
}
15+
},
16+
"cells": [
17+
{
18+
"cell_type": "markdown",
19+
"metadata": {
20+
"id": "view-in-github",
21+
"colab_type": "text"
22+
},
23+
"source": [
24+
"<a href=\"https://colab.research.google.com/github/ArghyaPal/Adversarial-Data-Programming/blob/master/main.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
25+
]
26+
},
27+
{
28+
"cell_type": "code",
29+
"metadata": {
30+
"id": "WAMXDQvIclNx",
31+
"colab_type": "code",
32+
"colab": {}
33+
},
34+
"source": [
35+
"import argparse\n",
36+
"import os\n",
37+
"import numpy as np\n",
38+
"import math\n",
39+
"import scipy\n",
40+
"import itertools\n",
41+
"\n",
42+
"import torchvision.transforms as transforms\n",
43+
"from torchvision.utils import save_image\n",
44+
"\n",
45+
"from torch.utils.data import DataLoader\n",
46+
"from torchvision import datasets\n",
47+
"from torch.autograd import Variable\n",
48+
"\n",
49+
"import torch.nn as nn\n",
50+
"import torch.nn.functional as F\n",
51+
"import torch\n",
52+
"\n",
53+
"from model import *"
54+
],
55+
"execution_count": 0,
56+
"outputs": []
57+
},
58+
{
59+
"cell_type": "code",
60+
"metadata": {
61+
"id": "bM5UAfXMcpiQ",
62+
"colab_type": "code",
63+
"colab": {}
64+
},
65+
"source": [
66+
"os.makedirs(\"images\", exist_ok=True)\n",
67+
"\n",
68+
"parser = argparse.ArgumentParser()\n",
69+
"parser.add_argument(\"--n_epochs\", type=int, default=200, help=\"number of epochs of training\")\n",
70+
"parser.add_argument(\"--batch_size\", type=int, default=32, help=\"size of the batches\")\n",
71+
"parser.add_argument(\"--lr\", type=float, default=0.0002, help=\"adam: learning rate\")\n",
72+
"parser.add_argument(\"--b1\", type=float, default=0.5, help=\"adam: decay of first order momentum of gradient\")\n",
73+
"parser.add_argument(\"--b2\", type=float, default=0.999, help=\"adam: decay of first order momentum of gradient\")\n",
74+
"parser.add_argument(\"--n_cpu\", type=int, default=8, help=\"number of cpu threads to use during batch generation\")\n",
75+
"parser.add_argument(\"--latent_dim\", type=int, default=100, help=\"dimensionality of the latent space\")\n",
76+
"parser.add_argument(\"--img_size\", type=int, default=32, help=\"size of each image dimension\")\n",
77+
"parser.add_argument(\"--channels\", type=int, default=3, help=\"number of image channels\")\n",
78+
"parser.add_argument(\"--sample_interval\", type=int, default=400, help=\"interval betwen image samples\")\n",
79+
"opt = parser.parse_args()\n",
80+
"print(opt) "
81+
],
82+
"execution_count": 0,
83+
"outputs": []
84+
},
85+
{
86+
"cell_type": "code",
87+
"metadata": {
88+
"id": "UQsuGDgBcteY",
89+
"colab_type": "code",
90+
"colab": {}
91+
},
92+
"source": [
93+
"img_shape = (opt.channels, opt.img_size, opt.img_size)\n",
94+
"\n",
95+
"cuda = True if torch.cuda.is_available() else False\n",
96+
"\n",
97+
"\n",
98+
"# Loss functions\n",
99+
"adversarial_loss = torch.nn.MSELoss()\n",
100+
"\n",
101+
"# Initialize generator and discriminator\n",
102+
"generator = Generator()\n",
103+
"discriminator = Discriminator()\n",
104+
"\n",
105+
"if cuda:\n",
106+
" generator.cuda()\n",
107+
" discriminator.cuda()\n",
108+
" adversarial_loss.cuda()\n",
109+
"\n",
110+
"# Configure data loader\n",
111+
"os.makedirs(\"./data/mnist\", exist_ok=True)\n",
112+
"dataloader = torch.utils.data.DataLoader(\n",
113+
" datasets.MNIST(\n",
114+
" \"./data/mnist\",\n",
115+
" train=True,\n",
116+
" download=True,\n",
117+
" transform=transforms.Compose(\n",
118+
" [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]\n",
119+
" ),\n",
120+
" ),\n",
121+
" batch_size=opt.batch_size,\n",
122+
" shuffle=True,\n",
123+
")\n",
124+
"\n",
125+
"# Optimizers\n",
126+
"optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))\n",
127+
"optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))\n",
128+
"\n",
129+
"FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor\n",
130+
"LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor\n",
131+
"\n",
132+
"\n",
133+
"def sample_image(n_row, batches_done):\n",
134+
" \"\"\"Saves a grid of generated digits ranging from 0 to n_classes\"\"\"\n",
135+
" # Sample noise\n",
136+
" z = Variable(FloatTensor(np.random.normal(0, 1, (n_row ** 2, opt.latent_dim))))\n",
137+
" # Get labels ranging from 0 to n_classes for n rows\n",
138+
" labels = np.array([num for _ in range(n_row) for num in range(n_row)])\n",
139+
" labels = Variable(LongTensor(labels))\n",
140+
" gen_imgs, gen_labels = generator(z)\n",
141+
" save_image(gen_imgs.data, \"images/%d.png\" % batches_done, nrow=n_row, normalize=True)"
142+
],
143+
"execution_count": 0,
144+
"outputs": []
145+
},
146+
{
147+
"cell_type": "code",
148+
"metadata": {
149+
"id": "jLtn-jdwcxFY",
150+
"colab_type": "code",
151+
"colab": {}
152+
},
153+
"source": [
154+
"\n",
155+
"# ----------\n",
156+
"# Training\n",
157+
"# ----------\n",
158+
"\n",
159+
"for epoch in range(opt.n_epochs):\n",
160+
" for i, (imgs, labels) in enumerate(dataloader):\n",
161+
"\n",
162+
" batch_size = imgs.shape[0]\n",
163+
"\n",
164+
" # Adversarial ground truths\n",
165+
" valid = Variable(Tensor(batch_size, 1).fill_(1.0), requires_grad=False)\n",
166+
" fake = Variable(Tensor(batch_size, 1).fill_(0.0), requires_grad=False)\n",
167+
"\n",
168+
" # Configure input\n",
169+
" imgs = Variable(imgs.type(Tensor))\n",
170+
"\n",
171+
" # ------------------\n",
172+
" # Train Generators\n",
173+
" # ------------------\n",
174+
"\n",
175+
" optimizer_G.zero_grad()\n",
176+
"\n",
177+
" # Sample noise as generator input\n",
178+
" z = Variable(Tensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))\n",
179+
"\n",
180+
" # Generate a batch of images\n",
181+
" gen_imgs, gen_labels = coupled_generators(z)\n",
182+
"\n",
183+
" # Determine validity of generated images\n",
184+
" validity = coupled_discriminators(gen_imgs, gen_labels)\n",
185+
"\n",
186+
" g_loss = adversarial_loss(validity, valid)\n",
187+
"\n",
188+
" g_loss.backward()\n",
189+
" optimizer_G.step()\n",
190+
"\n",
191+
" # ----------------------\n",
192+
" # Train Discriminators\n",
193+
" # ----------------------\n",
194+
"\n",
195+
" optimizer_D.zero_grad()\n",
196+
"\n",
197+
" # Loss for real images\n",
198+
" validity_real = discriminator(real_imgs, labels)\n",
199+
" d_real_loss = adversarial_loss(validity_real, valid)\n",
200+
"\n",
201+
" # Loss for fake images\n",
202+
" validity_fake = discriminator(gen_imgs.detach(), gen_labels)\n",
203+
" d_fake_loss = adversarial_loss(validity_fake, fake)\n",
204+
"\n",
205+
" # Total discriminator loss\n",
206+
" d_loss = (d_real_loss + d_fake_loss) / 2\n",
207+
"\n",
208+
" d_loss.backward()\n",
209+
" optimizer_D.step()\n",
210+
"\n",
211+
" print(\n",
212+
" \"[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]\"\n",
213+
" % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())\n",
214+
" )\n",
215+
"\n",
216+
" batches_done = epoch * len(dataloader) + i\n",
217+
" if batches_done % opt.sample_interval == 0:\n",
218+
" sample_image(n_row=10, batches_done=batches_done)"
219+
],
220+
"execution_count": 0,
221+
"outputs": []
222+
}
223+
]
224+
}

0 commit comments

Comments
 (0)