Skip to content

Commit

Permalink
added to save best cells
Browse files Browse the repository at this point in the history
  • Loading branch information
shibuiwilliam authored Aug 26, 2018
1 parent b8dabce commit ee44313
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 97 deletions.
31 changes: 28 additions & 3 deletions ENAS.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import os
import csv
import pickle
import sys
import shutil
import gc
Expand Down Expand Up @@ -220,8 +221,9 @@ def write_record(self, epoch, lr, reward, val_loss):
with open(record_file, "a") as f:
writer = csv.writer(f, lineterminator='\n')
if not os.path.exists(record_file):
writer.writerow(["epoch", "lr", "reward", "val_loss"])
writer.writerow([epoch, lr, reward, val_loss])
writer.writerow(["epoch", "lr", "reward", "val_loss", "best_val_acc"])
writer.writerow([epoch, lr, reward, val_loss, self.best_val_acc])
print("saved records so far")

def read_record(self):
record_file = "{0}_record.csv".format(self.child_network_name)
Expand All @@ -231,15 +233,36 @@ def read_record(self):
reader = csv.reader(f)
for row in reader:
rec.append(row)
print("loaded records")
return rec
else:
return None

def save_best_cell(self):
normal_cell_file = "{0}_normal_cell.pkl".format(self.child_network_name)
with open(normal_cell_file, "wb") as f:
pickle.dump(self.best_normal_cell, f)
reduction_cell_file = "{0}_reduction_cell.pkl".format(self.child_network_name)
with open(reduction_cell_file, "wb") as f:
pickle.dump(self.best_reduction_cell, f)
print("saved best cells")

def load_best_cell(self):
normal_cell_file = "{0}_normal_cell.pkl".format(self.child_network_name)
with open(normal_cell_file, "rb") as f:
self.best_normal_cell = pickle.load(f)
reduction_cell_file = "{0}_reduction_cell.pkl".format(self.child_network_name)
with open(reduction_cell_file, "rb") as f:
self.best_reduction_cell = pickle.load(f)
print("loaded best cells")

def search_neural_architecture(self):
if self.start_from_record:
rec = self.read_record()
if rec is not None:
starting_epoch = int(rec[-1][0]) + 1
self.best_val_acc = float(rec[-1][4])
self.load_best_cell()
else:
starting_epoch = 0
for e in range(starting_epoch, self.child_epochs):
Expand Down Expand Up @@ -274,12 +297,14 @@ def search_neural_architecture(self):
val_acc = CNC.evaluate_child_network(x_val_batch, y_val_batch)
print(val_acc)
self.reward = val_acc[1]
self.write_record(e, self.child_lr_scedule[e], self.reward, val_acc[0])

if self.best_val_acc < val_acc[1]:
self.best_val_acc = val_acc[1]
self.best_normal_cell = sample_cell["normal_cell"]
self.best_reduction_cell = sample_cell["reduction_cell"]

self.write_record(e, self.child_lr_scedule[e], self.reward, val_acc[0])
self.save_best_cell()

