Skip to content

Commit a49f6c2

Browse files
zsdonghaoluomai
authored andcommitted
Add binary neural networks. (tensorlayer#418)
* add BinaryDenseLayer SignLayer etc * add example of binarynet cnn | add BinaryConv2d * rename scale layer\ * remove unused code * remove print params * rename function name in binarynet example * update all * rename sign act name * rename function * fix codacy; * rename sign * improve docs for sign * yapf
1 parent e7c2eda commit a49f6c2

File tree

5 files changed

+384
-3
lines changed

5 files changed

+384
-3
lines changed

docs/modules/activation.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ Swish
4949
------------
5050
.. autofunction:: swish
5151

52-
Differentiable Sign
52+
Sign
5353
---------------------
5454
.. autofunction:: sign
5555

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
#! /usr/bin/python
2+
# -*- coding: utf-8 -*-
3+
4+
import time
5+
import tensorflow as tf
6+
import tensorlayer as tl
7+
8+
X_train, y_train, X_val, y_val, X_test, y_test = \
9+
tl.files.load_mnist_dataset(shape=(-1, 28, 28, 1))
10+
11+
sess = tf.InteractiveSession()
12+
13+
batch_size = 128
14+
15+
x = tf.placeholder(tf.float32, shape=[batch_size, 28, 28, 1])
16+
y_ = tf.placeholder(tf.int64, shape=[batch_size])
17+
18+
19+
def model(x, is_train=True, reuse=False):
20+
with tf.variable_scope("binarynet", reuse=reuse):
21+
net = tl.layers.InputLayer(x, name='input')
22+
net = tl.layers.BinaryConv2d(net, 32, (5, 5), (1, 1), padding='SAME', name='bcnn1')
23+
net = tl.layers.MaxPool2d(net, (2, 2), (2, 2), padding='SAME', name='pool1')
24+
25+
net = tl.layers.BatchNormLayer(net, is_train=is_train, name='bn')
26+
net = tl.layers.SignLayer(net, name='sign2')
27+
net = tl.layers.BinaryConv2d(net, 64, (5, 5), (1, 1), padding='SAME', name='bcnn2')
28+
net = tl.layers.MaxPool2d(net, (2, 2), (2, 2), padding='SAME', name='pool2')
29+
30+
net = tl.layers.SignLayer(net, name='sign2')
31+
net = tl.layers.FlattenLayer(net, name='flatten')
32+
net = tl.layers.DropoutLayer(net, 0.5, True, is_train, name='drop1')
33+
# net = tl.layers.DenseLayer(net, 256, act=tf.nn.relu, name='dense')
34+
net = tl.layers.BinaryDenseLayer(net, 256, name='dense')
35+
net = tl.layers.DropoutLayer(net, 0.5, True, is_train, name='drop2')
36+
# net = tl.layers.DenseLayer(net, 10, act=tf.identity, name='output')
37+
net = tl.layers.BinaryDenseLayer(net, 10, name='bout')
38+
# net = tl.layers.ScaleLayer(net, name='scale')
39+
return net
40+
41+
42+
# define inferences
43+
net_train = model(x, is_train=True, reuse=False)
44+
net_test = model(x, is_train=False, reuse=True)
45+
46+
# cost for training
47+
y = net_train.outputs
48+
cost = tl.cost.cross_entropy(y, y_, name='xentropy')
49+
50+
# cost and accuracy for evalution
51+
y2 = net_test.outputs
52+
cost_test = tl.cost.cross_entropy(y2, y_, name='xentropy2')
53+
correct_prediction = tf.equal(tf.argmax(y2, 1), y_)
54+
acc = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
55+
56+
# define the optimizer
57+
train_params = tl.layers.get_variables_with_name('binarynet', True, True)
58+
train_op = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(cost, var_list=train_params)
59+
60+
# initialize all variables in the session
61+
tl.layers.initialize_global_variables(sess)
62+
63+
net_train.print_params()
64+
net_train.print_layers()
65+
66+
n_epoch = 200
67+
print_freq = 5
68+
69+
# print(sess.run(net_test.all_params)) # print real value of parameters
70+
71+
for epoch in range(n_epoch):
72+
start_time = time.time()
73+
for X_train_a, y_train_a in tl.iterate.minibatches(X_train, y_train, batch_size, shuffle=True):
74+
sess.run(train_op, feed_dict={x: X_train_a, y_: y_train_a})
75+
76+
if epoch + 1 == 1 or (epoch + 1) % print_freq == 0:
77+
print("Epoch %d of %d took %fs" % (epoch + 1, n_epoch, time.time() - start_time))
78+
train_loss, train_acc, n_batch = 0, 0, 0
79+
for X_train_a, y_train_a in tl.iterate.minibatches(X_train, y_train, batch_size, shuffle=True):
80+
err, ac = sess.run([cost_test, acc], feed_dict={x: X_train_a, y_: y_train_a})
81+
train_loss += err
82+
train_acc += ac
83+
n_batch += 1
84+
print(" train loss: %f" % (train_loss / n_batch))
85+
print(" train acc: %f" % (train_acc / n_batch))
86+
val_loss, val_acc, n_batch = 0, 0, 0
87+
for X_val_a, y_val_a in tl.iterate.minibatches(X_val, y_val, batch_size, shuffle=True):
88+
err, ac = sess.run([cost_test, acc], feed_dict={x: X_val_a, y_: y_val_a})
89+
val_loss += err
90+
val_acc += ac
91+
n_batch += 1
92+
print(" val loss: %f" % (val_loss / n_batch))
93+
print(" val acc: %f" % (val_acc / n_batch))
94+
95+
print('Evaluation')
96+
test_loss, test_acc, n_batch = 0, 0, 0
97+
for X_test_a, y_test_a in tl.iterate.minibatches(X_test, y_test, batch_size, shuffle=True):
98+
err, ac = sess.run([cost_test, acc], feed_dict={x: X_test_a, y_: y_test_a})
99+
test_loss += err
100+
test_acc += ac
101+
n_batch += 1
102+
print(" test loss: %f" % (test_loss / n_batch))
103+
print(" test acc: %f" % (test_acc / n_batch))

tensorlayer/activation.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,9 @@ def _sign_grad(unused_op, grad):
123123

124124

125125
def sign(x): # https://github.com/AngusG/tensorflow-xnor-bnn/blob/master/models/binary_net.py#L36
126-
"""Differentiable sign function by clipping linear gradient into [-1, 1], usually be used for quantizing value in binary network, see `tf.sign <https://www.tensorflow.org/api_docs/python/tf/sign>`__.
126+
"""Sign function.
127+
128+
Clip and binarize tensor using the straight through estimator (STE) for the gradient, usually be used for quantizing values in `Binarized Neural Networks <https://arxiv.org/abs/1602.02830>`__.
127129
128130
Parameters
129131
----------
@@ -141,7 +143,7 @@ def sign(x): # https://github.com/AngusG/tensorflow-xnor-bnn/blob/master/models
141143
142144
"""
143145
with tf.get_default_graph().gradient_override_map({"sign": "QuantizeGrad"}):
144-
return tf.sign(x, name='tl_sign')
146+
return tf.sign(x, name='sign')
145147

146148

147149
# if tf.__version__ > "1.7":

tensorlayer/layers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from .core import *
1111
from .convolution import *
12+
from .binary import *
1213
from .super_resolution import *
1314
from .normalization import *
1415
from .spatial_transformer import *

0 commit comments

Comments
 (0)