|
| 1 | +import sys |
| 2 | +import time |
| 3 | +import shutil |
| 4 | +import numpy as np |
| 5 | + |
| 6 | +import os |
| 7 | +import cv2 |
| 8 | + |
| 9 | +os.environ["PATH"] += ";D:/CUDA/v8.0/bin;" |
| 10 | +import tensorflow as tf |
| 11 | + |
| 12 | +os.environ["TF_CPP_MIN_LOG_LEVEL"] = '1' # low the warning level |
| 13 | + |
| 14 | +""" |
| 15 | +Source: Aymeric Damien: cnn_mnist.py |
| 16 | + https://github.com/aymericdamien/TensorFlow-Examples/ |
| 17 | +Modify: Yonv1943 2018-07-13 13:30:40 |
| 18 | +
|
| 19 | +2018-07-13 Stable, complete |
| 20 | +2018-07-13 Add TensorBoard GRAPHS HISTOGRAM |
| 21 | +2018-07-14 Add two dropout layer, lift accuracy to 99.8%>= in test_set |
| 22 | +2018-07-14 Remove accuracy from TensorFLow Calculate |
| 23 | +2018-07-14 Change to three layers network, not softmax |
| 24 | +2018-07-14 Change to one layer network |
| 25 | +""" |
| 26 | + |
| 27 | + |
| 28 | +class Global(object): # Global Variables |
| 29 | + batch_size = 500 |
| 30 | + batch_epoch = 55000 // batch_size # mnist train data is 55000 |
| 31 | + train_epoch = 2 ** 5 # accuracy in test_set nearly 90%, 15s, (Intel i3-3110M, GTX 720M) |
| 32 | + |
| 33 | + data_dir = 'MNIST_data' |
| 34 | + txt_path = 'tf_training_info.txt' |
| 35 | + |
| 36 | + model_save_dir = 'mnist_model' |
| 37 | + model_save_name = 'mnist_model' |
| 38 | + model_save_path = os.path.join(model_save_dir, model_save_name) |
| 39 | + |
| 40 | + |
| 41 | +G = Global() |
| 42 | + |
| 43 | + |
| 44 | +def get_mnist_data(data_dir='MNIST_data'): |
| 45 | + from tensorflow.examples.tutorials.mnist import input_data |
| 46 | + mnist = input_data.read_data_sets(data_dir, one_hot=True) |
| 47 | + |
| 48 | + train_image = mnist.train.images |
| 49 | + train_label = mnist.train.labels |
| 50 | + |
| 51 | + train_image = train_image[:G.batch_epoch * G.batch_size] |
| 52 | + train_label = train_label[:G.batch_epoch * G.batch_size] |
| 53 | + |
| 54 | + test_image = mnist.test.images |
| 55 | + test_label = mnist.test.labels |
| 56 | + |
| 57 | + data_para = (train_image, train_label, test_image, test_label) |
| 58 | + data_para = [np.array(ary, dtype=np.float32) for ary in data_para] |
| 59 | + return data_para |
| 60 | + |
| 61 | + |
| 62 | +def init_session(): |
| 63 | + image = tf.placeholder(tf.float32, [None, 784], name='Input') # img: 28x28 |
| 64 | + label = tf.placeholder(tf.float32, [None, 10], name='Label') # 0~9 == 10 classes |
| 65 | + |
| 66 | + w1 = tf.get_variable(shape=[784, 10], name='Weights1') |
| 67 | + b1 = tf.get_variable(shape=[10], name='Bias1') |
| 68 | + |
| 69 | + pred = tf.nn.softmax(tf.matmul(image, w1) + b1) |
| 70 | + loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=pred, labels=label)) # high accuracy |
| 71 | + optimizer = tf.train.AdamOptimizer().minimize(loss) |
| 72 | + |
| 73 | + sess_para = (image, label, pred, loss, optimizer) |
| 74 | + return sess_para |
| 75 | + |
| 76 | + |
| 77 | +def train_session(sess_para, data_para): |
| 78 | + (train_image, train_label, test_image, test_label) = data_para |
| 79 | + (image, label, pred, loss, optimizer) = sess_para |
| 80 | + |
| 81 | + shutil.rmtree(G.model_save_dir, ignore_errors=True) |
| 82 | + logs = open(G.txt_path, 'a') |
| 83 | + sess = tf.Session() |
| 84 | + sess.run(tf.global_variables_initializer()) |
| 85 | + |
| 86 | + '''train loop init''' |
| 87 | + predict, summary, feed_train_label = None, None, None |
| 88 | + time0 = time1 = time.time() |
| 89 | + print('|Train_epoch: %d |batch: epoch*size" %dx%d' % (G.train_epoch, G.batch_epoch, G.batch_size)) |
| 90 | + for train_epoch in range(G.train_epoch): |
| 91 | + loss_sum = 0.0 |
| 92 | + for i in range(G.batch_epoch): |
| 93 | + j = i * G.batch_size |
| 94 | + feed_train_image = train_image[j: j + G.batch_size] |
| 95 | + feed_train_label = train_label[j: j + G.batch_size] |
| 96 | + |
| 97 | + feed_dict = {image: feed_train_image, label: feed_train_label} |
| 98 | + predict, loss_batch, _ = sess.run([pred, loss, optimizer], feed_dict) |
| 99 | + |
| 100 | + loss_sum += loss_batch |
| 101 | + (print(end='='), sys.stdout.flush()) if i % (G.batch_epoch // 16 + 1) == 0 else None |
| 102 | + |
| 103 | + ave_cost = loss_sum / G.batch_epoch |
| 104 | + logs.write('%e\n' % ave_cost) |
| 105 | + |
| 106 | + accuracy = np.average(np.equal(np.argmax(predict, 1), np.argmax(feed_train_label, 1))) |
| 107 | + |
| 108 | + time2 = time.time() |
| 109 | + print(end="\n|Time: %4.1f|%2d |Loss: %.2e |Inac: %.2e |" |
| 110 | + % (time2 - time1, train_epoch, ave_cost, 1 - accuracy)) |
| 111 | + time1 = time2 |
| 112 | + print() |
| 113 | + print('|Time: %.2f |epoch_batch: %d_%dx%d ' % (time.time() - time0, G.train_epoch, G.batch_epoch, G.batch_size)) |
| 114 | + |
| 115 | + '''save''' |
| 116 | + os.makedirs(G.model_save_dir) |
| 117 | + tf.train.Saver().save(sess, G.model_save_path), print('|model save in:', G.model_save_path) |
| 118 | + draw_plot(G.txt_path) |
| 119 | + |
| 120 | + sess.close() |
| 121 | + logs.close() |
| 122 | + |
| 123 | + |
| 124 | +def eval_session(sess_para, data_para): |
| 125 | + (train_image, train_label, test_image, test_label) = data_para |
| 126 | + (image, label, pred, loss, optimizer) = sess_para |
| 127 | + |
| 128 | + sess = tf.Session() |
| 129 | + tf.train.Saver().restore(sess, G.model_save_path) |
| 130 | + '''evaluation''' |
| 131 | + for print_info, feed_image, feed_label in [ |
| 132 | + ['Train_set', train_image[:len(test_image)], train_label[:len(test_label)]], |
| 133 | + ['Test_set ', test_image, test_label], |
| 134 | + ]: |
| 135 | + feed_dict = {image: feed_image, label: feed_label} |
| 136 | + predicts = pred.eval(feed_dict, session=sess) |
| 137 | + accuracy = np.average(np.equal(np.argmax(predicts, 1), np.argmax(feed_label, 1))) |
| 138 | + inaccuracy = 1.0 - accuracy |
| 139 | + print("|%s |Accuracy: %2.4f%% |Inaccuracy: %.2e" % (print_info, accuracy * 100, inaccuracy)) |
| 140 | + sess.close() |
| 141 | + |
| 142 | + |
| 143 | +def real_time_session(sess_para, window_name='cv2_mouse_paint', size=16): |
| 144 | + (image, label, pred, loss, optimizer) = sess_para |
| 145 | + |
| 146 | + feed_dict = dict() |
| 147 | + feed_dict[image] = np.array([]) |
| 148 | + |
| 149 | + sess = tf.Session() |
| 150 | + tf.train.Saver().restore(sess, G.model_save_path) |
| 151 | + |
| 152 | + def paint_brush(event, x, y, flags, param): # mouse callback function |
| 153 | + global ix, iy, drawing |
| 154 | + |
| 155 | + if event == cv2.EVENT_LBUTTONDOWN: |
| 156 | + ix, iy = x, y |
| 157 | + drawing = True |
| 158 | + elif event == cv2.EVENT_MOUSEMOVE and 'drawing' in globals(): |
| 159 | + cv2.line(img, (ix, iy), (x, y), 255, size) |
| 160 | + ix, iy = x, y |
| 161 | + |
| 162 | + '''hand-writing recognize''' |
| 163 | + cv2.rectangle(img, (0, 0), (img.shape[1], 64), 0, -1) |
| 164 | + input_image = (np.reshape(cv2.resize(img, (28, 28)), (1, 784)) / 256.0).astype(np.float32) |
| 165 | + feed_dict[image] = input_image |
| 166 | + predicts = pred.eval(feed_dict, session=sess) |
| 167 | + predict = np.argsort(predicts[0])[::-1] |
| 168 | + |
| 169 | + cv2.putText(img, str(predict[0]), (16, 55), cv2.FONT_HERSHEY_SIMPLEX, 2.0, 255, 1, cv2.LINE_AA) |
| 170 | + cv2.putText(img, str(predict[1:]), (64, 48), cv2.FONT_HERSHEY_SIMPLEX, 1.0, 255, 1, cv2.LINE_AA) |
| 171 | + elif event == cv2.EVENT_LBUTTONUP: |
| 172 | + del drawing |
| 173 | + elif event == cv2.EVENT_RBUTTONDOWN: |
| 174 | + cv2.rectangle(img, (0, 0), (28 * size, 28 * size), 0, -1) |
| 175 | + |
| 176 | + img = np.zeros((28 * size, 28 * size), np.uint8) |
| 177 | + cv2.namedWindow(window_name) |
| 178 | + cv2.setMouseCallback(window_name, paint_brush) |
| 179 | + |
| 180 | + not_break = True |
| 181 | + while not_break: |
| 182 | + cv2.imshow(window_name, img) |
| 183 | + k = cv2.waitKey(1) & 0xFF |
| 184 | + img = np.zeros((28 * size, 28 * size), np.uint8) if k == 8 else img # redraw |
| 185 | + not_break = not bool(k == 13 or k == 27) # quit(press Esc or Backspace) |
| 186 | + cv2.destroyWindow(window_name) |
| 187 | + |
| 188 | + sess.close() |
| 189 | + |
| 190 | + |
| 191 | +def draw_plot(ary_path): |
| 192 | + import matplotlib.pyplot as plt |
| 193 | + |
| 194 | + ary = np.loadtxt(ary_path) |
| 195 | + |
| 196 | + x_pts = [i for i in range(ary.shape[0])] |
| 197 | + y_pts = ary |
| 198 | + plt.plot(x_pts, y_pts, linestyle='dashed', marker='x', markersize=3) |
| 199 | + plt.show(1.943) |
| 200 | + |
| 201 | + |
| 202 | +def mouse_paint(window_name='cv2_mouse_paint', size=16): |
| 203 | + def paint_brush(event, x, y, flags, param): # mouse callback function |
| 204 | + global ix, iy, drawing |
| 205 | + |
| 206 | + if event == cv2.EVENT_LBUTTONDOWN: |
| 207 | + ix, iy = x, y |
| 208 | + drawing = True |
| 209 | + elif event == cv2.EVENT_MOUSEMOVE and 'drawing' in globals(): |
| 210 | + # 'var_name' in globals; learning from: https://stackoverflow.com/a/1592581/9293137 |
| 211 | + cv2.line(img, (ix, iy), (x, y), 255, size) |
| 212 | + ix, iy = x, y |
| 213 | + elif event == cv2.EVENT_LBUTTONUP: |
| 214 | + del drawing |
| 215 | + |
| 216 | + img = np.zeros((28 * size, 28 * size), np.uint8) |
| 217 | + cv2.namedWindow(window_name) |
| 218 | + cv2.setMouseCallback(window_name, paint_brush) |
| 219 | + |
| 220 | + not_break = True |
| 221 | + while not_break: |
| 222 | + cv2.imshow(window_name, img) |
| 223 | + k = cv2.waitKey(1) & 0xFF |
| 224 | + img = np.zeros((28 * size, 28 * size), np.uint8) if k == 8 else img # redraw |
| 225 | + not_break = not bool(k == 13 or k == 27) # quit(press Esc or Backspace) |
| 226 | + cv2.destroyWindow(window_name) |
| 227 | + return img |
| 228 | + |
| 229 | + |
| 230 | +def run(): |
| 231 | + # data_para = get_mnist_data(G.data_dir) |
| 232 | + sess_para = init_session() |
| 233 | + |
| 234 | + # train_session(sess_para, data_para) |
| 235 | + # eval_session(sess_para, data_para) |
| 236 | + real_time_session(sess_para) |
| 237 | + |
| 238 | + |
| 239 | +if __name__ == '__main__': |
| 240 | + run() |
0 commit comments