Skip to content

Commit

Permalink
Update AGGAN and README.md
Browse files Browse the repository at this point in the history
  • Loading branch information
Kyushik committed Sep 27, 2019
1 parent 5fcd18d commit c0b113f
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 888 deletions.
2 changes: 1 addition & 1 deletion 08_InfoGAN_MNIST.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -750,7 +750,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
"version": "3.6.5"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion 09_CycleGAN_Horse2Zebra.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
"version": "3.6.5"
}
},
"nbformat": 4,
Expand Down
40 changes: 11 additions & 29 deletions 11_AGGAN.ipynb → 11_AGGAN_Horse2Zebra.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# CycleGAN Horse2Zebra\n",
"# AGGAN Horse2Zebra\n",
"\n",
"This notebook is for implementing `Unsupervised-Attention-guided-Image-to-Image-Translation` from the paper [Unsupervised Attention-guided Image-to-Image Traslation](https://arxiv.org/abs/1806.02311) with [Tensorflow](https://www.tensorflow.org). <br>\n",
"[Horse2Zebra dataset](https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/), which is 256x256 size, will be used. \n",
Expand All @@ -21,7 +21,7 @@
"outputs": [],
"source": [
"import os\n",
"os.environ['CUDA_VISIBLE_DEVICES'] = '5'"
"os.environ['CUDA_VISIBLE_DEVICES'] = '0'"
]
},
{
Expand Down Expand Up @@ -81,19 +81,17 @@
},
"outputs": [],
"source": [
"algorithm = 'AGGAN'\n",
"algorithm = 'AGGAN_Horse2Zebra'\n",
"\n",
"img_size = 256\n",
"batch_size = 1\n",
"num_epoch = 200\n",
"\n",
"beta1 = 0.5\n",
"\n",
"\n",
"start_attention_epoch = 30 # When stop the training of attention network\n",
"attention_th = 0.1 # Mask selection threshold\n",
"\n",
"\n",
"init_lr = 0.0002\n",
"start_decay_epoch = 100\n",
"\n",
Expand All @@ -102,7 +100,8 @@
"load_model = False\n",
"train_model = True\n",
"\n",
"save_path = \"./../saved_models/\" + date_time + \"_\" + algorithm"
"save_path = \"./../saved_models/\" + date_time + \"_\" + algorithm\n",
"load_path = \"./saved_models/20190809-11-04-47_AGGAN_Horse2Zebra/model/model\" "
]
},
{
Expand Down Expand Up @@ -283,7 +282,6 @@
"def Discriminator(x, network_name, is_training, training_flag, reuse=False):\n",
" with tf.variable_scope(network_name, reuse=reuse):\n",
"\n",
"\n",
" # First conv layer (C64)\n",
" h1 = tf.layers.conv2d(x, filters=64, kernel_size=4, strides=2, padding='SAME')\n",
" h1 = tf.cond(training_flag>0, lambda:tf.contrib.layers.instance_norm(h1) ,lambda:h1)\n",
Expand All @@ -307,7 +305,6 @@
" # Output layer \n",
" logit = tf.layers.conv2d(h4, filters=1, kernel_size=4, strides=1, padding='SAME')\n",
" output = tf.sigmoid(logit)\n",
" \n",
" \n",
" return logit, output"
]
Expand All @@ -329,9 +326,7 @@
"source": [
"def GAN(x, y, is_training, training_flag):\n",
" \n",
" #################\n",
" ### Generator ###\n",
" #################\n",
" ############################## Generator ##############################\n",
" \n",
" # X -> Y\n",
" y_gen = Generator(x, 'Gx', is_training)\n",
Expand All @@ -349,9 +344,7 @@
" gen_image_x = tf.multiply(tf.concat([y_mask]*3, axis=3), x_gen)+\\\n",
" tf.multiply((1-tf.concat([y_mask]*3,axis=3)),y)\n",
" \n",
" #############\n",
" ### Cycle ###\n",
" #############\n",
" ############################### Cycle ################################\n",
" \n",
" # X->Y\n",
" y_cycle = Generator(gen_image_x, 'Gx', is_training, reuse = True)\n",
Expand All @@ -369,9 +362,7 @@
" gen_cycle_image_x = tf.multiply(tf.concat([y_mask_cycle]*3, axis=3), x_cycle)+\\\n",
" tf.multiply((1-tf.concat([y_mask_cycle]*3,axis=3)),gen_image_y)\n",
" \n",
" #####################\n",
" ### Discriminator ###\n",
" #####################\n",
" ########################### Discriminator #############################\n",
" \n",
" #Image for step after attention epoch\n",
" Dx_in = tf.where((tf.concat([x_mask]*3, axis=3))>attention_th, x, tf.zeros_like(x))\n",
Expand All @@ -396,7 +387,6 @@
" D_logit_x_fake, D_out_x_fake = Discriminator(fake_image_x, 'Dx', is_training, training_flag, reuse=True) \n",
" D_logit_y_fake, D_out_y_fake = Discriminator(fake_image_y, 'Dy', is_training, training_flag, reuse=True)\n",
"\n",
" \n",
" # Loss function\n",
" Dx_loss = tf.reduce_mean(tf.square(D_logit_x_real-1)) + tf.reduce_mean(tf.square(D_logit_x_fake))\n",
" Dy_loss = tf.reduce_mean(tf.square(D_logit_y_real-1)) + tf.reduce_mean(tf.square(D_logit_y_fake))\n",
Expand All @@ -406,7 +396,6 @@
" Gx_loss = tf.reduce_mean(tf.square(D_logit_x_fake-1)) \n",
" Gy_loss = tf.reduce_mean(tf.square(D_logit_y_fake-1)) \n",
" \n",
" \n",
" cycle_loss = tf.reduce_mean(tf.abs(gen_cycle_image_x-x)) + tf.reduce_mean(tf.abs(gen_cycle_image_y-y))\n",
" \n",
" G_loss = Gx_loss + Gy_loss + 10*cycle_loss\n",
Expand Down Expand Up @@ -510,13 +499,9 @@
],
"source": [
"Saver = tf.train.Saver()\n",
"load_model = True\n",
"\n",
"if load_model == True:\n",
" ## For training\n",
"# Saver.restore(sess, tf.train.latest_checkpoint('./saved_models/'))\n",
" \n",
" ## For Test\n",
" Saver.restore(sess, tf.train.latest_checkpoint('./saved_models/Final_model/'))"
" Saver.restore(sess, load_path)"
]
},
{
Expand Down Expand Up @@ -699,9 +684,6 @@
" ax1[5].imshow(x_mask_cycle_test,cmap=plt.get_cmap('gray'))\n",
" ax1[5].axis('off')\n",
" ax1[5].set_title('Y Mask')\n",
" \n",
" \n",
" \n",
"\n",
" plt.show()\n"
]
Expand Down Expand Up @@ -896,7 +878,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
"version": "3.6.5"
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit c0b113f

Please sign in to comment.