Skip to content

Commit fed0ee1

Browse files
author
Vineet John
committed
Added softmax and cnn image classification processors
1 parent e0e7ae6 commit fed0ee1

File tree

5 files changed

+143
-18
lines changed

5 files changed

+143
-18
lines changed
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import tensorflow as tf
2+
from tensorflow.examples.tutorials.mnist import input_data
3+
4+
from processors.processor import Processor
5+
from utils import log_helper
6+
7+
log = log_helper.get_logger("CNNProcessor")
8+
9+
10+
class CNNProcessor(Processor):
11+
12+
def process(self):
13+
log.info("CNNProcessor begun")
14+
15+
# Import data
16+
data_dir = "/tmp/tensorflow/mnist/input_data"
17+
mnist = input_data.read_data_sets(data_dir, one_hot=True)
18+
19+
x = tf.placeholder(tf.float32, [None, 784])
20+
y_ = tf.placeholder(tf.float32, [None, 10])
21+
22+
sess = tf.InteractiveSession()
23+
24+
W_conv1 = weight_variable([5, 5, 1, 32])
25+
b_conv1 = bias_variable([32])
26+
27+
x_image = tf.reshape(x, [-1, 28, 28, 1])
28+
29+
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
30+
h_pool1 = max_pool_2x2(h_conv1)
31+
32+
W_conv2 = weight_variable([5, 5, 32, 64])
33+
b_conv2 = bias_variable([64])
34+
35+
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
36+
h_pool2 = max_pool_2x2(h_conv2)
37+
38+
W_fc1 = weight_variable([7 * 7 * 64, 1024])
39+
b_fc1 = bias_variable([1024])
40+
41+
h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
42+
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
43+
44+
keep_prob = tf.placeholder(tf.float32)
45+
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)
46+
47+
W_fc2 = weight_variable([1024, 10])
48+
b_fc2 = bias_variable([10])
49+
50+
y_conv = tf.matmul(h_fc1_drop, W_fc2) + b_fc2
51+
52+
cross_entropy = tf.reduce_mean(
53+
tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y_conv))
54+
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
55+
correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1))
56+
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
57+
sess.run(tf.global_variables_initializer())
58+
for i in range(1001):
59+
batch = mnist.train.next_batch(50)
60+
if i % 100 == 0:
61+
train_accuracy = accuracy.eval(feed_dict={
62+
x: batch[0], y_: batch[1], keep_prob: 1.0})
63+
print("step %d, training accuracy %g" % (i, train_accuracy))
64+
train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
65+
66+
print("test accuracy %g" % accuracy.eval(feed_dict={
67+
x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))
68+
69+
log.info("CNNProcessor concluded")
70+
71+
def weight_variable(shape):
72+
initial = tf.truncated_normal(shape, stddev=0.1)
73+
return tf.Variable(initial)
74+
75+
def bias_variable(shape):
76+
initial = tf.constant(0.1, shape=shape)
77+
return tf.Variable(initial)
78+
79+
def conv2d(x, W):
80+
return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
81+
82+
def max_pool_2x2(x):
83+
return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],
84+
strides=[1, 2, 2, 1], padding='SAME')

tensorflow-mnist/processors/mnist_processor.py

Lines changed: 0 additions & 13 deletions
This file was deleted.
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import tensorflow as tf
2+
from tensorflow.examples.tutorials.mnist import input_data
3+
4+
from processors.processor import Processor
5+
from utils import log_helper
6+
7+
log = log_helper.get_logger("CNNProcessor")
8+
9+
10+
class SoftmaxRegressionProcessor(Processor):
11+
12+
def process(self):
13+
log.info("CNNProcessor begun")
14+
15+
# Import data
16+
data_dir = "/tmp/tensorflow/mnist/input_data"
17+
mnist = input_data.read_data_sets(data_dir, one_hot=True)
18+
19+
# Create the model
20+
x = tf.placeholder(tf.float32, [None, 784])
21+
W = tf.Variable(tf.zeros([784, 10]))
22+
b = tf.Variable(tf.zeros([10]))
23+
y = tf.matmul(x, W) + b
24+
25+
# Define loss and optimizer
26+
y_ = tf.placeholder(tf.float32, [None, 10])
27+
28+
cross_entropy = tf.reduce_mean(
29+
tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
30+
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
31+
32+
sess = tf.InteractiveSession()
33+
tf.global_variables_initializer().run()
34+
# Train
35+
for _ in range(1000):
36+
batch_xs, batch_ys = mnist.train.next_batch(100)
37+
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
38+
39+
# Test trained model
40+
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
41+
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
42+
print(sess.run(accuracy, feed_dict={x: mnist.test.images,
43+
y_: mnist.test.labels}))
44+
45+
log.info("CNNProcessor concluded")
Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#!/usr/bin/env bash
22

33
CODEDIR=$(dirname "$0")"/../"
4-
DATADIR=$(dirname "$0")"/../data/"
54

65
# Run Gaussian Mixture Model processor
7-
/usr/bin/python3 "$CODEDIR"/tensorflow_mnist.py
6+
/usr/bin/python3 "$CODEDIR"/tensorflow_mnist.py

tensorflow-mnist/tensorflow_mnist.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,28 @@
22

33
import sys
44

5-
from processors.mnist_processor import MnistProcessor
5+
from processors.cnn_processor import CNNProcessor
6+
from processors.softmax_regression_processor import SoftmaxRegressionProcessor
67
from utils.options import Options
78

89

910
def main(argv):
11+
processor = None
1012
options = parse_args(argv)
11-
processor = MnistProcessor(options)
12-
processor.process()
13+
14+
if options.mode == "softmax":
15+
processor = SoftmaxRegressionProcessor(options)
16+
elif options.mode == "cnn":
17+
processor = CNNProcessor(options)
18+
19+
if processor:
20+
processor.process()
1321

1422

1523
def parse_args(argv):
1624
parser = ArgumentParser(prog="tensorflow_mnist")
25+
parser.add_argument('--mode', metavar='Run mode', type=str, required=True)
26+
1727
return parser.parse_args(argv, namespace=Options)
1828

1929

0 commit comments

Comments
 (0)