forked from martin-gorner/tensorflow-mnist-tutorial
-
Notifications
You must be signed in to change notification settings - Fork 0
/
mnist_4.0_batchnorm_five_layers_sigmoid.py
212 lines (175 loc) · 9.78 KB
/
mnist_4.0_batchnorm_five_layers_sigmoid.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
# encoding: UTF-8
# Copyright 2016 Google.com
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
import tensorflowvisu
import math
from tensorflow.examples.tutorials.mnist import input_data as mnist_data
print("Tensorflow version " + tf.__version__)
tf.set_random_seed(0)
# neural network with 5 layers
#
# · · · · · · · · · · (input data, flattened pixels) X [batch, 784] # 784 = 28*28
# \x/x\x/x\x/x\x/x\x/ -- fully connected layer (sigmoid+BN) W1 [784, 200] B1[200]
# · · · · · · · · · Y1 [batch, 200]
# \x/x\x/x\x/x\x/ -- fully connected layer (sigmoid+BN) W2 [200, 100] B2[100]
# · · · · · · · Y2 [batch, 100]
# \x/x\x/x\x/ -- fully connected layer (sigmoid+BN) W3 [100, 60] B3[60]
# · · · · · Y3 [batch, 60]
# \x/x\x/ -- fully connected layer (sigmoid+BN) W4 [60, 30] B4[30]
# · · · Y4 [batch, 30]
# \x/ -- fully connected layer (softmax+BN) W5 [30, 10] B5[10]
# · Y5 [batch, 10]
# Download images and labels into mnist.test (10K images+labels) and mnist.train (60K images+labels)
mnist = mnist_data.read_data_sets("data", one_hot=True, reshape=False, validation_size=0)
# input X: 28x28 grayscale images, the first dimension (None) will index the images in the mini-batch
X = tf.placeholder(tf.float32, [None, 28, 28, 1])
# correct answers will go here
Y_ = tf.placeholder(tf.float32, [None, 10])
# variable learning rate
lr = tf.placeholder(tf.float32)
# train/test selector for batch normalisation
tst = tf.placeholder(tf.bool)
# training iteration
iter = tf.placeholder(tf.int32)
# five layers and their number of neurons (tha last layer has 10 softmax neurons)
L = 200
M = 100
N = 60
P = 30
Q = 10
# Weights initialised with small random values between -0.2 and +0.2
# When using RELUs, make sure biases are initialised with small *positive* values for example 0.1 = tf.ones([K])/10
W1 = tf.Variable(tf.truncated_normal([784, L], stddev=0.1)) # 784 = 28 * 28
S1 = tf.Variable(tf.ones([L]))
O1 = tf.Variable(tf.zeros([L]))
W2 = tf.Variable(tf.truncated_normal([L, M], stddev=0.1))
S2 = tf.Variable(tf.ones([M]))
O2 = tf.Variable(tf.zeros([M]))
W3 = tf.Variable(tf.truncated_normal([M, N], stddev=0.1))
S3 = tf.Variable(tf.ones([N]))
O3 = tf.Variable(tf.zeros([N]))
W4 = tf.Variable(tf.truncated_normal([N, P], stddev=0.1))
S4 = tf.Variable(tf.ones([P]))
O4 = tf.Variable(tf.zeros([P]))
W5 = tf.Variable(tf.truncated_normal([P, Q], stddev=0.1))
B5 = tf.Variable(tf.zeros([Q]))
## Batch normalisation conclusions with sigmoid activation function:
# BN is applied between logits and the activation function
# On Sigmoids it is very clear that without BN, the sigmoids saturate, with BN, they output
# a clean gaussian distribution of values, especially with high initial learning rates.
# sigmoid, no batch-norm, lr(0.003, 0.0001, 2000) => 97.5%
# sigmoid, batch-norm lr(0.03, 0.0001, 1000) => 98%
# sigmoid, batch-norm, no offsets => 97.3%
# sigmoid, batch-norm, no scales => 98.1% but cannot hold fast learning rate at start
# sigmoid, batch-norm, no scales, no offsets => 96%
# Both scales and offsets are useful with sigmoids.
# With RELUs, the scale variables can be omitted.
# Biases are not useful with batch norm, offsets are to be used instead
# Steady 98.5% accuracy using these parameters:
# moving average decay: 0.998 (equivalent to averaging over two epochs)
# learning rate decay from 0.03 to 0.0001 speed 1000 => max 98.59 at 6500 iterations, 98.54 at 10K it, 98% at 1300it, 98.5% at 3200it
def batchnorm(Ylogits, Offset, Scale, is_test, iteration):
exp_moving_avg = tf.train.ExponentialMovingAverage(0.998, iteration) # adding the iteration prevents from averaging across non-existing iterations
bnepsilon = 1e-5
mean, variance = tf.nn.moments(Ylogits, [0])
update_moving_averages = exp_moving_avg.apply([mean, variance])
m = tf.cond(is_test, lambda: exp_moving_avg.average(mean), lambda: mean)
v = tf.cond(is_test, lambda: exp_moving_avg.average(variance), lambda: variance)
Ybn = tf.nn.batch_normalization(Ylogits, m, v, Offset, Scale, bnepsilon)
return Ybn, update_moving_averages
def no_batchnorm(Ylogits, Offset, Scale, is_test, iteration):
return Ylogits, tf.no_op()
# The model
XX = tf.reshape(X, [-1, 784])
Y1l = tf.matmul(XX, W1)
Y1bn, update_ema1 = batchnorm(Y1l, O1, S1, tst, iter)
Y1 = tf.nn.sigmoid(Y1bn)
Y2l = tf.matmul(Y1, W2)
Y2bn, update_ema2 = batchnorm(Y2l, O2, S2, tst, iter)
Y2 = tf.nn.sigmoid(Y2bn)
Y3l = tf.matmul(Y2, W3)
Y3bn, update_ema3 = batchnorm(Y3l, O3, S3, tst, iter)
Y3 = tf.nn.sigmoid(Y3bn)
Y4l = tf.matmul(Y3, W4)
Y4bn, update_ema4 = batchnorm(Y4l, O4, S4, tst, iter)
Y4 = tf.nn.sigmoid(Y4bn)
Ylogits = tf.matmul(Y4, W5) + B5
Y = tf.nn.softmax(Ylogits)
update_ema = tf.group(update_ema1, update_ema2, update_ema3, update_ema4)
# cross-entropy loss function (= -sum(Y_i * log(Yi)) ), normalised for batches of 100 images
# TensorFlow provides the softmax_cross_entropy_with_logits function to avoid numerical stability
# problems with log(0) which is NaN
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=Ylogits, labels=Y_)
cross_entropy = tf.reduce_mean(cross_entropy)*100
# accuracy of the trained model, between 0 (worst) and 1 (best)
correct_prediction = tf.equal(tf.argmax(Y, 1), tf.argmax(Y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
# matplotlib visualisation
allweights = tf.concat([tf.reshape(W1, [-1]), tf.reshape(W2, [-1]), tf.reshape(W3, [-1]), tf.reshape(W4, [-1]), tf.reshape(W5, [-1])], 0)
allbiases = tf.concat([tf.reshape(O1, [-1]), tf.reshape(O2, [-1]), tf.reshape(O3, [-1]), tf.reshape(O4, [-1]), tf.reshape(B5, [-1])], 0)
# to use for sigmoid
allactivations = tf.concat([tf.reshape(Y1, [-1]), tf.reshape(Y2, [-1]), tf.reshape(Y3, [-1]), tf.reshape(Y4, [-1])], 0)
# to use for RELU
#allactivations = tf.concat([tf.reduce_max(Y1, [0]), tf.reduce_max(Y2, [0]), tf.reduce_max(Y3, [0]), tf.reduce_max(Y4, [0])], 0)
alllogits = tf.concat([tf.reshape(Y1l, [-1]), tf.reshape(Y2l, [-1]), tf.reshape(Y3l, [-1]), tf.reshape(Y4l, [-1])], 0)
I = tensorflowvisu.tf_format_mnist_images(X, Y, Y_)
It = tensorflowvisu.tf_format_mnist_images(X, Y, Y_, 1000, lines=25)
datavis = tensorflowvisu.MnistDataVis(title4="Logits", title5="activations", histogram4colornum=2, histogram5colornum=2)
# training step, the learning rate is a placeholder
train_step = tf.train.AdamOptimizer(lr).minimize(cross_entropy)
# init
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
# You can call this function in a loop to train the model, 100 images at a time
def training_step(i, update_test_data, update_train_data):
# training on batches of 100 images with 100 labels
batch_X, batch_Y = mnist.train.next_batch(100)
# learning rate decay (without batch norm)
#max_learning_rate = 0.003
#min_learning_rate = 0.0001
#decay_speed = 2000
# learning rate decay (with batch norm)
max_learning_rate = 0.03
min_learning_rate = 0.0001
decay_speed = 1000.0
learning_rate = min_learning_rate + (max_learning_rate - min_learning_rate) * math.exp(-i/decay_speed)
# compute training values for visualisation
if update_train_data:
a, c, im, al, ac = sess.run([accuracy, cross_entropy, I, alllogits, allactivations], {X: batch_X, Y_: batch_Y, tst: False})
print(str(i) + ": accuracy:" + str(a) + " loss: " + str(c) + " (lr:" + str(learning_rate) + ")")
datavis.append_training_curves_data(i, a, c)
datavis.update_image1(im)
datavis.append_data_histograms(i, al, ac)
# compute test values for visualisation
if update_test_data:
a, c, im = sess.run([accuracy, cross_entropy, It], {X: mnist.test.images, Y_: mnist.test.labels, tst: True})
print(str(i) + ": ********* epoch " + str(i*100//mnist.train.images.shape[0]+1) + " ********* test accuracy:" + str(a) + " test loss: " + str(c))
datavis.append_test_curves_data(i, a, c)
datavis.update_image2(im)
# the backpropagation training step
sess.run(train_step, {X: batch_X, Y_: batch_Y, lr: learning_rate, tst: False})
sess.run(update_ema, {X: batch_X, Y_: batch_Y, tst: False, iter: i})
datavis.animate(training_step, iterations=10000+1, train_data_update_freq=20, test_data_update_freq=100, more_tests_at_start=True)
# to save the animation as a movie, add save_movie=True as an argument to datavis.animate
# to disable the visualisation use the following line instead of the datavis.animate line
# for i in range(10000+1): training_step(i, i % 100 == 0, i % 20 == 0)
print("max test accuracy: " + str(datavis.get_max_test_accuracy()))
# Some results to expect:
# (In all runs, if sigmoids are used, all biases are initialised at 0, if RELUs are used,
# all biases are initialised at 0.1 apart from the last one which is initialised at 0.)
## decaying learning rate from 0.003 to 0.0001 decay_speed 2000, 10K iterations
# final test accuracy = 0.9813 (sigmoid - training cross-entropy not stabilised)
# final test accuracy = 0.9842 (relu - training set fully learned, test accuracy stable)