Skip to content

Commit 4bdc736

Browse files
committed
Add GAN HW
1 parent 612b5dd commit 4bdc736

File tree

1 file changed

+319
-0
lines changed

1 file changed

+319
-0
lines changed

13-GANs/GAN_HW.ipynb

+319
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,319 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"<a href=\"https://colab.research.google.com/github/HSE-LAMBDA/MLDM-2022/blob/main/13-GANs/GAN_homework.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
8+
]
9+
},
10+
{
11+
"cell_type": "markdown",
12+
"metadata": {},
13+
"source": [
14+
"### WGAN\n",
15+
"\n",
16+
"* Modify snippets below and implement [Wasserstein GAN](https://arxiv.org/abs/1701.07875) ([From GAN to WGAN\n",
17+
"](https://lilianweng.github.io/posts/2017-08-20-gan/)) with weight clipping. (2 points)\n",
18+
"\n",
19+
"* Replace weight clipping with [gradient penalty](https://arxiv.org/pdf/1704.00028v3.pdf). (2 points)\n",
20+
"\n",
21+
"* Add labels into WGAN, performing [conditional generation](https://arxiv.org/pdf/1411.1784.pdf). (2 points) \n",
22+
"\n",
23+
"Write a report about experiments and results, add plots and visualizations."
24+
]
25+
},
26+
{
27+
"cell_type": "code",
28+
"execution_count": null,
29+
"metadata": {},
30+
"outputs": [],
31+
"source": [
32+
"import torch\n",
33+
"import torch.nn as nn\n",
34+
"import torch.nn.functional as F\n",
35+
"import torch.optim as optim\n",
36+
"from torch.utils.data import DataLoader, Dataset\n",
37+
"\n",
38+
"import torchvision\n",
39+
"import matplotlib.pyplot as plt\n",
40+
"import numpy as np\n",
41+
"\n",
42+
"from torch.autograd import Variable"
43+
]
44+
},
45+
{
46+
"cell_type": "markdown",
47+
"metadata": {},
48+
"source": [
49+
"### Creating config object (argparse workaround)"
50+
]
51+
},
52+
{
53+
"cell_type": "code",
54+
"execution_count": null,
55+
"metadata": {},
56+
"outputs": [],
57+
"source": [
58+
"class Config:\n",
59+
" pass\n",
60+
"\n",
61+
"config = Config()\n",
62+
"config.mnist_path = None\n",
63+
"config.batch_size = 16\n",
64+
"config.num_workers = 3\n",
65+
"config.num_epochs = 10\n",
66+
"config.noise_size = 50\n",
67+
"config.print_freq = 100\n"
68+
]
69+
},
70+
{
71+
"cell_type": "markdown",
72+
"metadata": {},
73+
"source": [
74+
"### Create dataloder"
75+
]
76+
},
77+
{
78+
"cell_type": "code",
79+
"execution_count": null,
80+
"metadata": {},
81+
"outputs": [],
82+
"source": [
83+
"train = torchvision.datasets.FashionMNIST(\"fashion_mnist\", train=True, transform=torchvision.transforms.ToTensor(), download=True)"
84+
]
85+
},
86+
{
87+
"cell_type": "code",
88+
"execution_count": null,
89+
"metadata": {},
90+
"outputs": [],
91+
"source": [
92+
"dataloader = DataLoader(train, batch_size=16, shuffle=True)"
93+
]
94+
},
95+
{
96+
"cell_type": "code",
97+
"execution_count": null,
98+
"metadata": {},
99+
"outputs": [],
100+
"source": [
101+
"len(dataloader)"
102+
]
103+
},
104+
{
105+
"cell_type": "code",
106+
"execution_count": null,
107+
"metadata": {},
108+
"outputs": [],
109+
"source": [
110+
"for image, cat in dataloader:\n",
111+
" break"
112+
]
113+
},
114+
{
115+
"cell_type": "code",
116+
"execution_count": null,
117+
"metadata": {
118+
"scrolled": true
119+
},
120+
"outputs": [],
121+
"source": [
122+
"image.size()"
123+
]
124+
},
125+
{
126+
"cell_type": "markdown",
127+
"metadata": {},
128+
"source": [
129+
"### Create generator and discriminator"
130+
]
131+
},
132+
{
133+
"cell_type": "code",
134+
"execution_count": null,
135+
"metadata": {},
136+
"outputs": [],
137+
"source": [
138+
"class Generator(nn.Module):\n",
139+
" def __init__(self):\n",
140+
" super(Generator, self).__init__()\n",
141+
" self.model = nn.Sequential( \n",
142+
" nn.Linear(config.noise_size, 200),\n",
143+
" nn.ReLU(inplace=True),\n",
144+
" nn.Linear(200, 28*28),\n",
145+
" nn.Sigmoid())\n",
146+
" \n",
147+
" def forward(self, x):\n",
148+
" return self.model(x)\n",
149+
" \n",
150+
"class Discriminator(nn.Module):\n",
151+
" def __init__(self):\n",
152+
" super(Discriminator, self).__init__()\n",
153+
" self.model = nn.Sequential(\n",
154+
" nn.Linear(28*28, 200),\n",
155+
" nn.ReLU(inplace=True),\n",
156+
" nn.Linear(200, 50),\n",
157+
" nn.ReLU(inplace=True),\n",
158+
" nn.Linear(50, 1), \n",
159+
" nn.Sigmoid())\n",
160+
" def forward(self, x):\n",
161+
" return self.model(x)"
162+
]
163+
},
164+
{
165+
"cell_type": "code",
166+
"execution_count": null,
167+
"metadata": {},
168+
"outputs": [],
169+
"source": [
170+
"generator = Generator()\n",
171+
"discriminator = Discriminator()"
172+
]
173+
},
174+
{
175+
"cell_type": "markdown",
176+
"metadata": {},
177+
"source": [
178+
"### Create optimizers and loss"
179+
]
180+
},
181+
{
182+
"cell_type": "code",
183+
"execution_count": null,
184+
"metadata": {},
185+
"outputs": [],
186+
"source": [
187+
"optim_G = optim.Adam(params=generator.parameters(), lr=0.0001)\n",
188+
"optim_D = optim.Adam(params=discriminator.parameters(), lr=0.0001)\n",
189+
"\n",
190+
"criterion = nn.BCELoss()"
191+
]
192+
},
193+
{
194+
"cell_type": "markdown",
195+
"metadata": {},
196+
"source": [
197+
"### Create necessary variables"
198+
]
199+
},
200+
{
201+
"cell_type": "code",
202+
"execution_count": null,
203+
"metadata": {},
204+
"outputs": [],
205+
"source": [
206+
"input = Variable(torch.FloatTensor(config.batch_size, 28*28))\n",
207+
"noise = Variable(torch.FloatTensor(config.batch_size, config.noise_size))\n",
208+
"fixed_noise = Variable(torch.FloatTensor(config.batch_size, config.noise_size).normal_(0, 1))\n",
209+
"label = Variable(torch.FloatTensor(config.batch_size))\n",
210+
"real_label = 1\n",
211+
"fake_label = 0"
212+
]
213+
},
214+
{
215+
"cell_type": "markdown",
216+
"metadata": {},
217+
"source": [
218+
"### GAN"
219+
]
220+
},
221+
{
222+
"cell_type": "code",
223+
"execution_count": null,
224+
"metadata": {
225+
"scrolled": true
226+
},
227+
"outputs": [],
228+
"source": [
229+
"ERRD_x = np.zeros(config.num_epochs)\n",
230+
"ERRD_z = np.zeros(config.num_epochs)\n",
231+
"ERRG = np.zeros(config.num_epochs)\n",
232+
"N = len(dataloader)\n",
233+
"\n",
234+
"for epoch in range(config.num_epochs):\n",
235+
" for iteration, (images, cat) in enumerate(dataloader):\n",
236+
" ####### \n",
237+
" # Discriminator stage: maximize log(D(x)) + log(1 - D(G(z))) \n",
238+
" #######\n",
239+
" discriminator.zero_grad()\n",
240+
" \n",
241+
" # real\n",
242+
" label.data.fill_(real_label)\n",
243+
" input_data = images.view(images.shape[0], -1)\n",
244+
" output = discriminator(input_data)\n",
245+
" errD_x = criterion(output, label)\n",
246+
" ERRD_x[epoch] += errD_x.item()\n",
247+
" errD_x.backward()\n",
248+
" \n",
249+
" # fake \n",
250+
" noise.data.normal_(0, 1)\n",
251+
" fake = generator(noise)\n",
252+
" label.data.fill_(fake_label)\n",
253+
" output = discriminator(fake.detach())\n",
254+
" errD_z = criterion(output, label)\n",
255+
" ERRD_z[epoch] += errD_z.item()\n",
256+
" errD_z.backward()\n",
257+
" \n",
258+
" optim_D.step()\n",
259+
" \n",
260+
" ####### \n",
261+
" # Generator stage: maximize log(D(G(x))\n",
262+
" #######\n",
263+
" generator.zero_grad()\n",
264+
" label.data.fill_(real_label)\n",
265+
" output = discriminator(fake)\n",
266+
" errG = criterion(output, label)\n",
267+
" ERRG[epoch] += errG.item()\n",
268+
" errG.backward()\n",
269+
" \n",
270+
" optim_G.step()\n",
271+
" \n",
272+
" if (iteration+1) % config.print_freq == 0:\n",
273+
" print('Epoch:{} Iter: {} errD_x: {:.2f} errD_z: {:.2f} errG: {:.2f}'.format(epoch+1,\n",
274+
" iteration+1, \n",
275+
" errD_x.item(),\n",
276+
" errD_z.item(), \n",
277+
" errG.item()))"
278+
]
279+
},
280+
{
281+
"cell_type": "code",
282+
"execution_count": null,
283+
"metadata": {},
284+
"outputs": [],
285+
"source": [
286+
"noise.data.normal_(0, 1)\n",
287+
"fake = generator(noise)\n",
288+
"\n",
289+
"plt.figure(figsize=(6, 7))\n",
290+
"for i in range(16):\n",
291+
" plt.subplot(4, 4, i + 1)\n",
292+
" plt.imshow(fake[i].detach().numpy().reshape(28, 28), cmap=plt.cm.Greys_r)\n",
293+
" plt.axis('off')"
294+
]
295+
}
296+
],
297+
"metadata": {
298+
"anaconda-cloud": {},
299+
"kernelspec": {
300+
"display_name": "Python 3 (ipykernel)",
301+
"language": "python",
302+
"name": "python3"
303+
},
304+
"language_info": {
305+
"codemirror_mode": {
306+
"name": "ipython",
307+
"version": 3
308+
},
309+
"file_extension": ".py",
310+
"mimetype": "text/x-python",
311+
"name": "python",
312+
"nbconvert_exporter": "python",
313+
"pygments_lexer": "ipython3",
314+
"version": "3.8.11"
315+
}
316+
},
317+
"nbformat": 4,
318+
"nbformat_minor": 1
319+
}

0 commit comments

Comments
 (0)