From 7644bdacf49ccb36bca219b5f7e7860cbf9d2428 Mon Sep 17 00:00:00 2001
From: shlear <116897538+shlear@users.noreply.github.com>
Date: Mon, 30 Jan 2023 09:57:08 +0300
Subject: [PATCH] =?UTF-8?q?=D0=A1=D0=BE=D0=B7=D0=B4=D0=B0=D0=BD=D0=BE=20?=
=?UTF-8?q?=D1=81=20=D0=BF=D0=BE=D0=BC=D0=BE=D1=89=D1=8C=D1=8E=20Colaborat?=
=?UTF-8?q?ory?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
13-GANs/GAN_HW.ipynb | 3815 ++++++++++++++++++++++++++++++++++++++----
1 file changed, 3499 insertions(+), 316 deletions(-)
diff --git a/13-GANs/GAN_HW.ipynb b/13-GANs/GAN_HW.ipynb
index c4bef55..a6d0c20 100644
--- a/13-GANs/GAN_HW.ipynb
+++ b/13-GANs/GAN_HW.ipynb
@@ -1,318 +1,3501 @@
{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- ""
- ]
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "p0Kyda8zUdro"
+ },
+ "source": [
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Eaip0Gm0Udrq"
+ },
+ "source": [
+ "### WGAN\n",
+ "\n",
+ "* Modify snippets below and implement [Wasserstein GAN](https://arxiv.org/abs/1701.07875) with weight clipping. (2 points)\n",
+ "\n",
+ "* Replace weight clipping with [gradient penalty](https://arxiv.org/pdf/1704.00028v3.pdf). (2 points)\n",
+ "\n",
+ "* Add labels into WGAN, performing [conditional generation](https://arxiv.org/pdf/1411.1784.pdf). (2 points) \n",
+ "\n",
+ "Write a report about experiments and results, add plots and visualizations."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "id": "aBack-CyUdrq"
+ },
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "import torch.nn.functional as F\n",
+ "import torch.optim as optim\n",
+ "from torch.utils.data import DataLoader, Dataset\n",
+ "\n",
+ "import torchvision\n",
+ "import matplotlib.pyplot as plt\n",
+ "import numpy as np\n",
+ "\n",
+ "from torch.autograd import Variable"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "SD7I7vZ8Udrr"
+ },
+ "source": [
+ "### Creating config object (argparse workaround)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {
+ "id": "e0Yq4aMxUdrr"
+ },
+ "outputs": [],
+ "source": [
+ "class Config:\n",
+ " pass\n",
+ "\n",
+ "config = Config()\n",
+ "config.mnist_path = None\n",
+ "config.batch_size = 16\n",
+ "config.num_workers = 3\n",
+ "config.num_epochs = 10\n",
+ "config.noise_size = 50\n",
+ "config.print_freq = 100\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "2CGmqp35Udrr"
+ },
+ "source": [
+ "### Create dataloder"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {
+ "id": "jtK105iuUdrr",
+ "outputId": "9e43719a-08a1-4bfc-e426-f89fd4019cf4",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 451,
+ "referenced_widgets": [
+ "3089611b5cda418faf550134695bd1c2",
+ "02e561b7133b4c5e920a3a2c0d300d2e",
+ "d25fd1b9a05f404eb37505b593c83bee",
+ "70d8400cd7474cf3984d1f31eb3f6fd7",
+ "ab3ec40e6808496dbb5e46d77773eff7",
+ "993857a0d4574ff8bef19a791c0b9e3c",
+ "00258c5fbf8348a79f9d859a34a5df43",
+ "37d0caea6f5146c98eb549c98b113d9a",
+ "26a20977835d4f8ea9dbc0f06427bdbb",
+ "8c987b2fd7c249e7b74623c91a3f15ce",
+ "fb3036cf1cd0437b86ac1632219f7e04",
+ "65be345916ef419893271ff1f0a92b51",
+ "37c36bd250e04124a09270e46141f41f",
+ "2937ae2f576b40ada1f51eac29d28422",
+ "f68051dc0ef54f0f978d67f303a22e9d",
+ "dd0cbf60b58149b89f0cd0e59c130b0e",
+ "7a2f922f6aae441380f1586048914365",
+ "933951cefdc84b7796526bdc6869fdf3",
+ "3ff4afbc672c40d0b096f119a07d0024",
+ "b02746fe4d774ca39040f7685f7827cf",
+ "701dc00c6a24482192ecf3f082babc2c",
+ "c6547b6dabf148508e341c2bdf369d83",
+ "92cefc1774de486091dc3b60021783c1",
+ "8fa40adeb3cb428298430949d13ad090",
+ "c3ab1c12ac574a15920daf69c738c302",
+ "5a8a21e352be4118a88e75141ac843cc",
+ "37b2250444864bf880421d4d2fdc57ae",
+ "2d9f91fe44054b9ca60f153878eddcba",
+ "f271a7fe389546edbeddc0cee11444f4",
+ "ba1d248606e54691a73a062a9f5646cb",
+ "340311831d5b492e99c0a5e015975443",
+ "2f0bf104af174a21aecc2a320cece5c5",
+ "3bfc8f6bf8a6430693e0703b22a57272",
+ "e251c830789e469888cc2173200a66ff",
+ "a6c11c59a074439cb9417b0b3047ff37",
+ "4e0ce38999fd431c971f095a11493445",
+ "5daff87ee00f4024a4d0c36ec365656d",
+ "89306a31ada645e887f280483191d061",
+ "83b00e40330e4a9bab89fce2d40aab7f",
+ "77b51cbd87e04728a058d46ee733c819",
+ "c5dabb6aa34d4808bfe8a876a15130b4",
+ "643947d50110467fab73bb2609e40467",
+ "fffcb3453216492cbb954f130aeaba29",
+ "7d11aa9d9be3406581f64757a5550af7"
+ ]
+ }
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz\n",
+ "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to fashion_mnist/FashionMNIST/raw/train-images-idx3-ubyte.gz\n"
+ ]
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ " 0%| | 0/26421880 [00:00, ?it/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "3089611b5cda418faf550134695bd1c2"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Extracting fashion_mnist/FashionMNIST/raw/train-images-idx3-ubyte.gz to fashion_mnist/FashionMNIST/raw\n",
+ "\n",
+ "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz\n",
+ "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to fashion_mnist/FashionMNIST/raw/train-labels-idx1-ubyte.gz\n"
+ ]
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ " 0%| | 0/29515 [00:00, ?it/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "65be345916ef419893271ff1f0a92b51"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Extracting fashion_mnist/FashionMNIST/raw/train-labels-idx1-ubyte.gz to fashion_mnist/FashionMNIST/raw\n",
+ "\n",
+ "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz\n",
+ "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to fashion_mnist/FashionMNIST/raw/t10k-images-idx3-ubyte.gz\n"
+ ]
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ " 0%| | 0/4422102 [00:00, ?it/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "92cefc1774de486091dc3b60021783c1"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Extracting fashion_mnist/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to fashion_mnist/FashionMNIST/raw\n",
+ "\n",
+ "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz\n",
+ "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to fashion_mnist/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz\n"
+ ]
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ " 0%| | 0/5148 [00:00, ?it/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "e251c830789e469888cc2173200a66ff"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Extracting fashion_mnist/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to fashion_mnist/FashionMNIST/raw\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "train = torchvision.datasets.FashionMNIST(\"fashion_mnist\", train=True, transform=torchvision.transforms.ToTensor(), download=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {
+ "id": "ghEDU08xUdrs"
+ },
+ "outputs": [],
+ "source": [
+ "dataloader = DataLoader(train, batch_size=16, shuffle=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {
+ "id": "AlyUytopUdrs",
+ "outputId": "b999ea90-080e-4d4b-c872-61a4775fcbe9",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ }
+ },
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "3750"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 5
+ }
+ ],
+ "source": [
+ "len(dataloader)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {
+ "id": "ai4GQlhUUdrs"
+ },
+ "outputs": [],
+ "source": [
+ "for image, cat in dataloader:\n",
+ " break"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {
+ "scrolled": true,
+ "id": "AHjavRq0Udrs",
+ "outputId": "10060bac-561d-4557-cbe8-430a56226331",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ }
+ },
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "torch.Size([16, 1, 28, 28])"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 7
+ }
+ ],
+ "source": [
+ "image.size()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "ltR36cKpUdrs"
+ },
+ "source": [
+ "### Create generator and discriminator"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {
+ "id": "FTxWWtRZUdrs"
+ },
+ "outputs": [],
+ "source": [
+ "class Generator(nn.Module):\n",
+ " def __init__(self):\n",
+ " super(Generator, self).__init__()\n",
+ " self.model = nn.Sequential( \n",
+ " nn.Linear(config.noise_size, 200),\n",
+ " nn.ReLU(inplace=True),\n",
+ " nn.Linear(200, 28*28),\n",
+ " nn.Sigmoid())\n",
+ " \n",
+ " def forward(self, x):\n",
+ " return self.model(x)\n",
+ " \n",
+ "class Discriminator(nn.Module):\n",
+ " def __init__(self):\n",
+ " super(Discriminator, self).__init__()\n",
+ " self.model = nn.Sequential(\n",
+ " nn.Linear(28*28, 200),\n",
+ " nn.ReLU(inplace=True),\n",
+ " nn.Linear(200, 50),\n",
+ " nn.ReLU(inplace=True),\n",
+ " nn.Linear(50, 1), \n",
+ " nn.Sigmoid())\n",
+ " def forward(self, x):\n",
+ " return self.model(x)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {
+ "id": "sdSFPJRGUdrt"
+ },
+ "outputs": [],
+ "source": [
+ "generator = Generator()\n",
+ "discriminator = Discriminator()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "p6mS17QjUdrt"
+ },
+ "source": [
+ "### Create optimizers and loss"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {
+ "id": "1i4GrgO1Udrt"
+ },
+ "outputs": [],
+ "source": [
+ "# modification2: use RMSprop (or SGD) instead of Adam\n",
+ "# optimizer\n",
+ "# optim_G = optim.Adam(params=generator.parameters(), lr=0.0001)\n",
+ "# optim_D = optim.Adam(params=discriminator.parameters(), lr=0.0001)\n",
+ "optim_G = optim.RMSprop(params=generator.parameters(),lr=0.00005)\n",
+ "optim_D = optim.RMSprop(params=discriminator.parameters(),lr=0.00005)\n",
+ "\n",
+ "# modification4: no log in loss\n",
+ "# criterion = nn.BCELoss()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "ohwrLM7tUdrt"
+ },
+ "source": [
+ "### Create necessary variables"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {
+ "id": "kYfOh6i4Udrt",
+ "outputId": "5a9c7c42-4b40-454f-f16f-c2661d25e19f",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ }
+ },
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "Discriminator(\n",
+ " (model): Sequential(\n",
+ " (0): Linear(in_features=784, out_features=200, bias=True)\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Linear(in_features=200, out_features=50, bias=True)\n",
+ " (3): ReLU(inplace=True)\n",
+ " (4): Linear(in_features=50, out_features=1, bias=True)\n",
+ " (5): Sigmoid()\n",
+ " )\n",
+ ")"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 11
+ }
+ ],
+ "source": [
+ "from torch.utils.tensorboard import SummaryWriter\n",
+ "input = Variable(torch.FloatTensor(config.batch_size, 28*28))\n",
+ "noise = Variable(torch.FloatTensor(config.batch_size, config.noise_size))\n",
+ "fixed_noise = Variable(torch.FloatTensor(config.batch_size, config.noise_size).normal_(0, 1))\n",
+ "label = Variable(torch.FloatTensor(config.batch_size))\n",
+ "real_label = 1\n",
+ "fake_label = 0\n",
+ "\n",
+ "# for tensorboard plotting\n",
+ "writer_real = SummaryWriter(f\"logs/real\")\n",
+ "writer_fake = SummaryWriter(f\"logs/fake\")\n",
+ "step = 0\n",
+ "\n",
+ "generator.train()\n",
+ "discriminator.train()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "WiHkx6rOUdrt"
+ },
+ "source": [
+ "### GAN"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {
+ "scrolled": true,
+ "id": "zrHq0fJiUdrt",
+ "outputId": "2f6894e2-cc4e-4fde-e20e-d9744964d099",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ }
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Epoch [1/10] Batch 100/3750 Loss D: -0.0229, loss G: -0.4882\n",
+ "Epoch [1/10] Batch 200/3750 Loss D: -0.0356, loss G: -0.4740\n",
+ "Epoch [1/10] Batch 300/3750 Loss D: -0.0353, loss G: -0.4621\n",
+ "Epoch [1/10] Batch 400/3750 Loss D: -0.0285, loss G: -0.4560\n",
+ "Epoch [1/10] Batch 500/3750 Loss D: -0.0326, loss G: -0.4509\n",
+ "Epoch [1/10] Batch 600/3750 Loss D: -0.0302, loss G: -0.4530\n",
+ "Epoch [1/10] Batch 700/3750 Loss D: -0.0251, loss G: -0.4669\n",
+ "Epoch [1/10] Batch 800/3750 Loss D: -0.0194, loss G: -0.4841\n",
+ "Epoch [1/10] Batch 900/3750 Loss D: -0.0180, loss G: -0.4940\n",
+ "Epoch [1/10] Batch 1000/3750 Loss D: -0.0293, loss G: -0.4924\n",
+ "Epoch [1/10] Batch 1100/3750 Loss D: -0.0286, loss G: -0.4947\n",
+ "Epoch [1/10] Batch 1200/3750 Loss D: -0.0306, loss G: -0.4999\n",
+ "Epoch [1/10] Batch 1300/3750 Loss D: -0.0346, loss G: -0.5027\n",
+ "Epoch [1/10] Batch 1400/3750 Loss D: -0.0189, loss G: -0.5093\n",
+ "Epoch [1/10] Batch 1500/3750 Loss D: -0.0214, loss G: -0.4972\n",
+ "Epoch [1/10] Batch 1600/3750 Loss D: -0.0171, loss G: -0.5068\n",
+ "Epoch [1/10] Batch 1700/3750 Loss D: -0.0230, loss G: -0.5046\n",
+ "Epoch [1/10] Batch 1800/3750 Loss D: -0.0096, loss G: -0.5098\n",
+ "Epoch [1/10] Batch 1900/3750 Loss D: -0.0156, loss G: -0.5040\n",
+ "Epoch [1/10] Batch 2000/3750 Loss D: -0.0223, loss G: -0.4945\n",
+ "Epoch [1/10] Batch 2100/3750 Loss D: -0.0109, loss G: -0.5107\n",
+ "Epoch [1/10] Batch 2200/3750 Loss D: -0.0045, loss G: -0.5080\n",
+ "Epoch [1/10] Batch 2300/3750 Loss D: -0.0066, loss G: -0.5000\n",
+ "Epoch [1/10] Batch 2400/3750 Loss D: -0.0109, loss G: -0.5068\n",
+ "Epoch [1/10] Batch 2500/3750 Loss D: -0.0110, loss G: -0.4950\n",
+ "Epoch [1/10] Batch 2600/3750 Loss D: -0.0054, loss G: -0.5053\n",
+ "Epoch [1/10] Batch 2700/3750 Loss D: -0.0052, loss G: -0.5074\n",
+ "Epoch [1/10] Batch 2800/3750 Loss D: -0.0072, loss G: -0.4935\n",
+ "Epoch [1/10] Batch 2900/3750 Loss D: -0.0040, loss G: -0.4998\n",
+ "Epoch [1/10] Batch 3000/3750 Loss D: -0.0053, loss G: -0.4981\n",
+ "Epoch [1/10] Batch 3100/3750 Loss D: -0.0104, loss G: -0.4924\n",
+ "Epoch [1/10] Batch 3200/3750 Loss D: -0.0045, loss G: -0.4999\n",
+ "Epoch [1/10] Batch 3300/3750 Loss D: -0.0080, loss G: -0.5013\n",
+ "Epoch [1/10] Batch 3400/3750 Loss D: -0.0088, loss G: -0.4971\n",
+ "Epoch [1/10] Batch 3500/3750 Loss D: -0.0090, loss G: -0.4973\n",
+ "Epoch [1/10] Batch 3600/3750 Loss D: -0.0073, loss G: -0.4991\n",
+ "Epoch [1/10] Batch 3700/3750 Loss D: -0.0097, loss G: -0.4935\n",
+ "Epoch [2/10] Batch 100/3750 Loss D: -0.0104, loss G: -0.5051\n",
+ "Epoch [2/10] Batch 200/3750 Loss D: -0.0113, loss G: -0.4903\n",
+ "Epoch [2/10] Batch 300/3750 Loss D: -0.0048, loss G: -0.5027\n",
+ "Epoch [2/10] Batch 400/3750 Loss D: -0.0075, loss G: -0.4984\n",
+ "Epoch [2/10] Batch 500/3750 Loss D: -0.0089, loss G: -0.4876\n",
+ "Epoch [2/10] Batch 600/3750 Loss D: -0.0092, loss G: -0.5133\n",
+ "Epoch [2/10] Batch 700/3750 Loss D: -0.0111, loss G: -0.5033\n",
+ "Epoch [2/10] Batch 800/3750 Loss D: -0.0101, loss G: -0.4886\n",
+ "Epoch [2/10] Batch 900/3750 Loss D: -0.0080, loss G: -0.5036\n",
+ "Epoch [2/10] Batch 1000/3750 Loss D: -0.0106, loss G: -0.5092\n",
+ "Epoch [2/10] Batch 1100/3750 Loss D: -0.0092, loss G: -0.5003\n",
+ "Epoch [2/10] Batch 1200/3750 Loss D: -0.0037, loss G: -0.5029\n",
+ "Epoch [2/10] Batch 1300/3750 Loss D: -0.0092, loss G: -0.5003\n",
+ "Epoch [2/10] Batch 1400/3750 Loss D: -0.0096, loss G: -0.4990\n",
+ "Epoch [2/10] Batch 1500/3750 Loss D: -0.0071, loss G: -0.4943\n",
+ "Epoch [2/10] Batch 1600/3750 Loss D: -0.0065, loss G: -0.5044\n",
+ "Epoch [2/10] Batch 1700/3750 Loss D: -0.0022, loss G: -0.4976\n",
+ "Epoch [2/10] Batch 1800/3750 Loss D: -0.0055, loss G: -0.4997\n",
+ "Epoch [2/10] Batch 1900/3750 Loss D: -0.0024, loss G: -0.5051\n",
+ "Epoch [2/10] Batch 2000/3750 Loss D: -0.0041, loss G: -0.4902\n",
+ "Epoch [2/10] Batch 2100/3750 Loss D: 0.0026, loss G: -0.5093\n",
+ "Epoch [2/10] Batch 2200/3750 Loss D: -0.0048, loss G: -0.4983\n",
+ "Epoch [2/10] Batch 2300/3750 Loss D: -0.0085, loss G: -0.5017\n",
+ "Epoch [2/10] Batch 2400/3750 Loss D: -0.0039, loss G: -0.5015\n",
+ "Epoch [2/10] Batch 2500/3750 Loss D: -0.0025, loss G: -0.5044\n",
+ "Epoch [2/10] Batch 2600/3750 Loss D: 0.0008, loss G: -0.4867\n",
+ "Epoch [2/10] Batch 2700/3750 Loss D: -0.0019, loss G: -0.4906\n",
+ "Epoch [2/10] Batch 2800/3750 Loss D: -0.0019, loss G: -0.5185\n",
+ "Epoch [2/10] Batch 2900/3750 Loss D: -0.0088, loss G: -0.4897\n",
+ "Epoch [2/10] Batch 3000/3750 Loss D: -0.0064, loss G: -0.4981\n",
+ "Epoch [2/10] Batch 3100/3750 Loss D: -0.0044, loss G: -0.5055\n",
+ "Epoch [2/10] Batch 3200/3750 Loss D: -0.0015, loss G: -0.4973\n",
+ "Epoch [2/10] Batch 3300/3750 Loss D: -0.0077, loss G: -0.5113\n",
+ "Epoch [2/10] Batch 3400/3750 Loss D: -0.0034, loss G: -0.4894\n",
+ "Epoch [2/10] Batch 3500/3750 Loss D: -0.0031, loss G: -0.4915\n",
+ "Epoch [2/10] Batch 3600/3750 Loss D: -0.0043, loss G: -0.5019\n",
+ "Epoch [2/10] Batch 3700/3750 Loss D: -0.0027, loss G: -0.5062\n",
+ "Epoch [3/10] Batch 100/3750 Loss D: -0.0027, loss G: -0.4981\n",
+ "Epoch [3/10] Batch 200/3750 Loss D: -0.0029, loss G: -0.4924\n",
+ "Epoch [3/10] Batch 300/3750 Loss D: -0.0026, loss G: -0.5082\n",
+ "Epoch [3/10] Batch 400/3750 Loss D: -0.0037, loss G: -0.4852\n",
+ "Epoch [3/10] Batch 500/3750 Loss D: 0.0004, loss G: -0.5082\n",
+ "Epoch [3/10] Batch 600/3750 Loss D: -0.0006, loss G: -0.4986\n",
+ "Epoch [3/10] Batch 700/3750 Loss D: -0.0051, loss G: -0.4893\n",
+ "Epoch [3/10] Batch 800/3750 Loss D: -0.0021, loss G: -0.5034\n",
+ "Epoch [3/10] Batch 900/3750 Loss D: -0.0001, loss G: -0.5017\n",
+ "Epoch [3/10] Batch 1000/3750 Loss D: -0.0016, loss G: -0.4961\n",
+ "Epoch [3/10] Batch 1100/3750 Loss D: -0.0022, loss G: -0.5031\n",
+ "Epoch [3/10] Batch 1200/3750 Loss D: -0.0043, loss G: -0.5001\n",
+ "Epoch [3/10] Batch 1300/3750 Loss D: -0.0009, loss G: -0.5042\n",
+ "Epoch [3/10] Batch 1400/3750 Loss D: -0.0039, loss G: -0.4915\n",
+ "Epoch [3/10] Batch 1500/3750 Loss D: -0.0063, loss G: -0.5021\n",
+ "Epoch [3/10] Batch 1600/3750 Loss D: -0.0030, loss G: -0.4951\n",
+ "Epoch [3/10] Batch 1700/3750 Loss D: -0.0026, loss G: -0.5071\n",
+ "Epoch [3/10] Batch 1800/3750 Loss D: -0.0024, loss G: -0.4934\n",
+ "Epoch [3/10] Batch 1900/3750 Loss D: -0.0051, loss G: -0.4925\n",
+ "Epoch [3/10] Batch 2000/3750 Loss D: -0.0090, loss G: -0.5091\n",
+ "Epoch [3/10] Batch 2100/3750 Loss D: -0.0024, loss G: -0.4979\n",
+ "Epoch [3/10] Batch 2200/3750 Loss D: -0.0030, loss G: -0.4936\n",
+ "Epoch [3/10] Batch 2300/3750 Loss D: -0.0053, loss G: -0.5066\n",
+ "Epoch [3/10] Batch 2400/3750 Loss D: -0.0047, loss G: -0.4986\n",
+ "Epoch [3/10] Batch 2500/3750 Loss D: -0.0043, loss G: -0.4995\n",
+ "Epoch [3/10] Batch 2600/3750 Loss D: -0.0049, loss G: -0.4927\n",
+ "Epoch [3/10] Batch 2700/3750 Loss D: -0.0065, loss G: -0.4936\n",
+ "Epoch [3/10] Batch 2800/3750 Loss D: -0.0072, loss G: -0.5075\n",
+ "Epoch [3/10] Batch 2900/3750 Loss D: -0.0047, loss G: -0.5044\n",
+ "Epoch [3/10] Batch 3000/3750 Loss D: -0.0041, loss G: -0.4885\n",
+ "Epoch [3/10] Batch 3100/3750 Loss D: -0.0049, loss G: -0.5111\n",
+ "Epoch [3/10] Batch 3200/3750 Loss D: -0.0010, loss G: -0.4951\n",
+ "Epoch [3/10] Batch 3300/3750 Loss D: -0.0062, loss G: -0.5155\n",
+ "Epoch [3/10] Batch 3400/3750 Loss D: -0.0024, loss G: -0.4925\n",
+ "Epoch [3/10] Batch 3500/3750 Loss D: -0.0059, loss G: -0.5156\n",
+ "Epoch [3/10] Batch 3600/3750 Loss D: -0.0030, loss G: -0.4933\n",
+ "Epoch [3/10] Batch 3700/3750 Loss D: -0.0036, loss G: -0.4990\n",
+ "Epoch [4/10] Batch 100/3750 Loss D: -0.0026, loss G: -0.4958\n",
+ "Epoch [4/10] Batch 200/3750 Loss D: -0.0009, loss G: -0.5010\n",
+ "Epoch [4/10] Batch 300/3750 Loss D: -0.0024, loss G: -0.5063\n",
+ "Epoch [4/10] Batch 400/3750 Loss D: -0.0052, loss G: -0.5031\n",
+ "Epoch [4/10] Batch 500/3750 Loss D: -0.0029, loss G: -0.5017\n",
+ "Epoch [4/10] Batch 600/3750 Loss D: -0.0027, loss G: -0.4863\n",
+ "Epoch [4/10] Batch 700/3750 Loss D: -0.0037, loss G: -0.4967\n",
+ "Epoch [4/10] Batch 800/3750 Loss D: -0.0010, loss G: -0.5009\n",
+ "Epoch [4/10] Batch 900/3750 Loss D: -0.0046, loss G: -0.5101\n",
+ "Epoch [4/10] Batch 1000/3750 Loss D: -0.0026, loss G: -0.4984\n",
+ "Epoch [4/10] Batch 1100/3750 Loss D: 0.0007, loss G: -0.4998\n",
+ "Epoch [4/10] Batch 1200/3750 Loss D: -0.0029, loss G: -0.5035\n",
+ "Epoch [4/10] Batch 1300/3750 Loss D: -0.0044, loss G: -0.4957\n",
+ "Epoch [4/10] Batch 1400/3750 Loss D: -0.0033, loss G: -0.5013\n",
+ "Epoch [4/10] Batch 1500/3750 Loss D: -0.0019, loss G: -0.4905\n",
+ "Epoch [4/10] Batch 1600/3750 Loss D: -0.0037, loss G: -0.5051\n",
+ "Epoch [4/10] Batch 1700/3750 Loss D: -0.0051, loss G: -0.4912\n",
+ "Epoch [4/10] Batch 1800/3750 Loss D: -0.0019, loss G: -0.4996\n",
+ "Epoch [4/10] Batch 1900/3750 Loss D: -0.0047, loss G: -0.5016\n",
+ "Epoch [4/10] Batch 2000/3750 Loss D: -0.0039, loss G: -0.4955\n",
+ "Epoch [4/10] Batch 2100/3750 Loss D: -0.0056, loss G: -0.5072\n",
+ "Epoch [4/10] Batch 2200/3750 Loss D: -0.0048, loss G: -0.4935\n",
+ "Epoch [4/10] Batch 2300/3750 Loss D: -0.0035, loss G: -0.5056\n",
+ "Epoch [4/10] Batch 2400/3750 Loss D: -0.0022, loss G: -0.5014\n",
+ "Epoch [4/10] Batch 2500/3750 Loss D: -0.0005, loss G: -0.4924\n",
+ "Epoch [4/10] Batch 2600/3750 Loss D: -0.0054, loss G: -0.5005\n",
+ "Epoch [4/10] Batch 2700/3750 Loss D: -0.0021, loss G: -0.4993\n",
+ "Epoch [4/10] Batch 2800/3750 Loss D: -0.0054, loss G: -0.4990\n",
+ "Epoch [4/10] Batch 2900/3750 Loss D: -0.0009, loss G: -0.5074\n",
+ "Epoch [4/10] Batch 3000/3750 Loss D: -0.0045, loss G: -0.4914\n",
+ "Epoch [4/10] Batch 3100/3750 Loss D: -0.0059, loss G: -0.4960\n",
+ "Epoch [4/10] Batch 3200/3750 Loss D: -0.0033, loss G: -0.5098\n",
+ "Epoch [4/10] Batch 3300/3750 Loss D: -0.0039, loss G: -0.4936\n",
+ "Epoch [4/10] Batch 3400/3750 Loss D: -0.0023, loss G: -0.4972\n",
+ "Epoch [4/10] Batch 3500/3750 Loss D: -0.0019, loss G: -0.5032\n",
+ "Epoch [4/10] Batch 3600/3750 Loss D: -0.0049, loss G: -0.4915\n",
+ "Epoch [4/10] Batch 3700/3750 Loss D: -0.0022, loss G: -0.5065\n",
+ "Epoch [5/10] Batch 100/3750 Loss D: -0.0025, loss G: -0.4994\n",
+ "Epoch [5/10] Batch 200/3750 Loss D: -0.0034, loss G: -0.5005\n",
+ "Epoch [5/10] Batch 300/3750 Loss D: -0.0050, loss G: -0.5015\n",
+ "Epoch [5/10] Batch 400/3750 Loss D: -0.0036, loss G: -0.4977\n",
+ "Epoch [5/10] Batch 500/3750 Loss D: -0.0033, loss G: -0.4982\n",
+ "Epoch [5/10] Batch 600/3750 Loss D: -0.0044, loss G: -0.4963\n",
+ "Epoch [5/10] Batch 700/3750 Loss D: -0.0036, loss G: -0.5141\n",
+ "Epoch [5/10] Batch 800/3750 Loss D: -0.0025, loss G: -0.4914\n",
+ "Epoch [5/10] Batch 900/3750 Loss D: -0.0025, loss G: -0.4961\n",
+ "Epoch [5/10] Batch 1000/3750 Loss D: -0.0040, loss G: -0.4935\n",
+ "Epoch [5/10] Batch 1100/3750 Loss D: -0.0028, loss G: -0.4985\n",
+ "Epoch [5/10] Batch 1200/3750 Loss D: -0.0052, loss G: -0.4888\n",
+ "Epoch [5/10] Batch 1300/3750 Loss D: -0.0043, loss G: -0.4984\n",
+ "Epoch [5/10] Batch 1400/3750 Loss D: -0.0039, loss G: -0.4922\n",
+ "Epoch [5/10] Batch 1500/3750 Loss D: -0.0033, loss G: -0.4976\n",
+ "Epoch [5/10] Batch 1600/3750 Loss D: -0.0014, loss G: -0.4988\n",
+ "Epoch [5/10] Batch 1700/3750 Loss D: -0.0069, loss G: -0.5106\n",
+ "Epoch [5/10] Batch 1800/3750 Loss D: 0.0003, loss G: -0.4913\n",
+ "Epoch [5/10] Batch 1900/3750 Loss D: -0.0041, loss G: -0.4869\n",
+ "Epoch [5/10] Batch 2000/3750 Loss D: -0.0028, loss G: -0.5061\n",
+ "Epoch [5/10] Batch 2100/3750 Loss D: -0.0032, loss G: -0.4949\n",
+ "Epoch [5/10] Batch 2200/3750 Loss D: -0.0006, loss G: -0.4980\n",
+ "Epoch [5/10] Batch 2300/3750 Loss D: -0.0043, loss G: -0.5005\n",
+ "Epoch [5/10] Batch 2400/3750 Loss D: -0.0028, loss G: -0.5026\n",
+ "Epoch [5/10] Batch 2500/3750 Loss D: -0.0015, loss G: -0.4894\n",
+ "Epoch [5/10] Batch 2600/3750 Loss D: -0.0016, loss G: -0.5002\n",
+ "Epoch [5/10] Batch 2700/3750 Loss D: -0.0027, loss G: -0.5040\n",
+ "Epoch [5/10] Batch 2800/3750 Loss D: -0.0024, loss G: -0.4950\n",
+ "Epoch [5/10] Batch 2900/3750 Loss D: -0.0043, loss G: -0.5065\n",
+ "Epoch [5/10] Batch 3000/3750 Loss D: -0.0023, loss G: -0.4973\n",
+ "Epoch [5/10] Batch 3100/3750 Loss D: -0.0054, loss G: -0.5013\n",
+ "Epoch [5/10] Batch 3200/3750 Loss D: -0.0055, loss G: -0.4915\n",
+ "Epoch [5/10] Batch 3300/3750 Loss D: -0.0005, loss G: -0.5003\n",
+ "Epoch [5/10] Batch 3400/3750 Loss D: 0.0021, loss G: -0.5123\n",
+ "Epoch [5/10] Batch 3500/3750 Loss D: -0.0025, loss G: -0.5061\n",
+ "Epoch [5/10] Batch 3600/3750 Loss D: -0.0056, loss G: -0.4878\n",
+ "Epoch [5/10] Batch 3700/3750 Loss D: -0.0010, loss G: -0.4904\n",
+ "Epoch [6/10] Batch 100/3750 Loss D: -0.0011, loss G: -0.5075\n",
+ "Epoch [6/10] Batch 200/3750 Loss D: -0.0012, loss G: -0.4982\n",
+ "Epoch [6/10] Batch 300/3750 Loss D: -0.0027, loss G: -0.4844\n",
+ "Epoch [6/10] Batch 400/3750 Loss D: -0.0009, loss G: -0.5052\n",
+ "Epoch [6/10] Batch 500/3750 Loss D: -0.0051, loss G: -0.5053\n",
+ "Epoch [6/10] Batch 600/3750 Loss D: -0.0033, loss G: -0.4884\n",
+ "Epoch [6/10] Batch 700/3750 Loss D: -0.0037, loss G: -0.5062\n",
+ "Epoch [6/10] Batch 800/3750 Loss D: 0.0000, loss G: -0.5074\n",
+ "Epoch [6/10] Batch 900/3750 Loss D: -0.0024, loss G: -0.4913\n",
+ "Epoch [6/10] Batch 1000/3750 Loss D: 0.0030, loss G: -0.4941\n",
+ "Epoch [6/10] Batch 1100/3750 Loss D: -0.0030, loss G: -0.5037\n",
+ "Epoch [6/10] Batch 1200/3750 Loss D: -0.0021, loss G: -0.4996\n",
+ "Epoch [6/10] Batch 1300/3750 Loss D: -0.0017, loss G: -0.4996\n",
+ "Epoch [6/10] Batch 1400/3750 Loss D: -0.0006, loss G: -0.4933\n",
+ "Epoch [6/10] Batch 1500/3750 Loss D: 0.0062, loss G: -0.5148\n",
+ "Epoch [6/10] Batch 1600/3750 Loss D: -0.0035, loss G: -0.4997\n",
+ "Epoch [6/10] Batch 1700/3750 Loss D: -0.0019, loss G: -0.4888\n",
+ "Epoch [6/10] Batch 1800/3750 Loss D: -0.0025, loss G: -0.5085\n",
+ "Epoch [6/10] Batch 1900/3750 Loss D: -0.0057, loss G: -0.5050\n",
+ "Epoch [6/10] Batch 2000/3750 Loss D: -0.0018, loss G: -0.4961\n",
+ "Epoch [6/10] Batch 2100/3750 Loss D: -0.0034, loss G: -0.4926\n",
+ "Epoch [6/10] Batch 2200/3750 Loss D: -0.0040, loss G: -0.4948\n",
+ "Epoch [6/10] Batch 2300/3750 Loss D: 0.0001, loss G: -0.5038\n",
+ "Epoch [6/10] Batch 2400/3750 Loss D: 0.0039, loss G: -0.5102\n",
+ "Epoch [6/10] Batch 2500/3750 Loss D: 0.0011, loss G: -0.4952\n",
+ "Epoch [6/10] Batch 2600/3750 Loss D: -0.0023, loss G: -0.4910\n",
+ "Epoch [6/10] Batch 2700/3750 Loss D: -0.0027, loss G: -0.5096\n",
+ "Epoch [6/10] Batch 2800/3750 Loss D: -0.0070, loss G: -0.5112\n",
+ "Epoch [6/10] Batch 2900/3750 Loss D: -0.0019, loss G: -0.4948\n",
+ "Epoch [6/10] Batch 3000/3750 Loss D: -0.0026, loss G: -0.4915\n",
+ "Epoch [6/10] Batch 3100/3750 Loss D: -0.0029, loss G: -0.4974\n",
+ "Epoch [6/10] Batch 3200/3750 Loss D: -0.0053, loss G: -0.5091\n",
+ "Epoch [6/10] Batch 3300/3750 Loss D: -0.0027, loss G: -0.4999\n",
+ "Epoch [6/10] Batch 3400/3750 Loss D: -0.0028, loss G: -0.4978\n",
+ "Epoch [6/10] Batch 3500/3750 Loss D: -0.0019, loss G: -0.5035\n",
+ "Epoch [6/10] Batch 3600/3750 Loss D: -0.0014, loss G: -0.4923\n",
+ "Epoch [6/10] Batch 3700/3750 Loss D: -0.0024, loss G: -0.4987\n",
+ "Epoch [7/10] Batch 100/3750 Loss D: -0.0047, loss G: -0.5067\n",
+ "Epoch [7/10] Batch 200/3750 Loss D: -0.0064, loss G: -0.4877\n",
+ "Epoch [7/10] Batch 300/3750 Loss D: -0.0028, loss G: -0.5004\n",
+ "Epoch [7/10] Batch 400/3750 Loss D: -0.0040, loss G: -0.4992\n",
+ "Epoch [7/10] Batch 500/3750 Loss D: -0.0020, loss G: -0.5005\n",
+ "Epoch [7/10] Batch 600/3750 Loss D: -0.0026, loss G: -0.5004\n",
+ "Epoch [7/10] Batch 700/3750 Loss D: -0.0010, loss G: -0.5040\n",
+ "Epoch [7/10] Batch 800/3750 Loss D: -0.0050, loss G: -0.4982\n",
+ "Epoch [7/10] Batch 900/3750 Loss D: -0.0040, loss G: -0.4954\n",
+ "Epoch [7/10] Batch 1000/3750 Loss D: -0.0051, loss G: -0.4927\n",
+ "Epoch [7/10] Batch 1100/3750 Loss D: -0.0013, loss G: -0.4957\n",
+ "Epoch [7/10] Batch 1200/3750 Loss D: 0.0025, loss G: -0.5183\n",
+ "Epoch [7/10] Batch 1300/3750 Loss D: -0.0030, loss G: -0.4889\n",
+ "Epoch [7/10] Batch 1400/3750 Loss D: -0.0030, loss G: -0.4885\n",
+ "Epoch [7/10] Batch 1500/3750 Loss D: -0.0034, loss G: -0.5052\n",
+ "Epoch [7/10] Batch 1600/3750 Loss D: -0.0031, loss G: -0.4985\n",
+ "Epoch [7/10] Batch 1700/3750 Loss D: -0.0030, loss G: -0.4966\n",
+ "Epoch [7/10] Batch 1800/3750 Loss D: -0.0015, loss G: -0.5048\n",
+ "Epoch [7/10] Batch 1900/3750 Loss D: -0.0026, loss G: -0.4986\n",
+ "Epoch [7/10] Batch 2000/3750 Loss D: -0.0030, loss G: -0.5044\n",
+ "Epoch [7/10] Batch 2100/3750 Loss D: -0.0025, loss G: -0.4895\n",
+ "Epoch [7/10] Batch 2200/3750 Loss D: -0.0022, loss G: -0.4996\n",
+ "Epoch [7/10] Batch 2300/3750 Loss D: -0.0007, loss G: -0.5045\n",
+ "Epoch [7/10] Batch 2400/3750 Loss D: -0.0024, loss G: -0.4849\n",
+ "Epoch [7/10] Batch 2500/3750 Loss D: -0.0053, loss G: -0.5045\n",
+ "Epoch [7/10] Batch 2600/3750 Loss D: -0.0017, loss G: -0.4959\n",
+ "Epoch [7/10] Batch 2700/3750 Loss D: -0.0008, loss G: -0.4996\n",
+ "Epoch [7/10] Batch 2800/3750 Loss D: -0.0026, loss G: -0.5021\n",
+ "Epoch [7/10] Batch 2900/3750 Loss D: -0.0024, loss G: -0.4979\n",
+ "Epoch [7/10] Batch 3000/3750 Loss D: -0.0027, loss G: -0.4986\n",
+ "Epoch [7/10] Batch 3100/3750 Loss D: -0.0033, loss G: -0.5038\n",
+ "Epoch [7/10] Batch 3200/3750 Loss D: -0.0062, loss G: -0.4876\n",
+ "Epoch [7/10] Batch 3300/3750 Loss D: -0.0021, loss G: -0.4927\n",
+ "Epoch [7/10] Batch 3400/3750 Loss D: -0.0034, loss G: -0.5103\n",
+ "Epoch [7/10] Batch 3500/3750 Loss D: -0.0023, loss G: -0.4956\n",
+ "Epoch [7/10] Batch 3600/3750 Loss D: -0.0041, loss G: -0.4939\n",
+ "Epoch [7/10] Batch 3700/3750 Loss D: -0.0046, loss G: -0.5083\n",
+ "Epoch [8/10] Batch 100/3750 Loss D: -0.0029, loss G: -0.4985\n",
+ "Epoch [8/10] Batch 200/3750 Loss D: -0.0009, loss G: -0.5109\n",
+ "Epoch [8/10] Batch 300/3750 Loss D: -0.0048, loss G: -0.4945\n",
+ "Epoch [8/10] Batch 400/3750 Loss D: -0.0041, loss G: -0.4953\n",
+ "Epoch [8/10] Batch 500/3750 Loss D: -0.0009, loss G: -0.4949\n",
+ "Epoch [8/10] Batch 600/3750 Loss D: -0.0021, loss G: -0.5110\n",
+ "Epoch [8/10] Batch 700/3750 Loss D: -0.0038, loss G: -0.4874\n",
+ "Epoch [8/10] Batch 800/3750 Loss D: -0.0010, loss G: -0.4979\n",
+ "Epoch [8/10] Batch 900/3750 Loss D: -0.0024, loss G: -0.5055\n",
+ "Epoch [8/10] Batch 1000/3750 Loss D: -0.0019, loss G: -0.4990\n",
+ "Epoch [8/10] Batch 1100/3750 Loss D: -0.0059, loss G: -0.4890\n",
+ "Epoch [8/10] Batch 1200/3750 Loss D: -0.0031, loss G: -0.5019\n",
+ "Epoch [8/10] Batch 1300/3750 Loss D: -0.0001, loss G: -0.5003\n",
+ "Epoch [8/10] Batch 1400/3750 Loss D: -0.0022, loss G: -0.4961\n",
+ "Epoch [8/10] Batch 1500/3750 Loss D: -0.0039, loss G: -0.4850\n",
+ "Epoch [8/10] Batch 1600/3750 Loss D: -0.0001, loss G: -0.4995\n",
+ "Epoch [8/10] Batch 1700/3750 Loss D: -0.0017, loss G: -0.4933\n",
+ "Epoch [8/10] Batch 1800/3750 Loss D: -0.0023, loss G: -0.4929\n",
+ "Epoch [8/10] Batch 1900/3750 Loss D: -0.0082, loss G: -0.5114\n",
+ "Epoch [8/10] Batch 2000/3750 Loss D: -0.0029, loss G: -0.4958\n",
+ "Epoch [8/10] Batch 2100/3750 Loss D: -0.0037, loss G: -0.4947\n",
+ "Epoch [8/10] Batch 2200/3750 Loss D: -0.0033, loss G: -0.4997\n",
+ "Epoch [8/10] Batch 2300/3750 Loss D: -0.0022, loss G: -0.5159\n",
+ "Epoch [8/10] Batch 2400/3750 Loss D: -0.0018, loss G: -0.4965\n",
+ "Epoch [8/10] Batch 2500/3750 Loss D: -0.0051, loss G: -0.4886\n",
+ "Epoch [8/10] Batch 2600/3750 Loss D: -0.0039, loss G: -0.5092\n",
+ "Epoch [8/10] Batch 2700/3750 Loss D: -0.0042, loss G: -0.4885\n",
+ "Epoch [8/10] Batch 2800/3750 Loss D: -0.0001, loss G: -0.4892\n",
+ "Epoch [8/10] Batch 2900/3750 Loss D: -0.0054, loss G: -0.5047\n",
+ "Epoch [8/10] Batch 3000/3750 Loss D: -0.0019, loss G: -0.4976\n",
+ "Epoch [8/10] Batch 3100/3750 Loss D: -0.0020, loss G: -0.4906\n",
+ "Epoch [8/10] Batch 3200/3750 Loss D: -0.0027, loss G: -0.4951\n",
+ "Epoch [8/10] Batch 3300/3750 Loss D: -0.0041, loss G: -0.4963\n",
+ "Epoch [8/10] Batch 3400/3750 Loss D: -0.0067, loss G: -0.5086\n",
+ "Epoch [8/10] Batch 3500/3750 Loss D: -0.0049, loss G: -0.4841\n",
+ "Epoch [8/10] Batch 3600/3750 Loss D: -0.0020, loss G: -0.5020\n",
+ "Epoch [8/10] Batch 3700/3750 Loss D: -0.0010, loss G: -0.4989\n",
+ "Epoch [9/10] Batch 100/3750 Loss D: -0.0031, loss G: -0.4877\n",
+ "Epoch [9/10] Batch 200/3750 Loss D: -0.0038, loss G: -0.5043\n",
+ "Epoch [9/10] Batch 300/3750 Loss D: -0.0012, loss G: -0.5083\n",
+ "Epoch [9/10] Batch 400/3750 Loss D: -0.0023, loss G: -0.4994\n",
+ "Epoch [9/10] Batch 500/3750 Loss D: -0.0025, loss G: -0.4888\n",
+ "Epoch [9/10] Batch 600/3750 Loss D: -0.0004, loss G: -0.5074\n",
+ "Epoch [9/10] Batch 700/3750 Loss D: -0.0021, loss G: -0.5071\n",
+ "Epoch [9/10] Batch 800/3750 Loss D: -0.0026, loss G: -0.4994\n",
+ "Epoch [9/10] Batch 900/3750 Loss D: -0.0026, loss G: -0.4893\n",
+ "Epoch [9/10] Batch 1000/3750 Loss D: -0.0021, loss G: -0.4991\n",
+ "Epoch [9/10] Batch 1100/3750 Loss D: 0.0005, loss G: -0.5113\n",
+ "Epoch [9/10] Batch 1200/3750 Loss D: -0.0019, loss G: -0.4907\n",
+ "Epoch [9/10] Batch 1300/3750 Loss D: -0.0030, loss G: -0.4914\n",
+ "Epoch [9/10] Batch 1400/3750 Loss D: -0.0027, loss G: -0.5031\n",
+ "Epoch [9/10] Batch 1500/3750 Loss D: -0.0029, loss G: -0.5028\n",
+ "Epoch [9/10] Batch 1600/3750 Loss D: -0.0011, loss G: -0.4989\n",
+ "Epoch [9/10] Batch 1700/3750 Loss D: -0.0051, loss G: -0.5052\n",
+ "Epoch [9/10] Batch 1800/3750 Loss D: -0.0021, loss G: -0.4996\n",
+ "Epoch [9/10] Batch 1900/3750 Loss D: -0.0019, loss G: -0.5016\n",
+ "Epoch [9/10] Batch 2000/3750 Loss D: -0.0046, loss G: -0.4891\n",
+ "Epoch [9/10] Batch 2100/3750 Loss D: -0.0041, loss G: -0.4936\n",
+ "Epoch [9/10] Batch 2200/3750 Loss D: -0.0056, loss G: -0.5113\n",
+ "Epoch [9/10] Batch 2300/3750 Loss D: -0.0013, loss G: -0.4882\n",
+ "Epoch [9/10] Batch 2400/3750 Loss D: 0.0018, loss G: -0.4893\n",
+ "Epoch [9/10] Batch 2500/3750 Loss D: -0.0021, loss G: -0.5094\n",
+ "Epoch [9/10] Batch 2600/3750 Loss D: -0.0040, loss G: -0.5018\n",
+ "Epoch [9/10] Batch 2700/3750 Loss D: -0.0017, loss G: -0.4905\n",
+ "Epoch [9/10] Batch 2800/3750 Loss D: -0.0016, loss G: -0.5050\n",
+ "Epoch [9/10] Batch 2900/3750 Loss D: -0.0025, loss G: -0.4918\n",
+ "Epoch [9/10] Batch 3000/3750 Loss D: -0.0036, loss G: -0.4962\n",
+ "Epoch [9/10] Batch 3100/3750 Loss D: -0.0008, loss G: -0.5071\n",
+ "Epoch [9/10] Batch 3200/3750 Loss D: -0.0034, loss G: -0.4921\n",
+ "Epoch [9/10] Batch 3300/3750 Loss D: -0.0017, loss G: -0.5013\n",
+ "Epoch [9/10] Batch 3400/3750 Loss D: -0.0035, loss G: -0.5001\n",
+ "Epoch [9/10] Batch 3500/3750 Loss D: -0.0003, loss G: -0.5042\n",
+ "Epoch [9/10] Batch 3600/3750 Loss D: -0.0024, loss G: -0.4967\n",
+ "Epoch [9/10] Batch 3700/3750 Loss D: -0.0040, loss G: -0.5010\n",
+ "Epoch [10/10] Batch 100/3750 Loss D: -0.0025, loss G: -0.4959\n",
+ "Epoch [10/10] Batch 200/3750 Loss D: -0.0025, loss G: -0.4948\n",
+ "Epoch [10/10] Batch 300/3750 Loss D: -0.0028, loss G: -0.5041\n",
+ "Epoch [10/10] Batch 400/3750 Loss D: -0.0043, loss G: -0.4954\n",
+ "Epoch [10/10] Batch 500/3750 Loss D: -0.0039, loss G: -0.4957\n",
+ "Epoch [10/10] Batch 600/3750 Loss D: -0.0030, loss G: -0.5023\n",
+ "Epoch [10/10] Batch 700/3750 Loss D: -0.0014, loss G: -0.4961\n",
+ "Epoch [10/10] Batch 800/3750 Loss D: -0.0025, loss G: -0.5009\n",
+ "Epoch [10/10] Batch 900/3750 Loss D: -0.0005, loss G: -0.4962\n",
+ "Epoch [10/10] Batch 1000/3750 Loss D: -0.0012, loss G: -0.5014\n",
+ "Epoch [10/10] Batch 1100/3750 Loss D: -0.0035, loss G: -0.4973\n",
+ "Epoch [10/10] Batch 1200/3750 Loss D: -0.0008, loss G: -0.4894\n",
+ "Epoch [10/10] Batch 1300/3750 Loss D: -0.0031, loss G: -0.5034\n",
+ "Epoch [10/10] Batch 1400/3750 Loss D: -0.0001, loss G: -0.5063\n",
+ "Epoch [10/10] Batch 1500/3750 Loss D: -0.0028, loss G: -0.4955\n",
+ "Epoch [10/10] Batch 1600/3750 Loss D: -0.0026, loss G: -0.4992\n",
+ "Epoch [10/10] Batch 1700/3750 Loss D: -0.0031, loss G: -0.5024\n",
+ "Epoch [10/10] Batch 1800/3750 Loss D: -0.0029, loss G: -0.4947\n",
+ "Epoch [10/10] Batch 1900/3750 Loss D: 0.0011, loss G: -0.4962\n",
+ "Epoch [10/10] Batch 2000/3750 Loss D: -0.0056, loss G: -0.5062\n",
+ "Epoch [10/10] Batch 2100/3750 Loss D: -0.0021, loss G: -0.4998\n",
+ "Epoch [10/10] Batch 2200/3750 Loss D: -0.0057, loss G: -0.4840\n",
+ "Epoch [10/10] Batch 2300/3750 Loss D: -0.0029, loss G: -0.4959\n",
+ "Epoch [10/10] Batch 2400/3750 Loss D: -0.0036, loss G: -0.5013\n",
+ "Epoch [10/10] Batch 2500/3750 Loss D: -0.0038, loss G: -0.4980\n",
+ "Epoch [10/10] Batch 2600/3750 Loss D: -0.0046, loss G: -0.4966\n",
+ "Epoch [10/10] Batch 2700/3750 Loss D: -0.0030, loss G: -0.5035\n",
+ "Epoch [10/10] Batch 2800/3750 Loss D: -0.0024, loss G: -0.4980\n",
+ "Epoch [10/10] Batch 2900/3750 Loss D: -0.0032, loss G: -0.4970\n",
+ "Epoch [10/10] Batch 3000/3750 Loss D: -0.0036, loss G: -0.4967\n",
+ "Epoch [10/10] Batch 3100/3750 Loss D: -0.0034, loss G: -0.4930\n",
+ "Epoch [10/10] Batch 3200/3750 Loss D: -0.0037, loss G: -0.5034\n",
+ "Epoch [10/10] Batch 3300/3750 Loss D: -0.0041, loss G: -0.4950\n",
+ "Epoch [10/10] Batch 3400/3750 Loss D: -0.0041, loss G: -0.5000\n",
+ "Epoch [10/10] Batch 3500/3750 Loss D: -0.0026, loss G: -0.4896\n",
+ "Epoch [10/10] Batch 3600/3750 Loss D: -0.0034, loss G: -0.4970\n",
+ "Epoch [10/10] Batch 3700/3750 Loss D: -0.0031, loss G: -0.4990\n"
+ ]
+ }
+ ],
+ "source": [
+ "ERRD_x = np.zeros(config.num_epochs)\n",
+ "ERRD_z = np.zeros(config.num_epochs)\n",
+ "ERRG = np.zeros(config.num_epochs)\n",
+ "N = len(dataloader)\n",
+ "#WGAN modified of DCGAN in:\n",
+ "#1. remove sigmoid in the last layer of discriminator(classification -> regression)\n",
+ "#2. no log Loss (Wasserstein distance)\n",
+ "#3. clip param norm to c (Wasserstein distance and Lipschitz continuity)\n",
+ "#4. No momentum-based optimizer, use RMSProp,SGD instead\n",
+ "\n",
+ "for epoch in range(config.num_epochs):\n",
+ " for iteration, (images, cat) in enumerate(dataloader):\n",
+ " # Train Critic: max E[critic(real)] - E[critic(fake)]\n",
+ " # CRITIC_ITERATIONS = 5\n",
+ " # optim_D.zero_grad()\n",
+ " ####### \n",
+ " # Discriminator stage: maximize log(D(x)) + log(1 - D(G(z))) \n",
+ " #######\n",
+ " discriminator.zero_grad()\n",
+ " # real\n",
+ " label.data.fill_(real_label)\n",
+ " input_data = images.view(images.shape[0], -1)\n",
+ " output_real = discriminator(input_data)\n",
+ " # errD_x = criterion(output, label)\n",
+ " # ERRD_x[epoch] += errD_x.item()\n",
+ " # errD_x.backward()\n",
+ " \n",
+ " # fake \n",
+ " noise.data.normal_(0, 1)\n",
+ " fake = generator(noise)\n",
+ " label.data.fill_(fake_label)\n",
+ " output_fake = discriminator(fake.detach())\n",
+ " # errD_z = criterion(output, label)\n",
+ " # ERRD_z[epoch] += errD_z.item()\n",
+ " # errD_z.backward()\n",
+ " loss_critic = -(torch.mean(output_real) - torch.mean(output_fake))\n",
+ "\n",
+ "\n",
+ " loss_critic.backward(retain_graph=True)\n",
+ " optim_D.step()\n",
+ " \n",
+ " # modification3: clip param norm to c=0.01 (Wasserstein distance and Lipschitz continuity)\n",
+ " for parm in discriminator.parameters():\n",
+ " parm.data.clamp_(-0.01, 0.01)\n",
+ "\n",
+ "\n",
+ " # Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]\n",
+ " generator.zero_grad()\n",
+ " label.data.fill_(real_label)\n",
+ " gen_fake = discriminator(fake)\n",
+ " loss_gen = -torch.mean(gen_fake)\n",
+ "\n",
+ " loss_gen.backward()\n",
+ "\n",
+ " optim_G.step()\n",
+ "\n",
+ " ####### \n",
+ " # Generator stage: maximize log(D(G(x))\n",
+ " #######\n",
+ " # generator.zero_grad()\n",
+ " # label.data.fill_(real_label)\n",
+ " # output = discriminator(fake)\n",
+ " # errG = criterion(output, label)\n",
+ " # ERRG[epoch] += errG.item()\n",
+ " # errG.backward()\n",
+ " \n",
+ " # optim_G.step()\n",
+ " \n",
+ " '''\n",
+ " if (iteration+1) % config.print_freq == 0:\n",
+ " print('Epoch:{} Iter: {} errD_x: {:.2f} errD_z: {:.2f} errG: {:.2f}'.format(epoch+1,\n",
+ " iteration+1, \n",
+ " '''\n",
+ " \n",
+ " # Print losses occasionally and print to tensorboard\n",
+ " if iteration % 100 == 0 and iteration > 0:\n",
+ " generator.eval()\n",
+ " discriminator.eval()\n",
+ " print(\n",
+ " f\"Epoch [{epoch+1}/{10}] Batch {iteration}/{len(dataloader)} \\\n",
+ " Loss D: {loss_critic:.4f}, loss G: {loss_gen:.4f}\"\n",
+ " )\n",
+ "\n",
+ " with torch.no_grad():\n",
+ " fake = generator(noise)\n",
+ " # take out (up to) 32 examples\n",
+ " img_grid_real = torchvision.utils.make_grid(\n",
+ " images[:32], normalize=True\n",
+ " )\n",
+ " img_grid_fake = torchvision.utils.make_grid(\n",
+ " fake[:32], normalize=True\n",
+ " )\n",
+ "\n",
+ " writer_real.add_image(\"Real\", img_grid_real, global_step=step)\n",
+ " writer_fake.add_image(\"Fake\", img_grid_fake, global_step=step)\n",
+ "\n",
+ " step += 1\n",
+ " generator.train()\n",
+ " discriminator.train()\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {
+ "id": "sDYUxxGIUdru",
+ "outputId": "10c9d0f3-f820-45bb-f763-3c9f0b84f136",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 401
+ }
+ },
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "