-
Notifications
You must be signed in to change notification settings - Fork 1.2k
/
10_residual_network.py
166 lines (138 loc) · 5.44 KB
/
10_residual_network.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
"""In progress.
Parag K. Mital, Jan 2016.
"""
# %%
import tensorflow as tf
from libs.connections import conv2d, linear
from collections import namedtuple
from math import sqrt
# %%
def residual_network(x, n_outputs,
activation=tf.nn.relu):
"""Builds a residual network.
Parameters
----------
x : Placeholder
Input to the network
n_outputs : TYPE
Number of outputs of final softmax
activation : Attribute, optional
Nonlinearity to apply after each convolution
Returns
-------
net : Tensor
Description
Raises
------
ValueError
If a 2D Tensor is input, the Tensor must be square or else
the network can't be converted to a 4D Tensor.
"""
# %%
LayerBlock = namedtuple(
'LayerBlock', ['num_repeats', 'num_filters', 'bottleneck_size'])
blocks = [LayerBlock(3, 128, 32),
LayerBlock(3, 256, 64),
LayerBlock(3, 512, 128),
LayerBlock(3, 1024, 256)]
# %%
input_shape = x.get_shape().as_list()
if len(input_shape) == 2:
ndim = int(sqrt(input_shape[1]))
if ndim * ndim != input_shape[1]:
raise ValueError('input_shape should be square')
x = tf.reshape(x, [-1, ndim, ndim, 1])
# %%
# First convolution expands to 64 channels and downsamples
net = conv2d(x, 64, k_h=7, k_w=7,
name='conv1',
activation=activation)
# %%
# Max pool and downsampling
net = tf.nn.max_pool(
net, [1, 3, 3, 1], strides=[1, 2, 2, 1], padding='SAME')
# %%
# Setup first chain of resnets
net = conv2d(net, blocks[0].num_filters, k_h=1, k_w=1,
stride_h=1, stride_w=1, padding='VALID', name='conv2')
# %%
# Loop through all res blocks
for block_i, block in enumerate(blocks):
for repeat_i in range(block.num_repeats):
name = 'block_%d/repeat_%d' % (block_i, repeat_i)
conv = conv2d(net, block.bottleneck_size, k_h=1, k_w=1,
padding='VALID', stride_h=1, stride_w=1,
activation=activation,
name=name + '/conv_in')
conv = conv2d(conv, block.bottleneck_size, k_h=3, k_w=3,
padding='SAME', stride_h=1, stride_w=1,
activation=activation,
name=name + '/conv_bottleneck')
conv = conv2d(conv, block.num_filters, k_h=1, k_w=1,
padding='VALID', stride_h=1, stride_w=1,
activation=activation,
name=name + '/conv_out')
net = conv + net
try:
# upscale to the next block size
next_block = blocks[block_i + 1]
net = conv2d(net, next_block.num_filters, k_h=1, k_w=1,
padding='SAME', stride_h=1, stride_w=1, bias=False,
name='block_%d/conv_upscale' % block_i)
except IndexError:
pass
# %%
net = tf.nn.avg_pool(net,
ksize=[1, net.get_shape().as_list()[1],
net.get_shape().as_list()[2], 1],
strides=[1, 1, 1, 1], padding='VALID')
net = tf.reshape(
net,
[-1, net.get_shape().as_list()[1] *
net.get_shape().as_list()[2] *
net.get_shape().as_list()[3]])
net = linear(net, n_outputs, activation=tf.nn.softmax)
# %%
return net
def test_mnist():
"""Test the resnet on MNIST."""
import tensorflow.examples.tutorials.mnist.input_data as input_data
mnist = input_data.read_data_sets('MNIST_data/', one_hot=True)
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])
y_pred = residual_network(x, 10)
# %% Define loss/eval/training functions
cross_entropy = -tf.reduce_sum(y * tf.log(y_pred))
optimizer = tf.train.AdamOptimizer().minimize(cross_entropy)
# %% Monitor accuracy
correct_prediction = tf.equal(tf.argmax(y_pred, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, 'float'))
# %% We now create a new session to actually perform the initialization the
# variables:
sess = tf.Session()
sess.run(tf.global_variables_initializer())
# %% We'll train in minibatches and report accuracy:
batch_size = 50
n_epochs = 5
for epoch_i in range(n_epochs):
# Training
train_accuracy = 0
for batch_i in range(mnist.train.num_examples // batch_size):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
train_accuracy += sess.run([optimizer, accuracy], feed_dict={
x: batch_xs, y: batch_ys})[1]
train_accuracy /= (mnist.train.num_examples // batch_size)
# Validation
valid_accuracy = 0
for batch_i in range(mnist.validation.num_examples // batch_size):
batch_xs, batch_ys = mnist.validation.next_batch(batch_size)
valid_accuracy += sess.run(accuracy,
feed_dict={
x: batch_xs,
y: batch_ys
})
valid_accuracy /= (mnist.validation.num_examples // batch_size)
print('epoch:', epoch_i, ', train:',
train_accuracy, ', valid:', valid_accuracy)
if __name__ == '__main__':
test_mnist()