child_train_record = {}
child_train_record["normal_cell"] = sample_cell["normal_cell"]
Expand Down
104 changes: 10 additions & 94 deletions ENAS_Keras_MNIST.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {
"colab": {
"autoexec": {
Expand All @@ -27,17 +27,7 @@
"id": "L8PLWMnE1e_5",
"outputId": "a22a2442-0ba3-4cfd-90d9-8d2103478b64"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/anaconda3/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
" from ._conv import register_converters as _register_converters\n",
"Using TensorFlow backend.\n"
]
}
],
"outputs": [],
"source": [
"import numpy as np\n",
"import os\n",
Expand Down Expand Up @@ -65,7 +55,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -89,7 +79,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {
"colab": {
"autoexec": {
Expand All @@ -107,22 +97,9 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"x_train shape: (12665, 28, 28, 1)\n",
"x_test shape: (2115, 28, 28, 1)\n",
"y_train shape: (12665, 2)\n",
"y_test shape: (2115, 2)\n",
"12665 train samples\n",
"2115 test samples\n"
]
}
],
"outputs": [],
"source": [
"child_classes = 2\n",
"\n",
Expand Down Expand Up @@ -164,7 +141,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -180,7 +157,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -190,7 +167,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {
"scrolled": true
},
Expand Down Expand Up @@ -247,68 +224,7 @@
"id": "H-aUGKEOBnff",
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"SEARCH EPOCH: 1 / 150\n",
"normal_cell: {2: {'L': {'input_layer': 1, 'oper_id': 4}, 'R': {'input_layer': 0, 'oper_id': 4}}, 3: {'L': {'input_layer': 2, 'oper_id': 2}, 'R': {'input_layer': 1, 'oper_id': 0}}, 4: {'L': {'input_layer': 0, 'oper_id': 0}, 'R': {'input_layer': 0, 'oper_id': 0}}, 5: {'L': {'input_layer': 2, 'oper_id': 1}, 'R': {'input_layer': 4, 'oper_id': 4}}}\n",
"reduction_cell: {2: {'L': {'input_layer': 0, 'oper_id': 2}, 'R': {'input_layer': 1, 'oper_id': 4}}, 3: {'L': {'input_layer': 1, 'oper_id': 0}, 'R': {'input_layer': 0, 'oper_id': 2}}, 4: {'L': {'input_layer': 1, 'oper_id': 3}, 'R': {'input_layer': 3, 'oper_id': 3}}, 5: {'L': {'input_layer': 2, 'oper_id': 1}, 'R': {'input_layer': 3, 'oper_id': 4}}}\n",
"Epoch 1/1\n",
"12665/12665 [==============================] - 46s 4ms/step - loss: 2.6787 - acc: 0.8779\n",
"keeping weight: conv2d_normal_1_2_1_28x28x1_28x28x64_128\n",
"keeping weight: conv2d_normal_1_2_0_28x28x1_28x28x64_128\n",
"keeping weight: sepconv3x3_sepconv2d_normal_1_4_0_28x28x1_28x28x64_137\n",
"keeping weight: sepconv3x3_sepconv2d_normal_1_4_0_28x28x1_28x28x64_137\n",
"keeping weight: sepconv3x3_sepconv2d_normal_1_3_1_28x28x1_28x28x64_137\n",
"keeping weight: sepconv5x5_sepconv2d_normal_1_5_2_28x28x64_28x28x64_5760\n",
"keeping weight: sepconv3x3_sepconv2d_normal_1_4_0_28x28x64_28x28x64_4736\n",
"keeping weight: sepconv3x3_sepconv2d_normal_1_4_0_28x28x64_28x28x64_4736\n",
"keeping weight: sepconv3x3_sepconv2d_normal_1_3_1_28x28x64_28x28x64_4736\n",
"keeping weight: sepconv5x5_sepconv2d_normal_1_5_2_28x28x64_28x28x64_5760\n",
"keeping weight: conv2d_reduction_1_2_2_14x14x128_14x14x64_8256\n",
"keeping weight: conv2d_reduction_1_2_2_14x14x128_14x14x64_8256\n",
"keeping weight: conv2d_reduction_1_2_0_14x14x1_14x14x128_256\n",
"keeping weight: sepconv3x3_sepconv2d_reduction_1_3_1_28x28x128_14x14x128_17664\n",
"keeping weight: sepconv5x5_sepconv2d_reduction_1_5_2_14x14x128_14x14x128_19712\n",
"keeping weight: sepconv3x3_sepconv2d_reduction_1_3_1_14x14x128_14x14x128_17664\n",
"keeping weight: conv2d_reduction_1_3_0_14x14x1_14x14x128_256\n",
"keeping weight: sepconv5x5_sepconv2d_reduction_1_5_2_14x14x128_14x14x128_19712\n",
"keeping weight: sepconv3x3_sepconv2d_normal_2_4_0_28x28x128_28x28x128_17664\n",
"keeping weight: sepconv3x3_sepconv2d_normal_2_4_0_28x28x128_28x28x128_17664\n",
"keeping weight: conv2d_normal_2_2_2_14x14x128_14x14x64_8256\n",
"keeping weight: conv2d_normal_2_2_2_14x14x128_14x14x64_8256\n",
"keeping weight: sepconv3x3_sepconv2d_normal_2_4_0_28x28x128_28x28x128_17664\n",
"keeping weight: sepconv3x3_sepconv2d_normal_2_4_0_28x28x128_28x28x128_17664\n",
"keeping weight: conv2d_normal_2_2_1_14x14x256_14x14x128_32896\n",
"keeping weight: sepconv3x3_sepconv2d_normal_2_3_1_14x14x256_14x14x128_35200\n",
"keeping weight: sepconv5x5_sepconv2d_normal_2_5_2_14x14x128_14x14x128_19712\n",
"keeping weight: conv2d_normal_2_5_5_14x14x128_14x14x64_8256\n",
"keeping weight: conv2d_normal_2_5_5_14x14x128_14x14x64_8256\n",
"keeping weight: sepconv3x3_sepconv2d_normal_2_3_1_14x14x128_14x14x128_17664\n",
"keeping weight: sepconv5x5_sepconv2d_normal_2_5_2_14x14x128_14x14x128_19712\n",
"keeping weight: conv2d_reduction_2_2_2_7x7x256_7x7x128_32896\n",
"keeping weight: conv2d_reduction_2_2_2_7x7x256_7x7x128_32896\n",
"keeping weight: sepconv3x3_sepconv2d_reduction_2_3_1_14x14x256_7x7x256_68096\n",
"keeping weight: sepconv5x5_sepconv2d_reduction_2_5_2_7x7x256_7x7x256_72192\n",
"keeping weight: sepconv3x3_sepconv2d_reduction_2_3_1_7x7x256_7x7x256_68096\n",
"keeping weight: sepconv5x5_sepconv2d_reduction_2_5_2_7x7x256_7x7x256_72192\n",
"keeping weight: dense_fixed_999_999_999_512_2_1026\n",
"128/128 [==============================] - 1s 11ms/step\n",
"[5.546589970588684, 0.65625]\n",
"epoch: 1\n",
"record: \n",
"normal_cell: {2: {'L': {'input_layer': 1, 'oper_id': 4}, 'R': {'input_layer': 0, 'oper_id': 4}}, 3: {'L': {'input_layer': 2, 'oper_id': 2}, 'R': {'input_layer': 1, 'oper_id': 0}}, 4: {'L': {'input_layer': 0, 'oper_id': 0}, 'R': {'input_layer': 0, 'oper_id': 0}}, 5: {'L': {'input_layer': 2, 'oper_id': 1}, 'R': {'input_layer': 4, 'oper_id': 4}}}\n",
"reduction_cell: {2: {'L': {'input_layer': 0, 'oper_id': 2}, 'R': {'input_layer': 1, 'oper_id': 4}}, 3: {'L': {'input_layer': 1, 'oper_id': 0}, 'R': {'input_layer': 0, 'oper_id': 2}}, 4: {'L': {'input_layer': 1, 'oper_id': 3}, 'R': {'input_layer': 3, 'oper_id': 3}}, 5: {'L': {'input_layer': 2, 'oper_id': 1}, 'R': {'input_layer': 3, 'oper_id': 4}}}\n",
"val_loss: 5.546589970588684\n",
"reward: 0.65625\n",
"---------- train controller rnn ----------\n",
"---------- training normalcontroller ----------\n",
"---------- training reductioncontroller ----------\n"
]
}
],
"outputs": [],
"source": [
"ENAS.search_neural_architecture()"
]
Expand Down

0 comments on commit ee44313

Please sign in to comment.