Skip to content

Commit 4aa66e3

Browse files
XJTUWYDzsdonghao
authored andcommitted
🚀 Tenary Weight and DoReFa-Net in TensorFlow (TensorLayer) (tensorlayer#440)
* compress add four apis TenaryConv2d,TenaryDenseLayer,DorefaConv2d,DorefaDenselyLayer and build different tutorials for bnn,twn,dorefa based on mnist and cifar10 datasets * four apis four apis * fiexed some bugs of format fiexed some bugs of format * add bitw and bita for apis add bitw and bita for apis * Add files via upload * Add files via upload * Add files via upload * do some explain about twn and dorefa * fix some issue and delete some comment fix some issue and delete some comment * add some comment * use yapf
1 parent a7e29dd commit 4aa66e3

File tree

6 files changed

+1586
-0
lines changed

6 files changed

+1586
-0
lines changed
Lines changed: 297 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,297 @@
1+
#! /usr/bin/python
2+
# -*- coding: utf-8 -*-
3+
"""
4+
5+
- 1. This model has 1,068,298 paramters and Dorefa compression strategy(weight:1 bit, active: 1 bit),
6+
after 500 epoches' training with GPU,accurcy of 41.1% was found.
7+
8+
- 2. For simplified CNN layers see "Convolutional layer (Simplified)"
9+
in read the docs website.
10+
11+
- 3. Data augmentation without TFRecord see `tutorial_image_preprocess.py` !!
12+
13+
Links
14+
-------
15+
.. https://www.tensorflow.org/versions/r0.9/tutorials/deep_cnn/index.html
16+
.. https://github.com/tensorflow/tensorflow/tree/r0.9/tensorflow/models/image/cifar10
17+
18+
Note
19+
------
20+
The optimizers between official code and this code are different.
21+
22+
Description
23+
-----------
24+
The images are processed as follows:
25+
.. They are cropped to 24 x 24 pixels, centrally for evaluation or randomly for training.
26+
.. They are approximately whitened to make the model insensitive to dynamic range.
27+
28+
For training, we additionally apply a series of random distortions to
29+
artificially increase the data set size:
30+
.. Randomly flip the image from left to right.
31+
.. Randomly distort the image brightness.
32+
.. Randomly distort the image contrast.
33+
34+
Speed Up
35+
--------
36+
Reading images from disk and distorting them can use a non-trivial amount
37+
of processing time. To prevent these operations from slowing down training,
38+
we run them inside 16 separate threads which continuously fill a TensorFlow queue.
39+
40+
"""
41+
42+
import os, time
43+
import tensorflow as tf
44+
import tensorlayer as tl
45+
46+
model_file_name = "./model_cifar10_tfrecord.ckpt"
47+
resume = False # load model, resume from previous checkpoint?
48+
49+
## Download data, and convert to TFRecord format, see ```tutorial_tfrecord.py```
50+
X_train, y_train, X_test, y_test = tl.files.load_cifar10_dataset(shape=(-1, 32, 32, 3), plotable=False)
51+
52+
print('X_train.shape', X_train.shape) # (50000, 32, 32, 3)
53+
print('y_train.shape', y_train.shape) # (50000,)
54+
print('X_test.shape', X_test.shape) # (10000, 32, 32, 3)
55+
print('y_test.shape', y_test.shape) # (10000,)
56+
print('X %s y %s' % (X_test.dtype, y_test.dtype))
57+
58+
59+
def data_to_tfrecord(images, labels, filename):
60+
""" Save data into TFRecord """
61+
if os.path.isfile(filename):
62+
print("%s exists" % filename)
63+
return
64+
print("Converting data into %s ..." % filename)
65+
# cwd = os.getcwd()
66+
writer = tf.python_io.TFRecordWriter(filename)
67+
for index, img in enumerate(images):
68+
img_raw = img.tobytes()
69+
## Visualize a image
70+
# tl.visualize.frame(np.asarray(img, dtype=np.uint8), second=1, saveable=False, name='frame', fig_idx=1236)
71+
label = int(labels[index])
72+
# print(label)
73+
## Convert the bytes back to image as follow:
74+
# image = Image.frombytes('RGB', (32, 32), img_raw)
75+
# image = np.fromstring(img_raw, np.float32)
76+
# image = image.reshape([32, 32, 3])
77+
# tl.visualize.frame(np.asarray(image, dtype=np.uint8), second=1, saveable=False, name='frame', fig_idx=1236)
78+
example = tf.train.Example(
79+
features=tf.train.Features(
80+
feature={
81+
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
82+
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
83+
}))
84+
writer.write(example.SerializeToString()) # Serialize To String
85+
writer.close()
86+
87+
88+
def read_and_decode(filename, is_train=None):
89+
""" Return tensor to read from TFRecord """
90+
filename_queue = tf.train.string_input_producer([filename])
91+
reader = tf.TFRecordReader()
92+
_, serialized_example = reader.read(filename_queue)
93+
features = tf.parse_single_example(
94+
serialized_example, features={
95+
'label': tf.FixedLenFeature([], tf.int64),
96+
'img_raw': tf.FixedLenFeature([], tf.string),
97+
})
98+
# You can do more image distortion here for training data
99+
img = tf.decode_raw(features['img_raw'], tf.float32)
100+
img = tf.reshape(img, [32, 32, 3])
101+
# img = tf.cast(img, tf.float32) #* (1. / 255) - 0.5
102+
if is_train == True:
103+
# 1. Randomly crop a [height, width] section of the image.
104+
img = tf.random_crop(img, [24, 24, 3])
105+
# 2. Randomly flip the image horizontally.
106+
img = tf.image.random_flip_left_right(img)
107+
# 3. Randomly change brightness.
108+
img = tf.image.random_brightness(img, max_delta=63)
109+
# 4. Randomly change contrast.
110+
img = tf.image.random_contrast(img, lower=0.2, upper=1.8)
111+
# 5. Subtract off the mean and divide by the variance of the pixels.
112+
try: # TF 0.12+
113+
img = tf.image.per_image_standardization(img)
114+
except Exception: # earlier TF versions
115+
img = tf.image.per_image_whitening(img)
116+
117+
elif is_train == False:
118+
# 1. Crop the central [height, width] of the image.
119+
img = tf.image.resize_image_with_crop_or_pad(img, 24, 24)
120+
# 2. Subtract off the mean and divide by the variance of the pixels.
121+
try: # TF 0.12+
122+
img = tf.image.per_image_standardization(img)
123+
except Exception: # earlier TF versions
124+
img = tf.image.per_image_whitening(img)
125+
elif is_train == None:
126+
img = img
127+
128+
label = tf.cast(features['label'], tf.int32)
129+
return img, label
130+
131+
132+
## Save data into TFRecord files
133+
data_to_tfrecord(images=X_train, labels=y_train, filename="train.cifar10")
134+
data_to_tfrecord(images=X_test, labels=y_test, filename="test.cifar10")
135+
136+
batch_size = 128
137+
model_file_name = "./model_cifar10_advanced.ckpt"
138+
resume = False # load model, resume from previous checkpoint?
139+
140+
with tf.device('/cpu:0'):
141+
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
142+
# prepare data in cpu
143+
x_train_, y_train_ = read_and_decode("train.cifar10", True)
144+
x_test_, y_test_ = read_and_decode("test.cifar10", False)
145+
146+
x_train_batch, y_train_batch = tf.train.shuffle_batch(
147+
[x_train_, y_train_], batch_size=batch_size, capacity=2000, min_after_dequeue=1000, num_threads=32) # set the number of threads here
148+
# for testing, uses batch instead of shuffle_batch
149+
x_test_batch, y_test_batch = tf.train.batch([x_test_, y_test_], batch_size=batch_size, capacity=50000, num_threads=32)
150+
151+
def model(x_crop, y_, reuse):
152+
""" For more simplified CNN APIs, check tensorlayer.org """
153+
W_init = tf.truncated_normal_initializer(stddev=5e-2)
154+
W_init2 = tf.truncated_normal_initializer(stddev=0.04)
155+
b_init2 = tf.constant_initializer(value=0.1)
156+
with tf.variable_scope("model", reuse=reuse):
157+
net = tl.layers.InputLayer(x_crop, name='input')
158+
net = tl.layers.Conv2d(net, 64, (5, 5), (1, 1), act=tf.nn.relu, padding='SAME', W_init=W_init, name='cnn1')
159+
net = tl.layers.SignLayer(net)
160+
net = tl.layers.MaxPool2d(net, (3, 3), (2, 2), padding='SAME', name='pool1')
161+
net = tl.layers.LocalResponseNormLayer(net, depth_radius=4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, name='norm1')
162+
net = tl.layers.BinaryConv2d(net, 64, (5, 5), (1, 1), act=tf.nn.relu, padding='SAME', W_init=W_init, name='cnn2')
163+
net = tl.layers.LocalResponseNormLayer(net, depth_radius=4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, name='norm2')
164+
net = tl.layers.MaxPool2d(net, (3, 3), (2, 2), padding='SAME', name='pool2')
165+
net = tl.layers.FlattenLayer(net, name='flatten') # output: (batch_size, 2304)
166+
net = tl.layers.SignLayer(net)
167+
net = tl.layers.BinaryDenseLayer(net, n_units=384, act=tf.nn.relu, W_init=W_init2, b_init=b_init2, name='d1relu') # output: (batch_size, 384)
168+
net = tl.layers.SignLayer(net)
169+
net = tl.layers.BinaryDenseLayer(net, n_units=192, act=tf.nn.relu, W_init=W_init2, b_init=b_init2, name='d2relu') # output: (batch_size, 192)
170+
net = tl.layers.DenseLayer(net, n_units=10, act=tf.identity, W_init=W_init2, name='output') # output: (batch_size, 10)
171+
y = net.outputs
172+
173+
ce = tl.cost.cross_entropy(y, y_, name='cost')
174+
# L2 for the MLP, without this, the accuracy will be reduced by 15%.
175+
L2 = 0
176+
for p in tl.layers.get_variables_with_name('relu/W', True, True):
177+
L2 += tf.contrib.layers.l2_regularizer(0.004)(p)
178+
cost = ce + L2
179+
180+
# correct_prediction = tf.equal(tf.argmax(tf.nn.softmax(y), 1), y_)
181+
correct_prediction = tf.equal(tf.cast(tf.argmax(y, 1), tf.int32), y_)
182+
acc = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
183+
184+
return net, cost, acc
185+
186+
def model_batch_norm(x_crop, y_, reuse, is_train):
187+
""" Batch normalization should be placed before rectifier. """
188+
W_init = tf.truncated_normal_initializer(stddev=5e-2)
189+
W_init2 = tf.truncated_normal_initializer(stddev=0.04)
190+
b_init2 = tf.constant_initializer(value=0.1)
191+
with tf.variable_scope("model", reuse=reuse):
192+
net = InputLayer(x_crop, name='input')
193+
194+
net = tl.layers.Conv2d(net, 64, (5, 5), (1, 1), padding='SAME', W_init=W_init, b_init=None, name='cnn1')
195+
net = tl.layers.BatchNormLayer(net, is_train, act=tf.nn.relu, name='batch1')
196+
net = tl.layers.MaxPool2d(net, (3, 3), (2, 2), padding='SAME', name='pool1')
197+
net = tl.layers.Conv2d(net, 64, (5, 5), (1, 1), padding='SAME', W_init=W_init, b_init=None, name='cnn2')
198+
net = tl.layers.BatchNormLayer(net, is_train, act=tf.nn.relu, name='batch2')
199+
net = tl.layers.MaxPool2d(net, (3, 3), (2, 2), padding='SAME', name='pool2')
200+
net = tl.layers.FlattenLayer(net, name='flatten') # output: (batch_size, 2304)
201+
net = tl.layers.DenseLayer(net, n_units=384, act=tf.nn.relu, W_init=W_init2, b_init=b_init2, name='d1relu') # output: (batch_size, 384)
202+
net = tl.layers.DenseLayer(net, n_units=192, act=tf.nn.relu, W_init=W_init2, b_init=b_init2, name='d2relu') # output: (batch_size, 192)
203+
net = tl.layers.DenseLayer(net, n_units=10, act=tf.identity, W_init=W_init2, name='output') # output: (batch_size, 10)
204+
y = net.outputs
205+
206+
ce = tl.cost.cross_entropy(y, y_, name='cost')
207+
# L2 for the MLP, without this, the accuracy will be reduced by 15%.
208+
L2 = 0
209+
for p in tl.layers.get_variables_with_name('relu/W', True, True):
210+
L2 += tf.contrib.layers.l2_regularizer(0.004)(p)
211+
cost = ce + L2
212+
213+
correct_prediction = tf.equal(tf.cast(tf.argmax(y, 1), tf.int32), y_)
214+
acc = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
215+
216+
return net, cost, acc
217+
218+
## You can also use placeholder to feed_dict in data after using
219+
## val, l = sess.run([x_train_batch, y_train_batch]) to get the data
220+
# x_crop = tf.placeholder(tf.float32, shape=[batch_size, 24, 24, 3])
221+
# y_ = tf.placeholder(tf.int32, shape=[batch_size,])
222+
# cost, acc, network = model(x_crop, y_, None)
223+
224+
with tf.device('/gpu:0'): # <-- remove it if you don't have GPU
225+
## using local response normalization
226+
network, cost, acc, = model(x_train_batch, y_train_batch, False)
227+
_, cost_test, acc_test = model(x_test_batch, y_test_batch, True)
228+
## you may want to try batch normalization
229+
# network, cost, acc, = model_batch_norm(x_train_batch, y_train_batch, None, is_train=True)
230+
# _, cost_test, acc_test = model_batch_norm(x_test_batch, y_test_batch, True, is_train=False)
231+
232+
## train
233+
n_epoch = 50000
234+
learning_rate = 0.0001
235+
print_freq = 1
236+
n_step_epoch = int(len(y_train) / batch_size)
237+
n_step = n_epoch * n_step_epoch
238+
239+
with tf.device('/gpu:0'): # <-- remove it if you don't have GPU
240+
train_op = tf.train.AdamOptimizer(learning_rate).minimize(cost)
241+
242+
tl.layers.initialize_global_variables(sess)
243+
if resume:
244+
print("Load existing model " + "!" * 10)
245+
saver = tf.train.Saver()
246+
saver.restore(sess, model_file_name)
247+
248+
network.print_params(False)
249+
network.print_layers()
250+
251+
print(' learning_rate: %f' % learning_rate)
252+
print(' batch_size: %d' % batch_size)
253+
print(' n_epoch: %d, step in an epoch: %d, total n_step: %d' % (n_epoch, n_step_epoch, n_step))
254+
255+
coord = tf.train.Coordinator()
256+
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
257+
step = 0
258+
for epoch in range(n_epoch):
259+
start_time = time.time()
260+
train_loss, train_acc, n_batch = 0, 0, 0
261+
for s in range(n_step_epoch):
262+
## You can also use placeholder to feed_dict in data after using
263+
# val, l = sess.run([x_train_batch, y_train_batch])
264+
# tl.visualize.images2d(val, second=3, saveable=False, name='batch', dtype=np.uint8, fig_idx=2020121)
265+
# err, ac, _ = sess.run([cost, acc, train_op], feed_dict={x_crop: val, y_: l})
266+
err, ac, _ = sess.run([cost, acc, train_op])
267+
step += 1
268+
train_loss += err
269+
train_acc += ac
270+
n_batch += 1
271+
272+
if epoch + 1 == 1 or (epoch + 1) % print_freq == 0:
273+
print("Epoch %d : Step %d-%d of %d took %fs" % (epoch, step, step + n_step_epoch, n_step, time.time() - start_time))
274+
print(" train loss: %f" % (train_loss / n_batch))
275+
print(" train acc: %f" % (train_acc / n_batch))
276+
277+
test_loss, test_acc, n_batch = 0, 0, 0
278+
for _ in range(int(len(y_test) / batch_size)):
279+
err, ac = sess.run([cost_test, acc_test])
280+
test_loss += err
281+
test_acc += ac
282+
n_batch += 1
283+
print(" test loss: %f" % (test_loss / n_batch))
284+
print(" test acc: %f" % (test_acc / n_batch))
285+
286+
if (epoch + 1) % (print_freq * 50) == 0:
287+
print("Save model " + "!" * 10)
288+
saver = tf.train.Saver()
289+
save_path = saver.save(sess, model_file_name)
290+
# you can also save model into npz
291+
tl.files.save_npz(network.all_params, name='model.npz', sess=sess)
292+
# and restore it as follow:
293+
# tl.files.load_and_assign_npz(sess=sess, name='model.npz', network=network)
294+
295+
coord.request_stop()
296+
coord.join(threads)
297+
sess.close()

0 commit comments

Comments
 (0)