Skip to content

Commit

Permalink
Update the codes (Add load and save model)
Browse files Browse the repository at this point in the history
  • Loading branch information
Kyushik committed Aug 10, 2019
1 parent b33610b commit dbcc548
Show file tree
Hide file tree
Showing 13 changed files with 1,200 additions and 632 deletions.
90 changes: 67 additions & 23 deletions 01_VAE_MNIST.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@
"import numpy as np\n",
"import tensorflow as tf\n",
"import matplotlib.pyplot as plt\n",
"import random"
"import random\n",
"import datetime\n",
"import os"
]
},
{
Expand All @@ -54,6 +56,8 @@
"metadata": {},
"outputs": [],
"source": [
"algorithm = 'VAE'\n",
"\n",
"img_size = 28\n",
"data_size = img_size**2\n",
"\n",
Expand All @@ -65,7 +69,15 @@
"n_hidden = 256\n",
"n_latent = 128\n",
"\n",
"learning_rate = 1e-3"
"learning_rate = 1e-3\n",
"\n",
"date_time = datetime.datetime.now().strftime(\"%Y%m%d-%H-%M-%S\")\n",
"\n",
"load_model = False\n",
"train_model = True\n",
"\n",
"save_path = \"./saved_models/\" + date_time + \"_\" + algorithm\n",
"load_path = \"./saved_models/20190809-11-04-47_VAE/model/model\" "
]
},
{
Expand Down Expand Up @@ -282,6 +294,25 @@
"sess.run(init)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"Saver = tf.train.Saver()\n",
"\n",
"if load_model == True:\n",
" Saver.restore(sess, load_path)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -322,27 +353,28 @@
}
],
"source": [
"# Training\n",
"data_x = x_train\n",
"len_data = x_train.shape[0]\n",
"\n",
"for i in range(num_epoch):\n",
" # Shuffle the data \n",
" np.random.shuffle(data_x)\n",
" \n",
" # Making mini-batch\n",
" for j in range(0, len_data, batch_size):\n",
" if j + batch_size < len_data:\n",
" data_x_in = data_x[j : j + batch_size, :]\n",
" else:\n",
" data_x_in = data_x[j : len_data, :]\n",
" \n",
" # Run Optimizer!\n",
" _, loss_train = sess.run([train_step, loss], feed_dict = {x: data_x_in, keep_prob_encoder: 0.9, keep_prob_decoder: 0.9})\n",
"if train_model:\n",
" # Training\n",
" data_x = x_train\n",
" len_data = x_train.shape[0]\n",
"\n",
" for i in range(num_epoch):\n",
" # Shuffle the data \n",
" np.random.shuffle(data_x)\n",
"\n",
" print(\"Batch: {} / {}\".format(j, len_data), end=\"\\r\")\n",
" \n",
" print(\"Epoch: \" + str(i+1) + ' / ' + \"Loss: \" + str(loss_train))"
" # Making mini-batch\n",
" for j in range(0, len_data, batch_size):\n",
" if j + batch_size < len_data:\n",
" data_x_in = data_x[j : j + batch_size, :]\n",
" else:\n",
" data_x_in = data_x[j : len_data, :]\n",
"\n",
" # Run Optimizer!\n",
" _, loss_train = sess.run([train_step, loss], feed_dict = {x: data_x_in, keep_prob_encoder: 0.9, keep_prob_decoder: 0.9})\n",
"\n",
" print(\"Batch: {} / {}\".format(j, len_data), end=\"\\r\")\n",
"\n",
" print(\"Epoch: \" + str(i+1) + ' / ' + \"Loss: \" + str(loss_train))"
]
},
{
Expand Down Expand Up @@ -684,12 +716,24 @@
" plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Save Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
"source": [
"os.mkdir(save_path)\n",
"\n",
"Saver.save(sess, save_path + \"/model/model\")\n",
"print(\"Model is saved in {}\".format(save_path + \"/model/model\"))"
]
}
],
"metadata": {
Expand Down
124 changes: 84 additions & 40 deletions 02_GAN_MLP_MNIST.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@
"import numpy as np\n",
"import tensorflow as tf\n",
"import matplotlib.pyplot as plt\n",
"import random"
"import random\n",
"import datetime\n",
"import os"
]
},
{
Expand All @@ -54,6 +56,8 @@
"metadata": {},
"outputs": [],
"source": [
"algorithm = 'GAN'\n",
"\n",
"img_size = 28\n",
"data_size = img_size**2\n",
"\n",
Expand All @@ -66,7 +70,15 @@
"learning_rate_g = 2e-4\n",
"learning_rate_d = 2e-4\n",
"\n",
"show_result_epoch = 5"
"show_result_epoch = 5\n",
"\n",
"date_time = datetime.datetime.now().strftime(\"%Y%m%d-%H-%M-%S\")\n",
"\n",
"load_model = False\n",
"train_model = True\n",
"\n",
"save_path = \"./saved_models/\" + date_time + \"_\" + algorithm\n",
"load_path = \"./saved_models/20190809-11-04-47_CycleGAN_Horse2Zebra/model/model\" "
]
},
{
Expand Down Expand Up @@ -279,6 +291,25 @@
"sess.run(init)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"Saver = tf.train.Saver()\n",
"\n",
"if load_model == True:\n",
" Saver.restore(sess, load_path)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -1003,45 +1034,46 @@
}
],
"source": [
"# Training\n",
"data_x = x_train\n",
"len_data = x_train.shape[0]\n",
"if train_model:\n",
" # Training\n",
" data_x = x_train\n",
" len_data = x_train.shape[0]\n",
"\n",
"for i in range(num_epoch):\n",
" # Shuffle the data \n",
" np.random.shuffle(data_x)\n",
" \n",
" # Making mini-batch\n",
" for j in range(0, len_data, batch_size):\n",
" if j + batch_size < len_data:\n",
" data_x_in = data_x[j : j + batch_size, :]\n",
" else:\n",
" data_x_in = data_x[j : len_data, :]\n",
" \n",
" sampled_z = np.random.uniform(-1, 1, size=(data_x_in.shape[0] , n_latent))\n",
" \n",
" # Run Optimizer!\n",
" _, loss_d = sess.run([train_step_d, d_loss], feed_dict = {x: data_x_in, z: sampled_z})\n",
" _, loss_g = sess.run([train_step_g, g_loss], feed_dict = {x: data_x_in, z: sampled_z})\n",
" \n",
" print(\"Batch: {} / {}\".format(j, len_data), end=\"\\r\")\n",
" \n",
" # Print Progess\n",
" print(\"Epoch: {} / G Loss: {:.5f} / D Loss: {:.5f}\".format((i+1), loss_g, loss_d))\n",
" \n",
" # Show test images \n",
" z_test = np.random.uniform(-1, 1, size=(5, n_latent))\n",
" G_out = sess.run(G, feed_dict = {z: z_test})\n",
" output_reshape = np.reshape(G_out, [5, img_size, img_size])\n",
" \n",
" if i == 0 or (i+1) % show_result_epoch == 0:\n",
" f, ax = plt.subplots(1,5)\n",
" for j in range(5):\n",
" ax[j].imshow(output_reshape[j,:,:], cmap = 'gray')\n",
" ax[j].axis('off')\n",
" ax[j].set_title('Image '+str(j))\n",
" for i in range(num_epoch):\n",
" # Shuffle the data \n",
" np.random.shuffle(data_x)\n",
"\n",
" # Making mini-batch\n",
" for j in range(0, len_data, batch_size):\n",
" if j + batch_size < len_data:\n",
" data_x_in = data_x[j : j + batch_size, :]\n",
" else:\n",
" data_x_in = data_x[j : len_data, :]\n",
"\n",
" sampled_z = np.random.uniform(-1, 1, size=(data_x_in.shape[0] , n_latent))\n",
"\n",
" # Run Optimizer!\n",
" _, loss_d = sess.run([train_step_d, d_loss], feed_dict = {x: data_x_in, z: sampled_z})\n",
" _, loss_g = sess.run([train_step_g, g_loss], feed_dict = {x: data_x_in, z: sampled_z})\n",
"\n",
" print(\"Batch: {} / {}\".format(j, len_data), end=\"\\r\")\n",
"\n",
" # Print Progess\n",
" print(\"Epoch: {} / G Loss: {:.5f} / D Loss: {:.5f}\".format((i+1), loss_g, loss_d))\n",
"\n",
" # Show test images \n",
" z_test = np.random.uniform(-1, 1, size=(5, n_latent))\n",
" G_out = sess.run(G, feed_dict = {z: z_test})\n",
" output_reshape = np.reshape(G_out, [5, img_size, img_size])\n",
"\n",
" plt.show()"
" if i == 0 or (i+1) % show_result_epoch == 0:\n",
" f, ax = plt.subplots(1,5)\n",
" for j in range(5):\n",
" ax[j].imshow(output_reshape[j,:,:], cmap = 'gray')\n",
" ax[j].axis('off')\n",
" ax[j].set_title('Image '+str(j))\n",
"\n",
" plt.show()"
]
},
{
Expand Down Expand Up @@ -1090,12 +1122,24 @@
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Save Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
"source": [
"os.mkdir(save_path)\n",
"\n",
"Saver.save(sess, save_path + \"/model/model\")\n",
"print(\"Model is saved in {}\".format(save_path + \"/model/model\"))"
]
}
],
"metadata": {
Expand Down
Loading

0 comments on commit dbcc548

Please sign in to comment.