Skip to content

Commit 63350dd

Browse files
authored
Add files via upload
1 parent 2ddf6c7 commit 63350dd

File tree

5 files changed

+53
-0
lines changed

5 files changed

+53
-0
lines changed
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import tensorflow as tf
2+
import numpy as np
3+
4+
# getting dataset and training ,Testing data
5+
mnist_data = tf.contrib.learn.datasets.load_dataset('mnist')
6+
x_train = mnist_data.train.images
7+
y_train = np.asarray(mnist_data.train.labels, dtype=np.int32)
8+
x_eval = mnist_data.test.images
9+
y_eval = np.asarray(mnist_data.test.labels, dtype=np.int32)
10+
11+
x_predict =x_eval[:1]
12+
13+
14+
# creating linear regression model y = Wx+b
15+
def model_fn(features, labels, mode):
16+
x = tf.reshape(features['x'],[-1,784])
17+
W = tf.get_variable(name='W', shape=[784, 10], dtype=tf.float32)
18+
b = tf.get_variable(name='b', shape=[10], dtype=tf.float32)
19+
y = tf.add(tf.matmul(x, W), b)
20+
21+
if mode == tf.estimator.ModeKeys.PREDICT:
22+
return tf.estimator.EstimatorSpec(mode=mode, predictions=tf.nn.softmax(logits=y))
23+
24+
onehot_labels = tf.one_hot(indices=tf.cast(labels, tf.int32), depth=10)
25+
loss = tf.losses.softmax_cross_entropy(onehot_labels=onehot_labels, logits=y)
26+
27+
#training
28+
if mode == tf.estimator.ModeKeys.TRAIN:
29+
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
30+
train_step = optimizer.minimize(loss=loss, global_step=tf.train.get_global_step())
31+
32+
return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_step)
33+
34+
#testing and evaluation
35+
eval_metric_ops={'accuracy':tf.metrics.accuracy(labels=labels,predictions=tf.argmax(y, 1))}
36+
return tf.estimator.EstimatorSpec(mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)
37+
38+
39+
40+
41+
42+
estimator= tf.estimator.Estimator(model_fn= model_fn)
43+
44+
train_input_fn = tf.estimator.inputs.numpy_input_fn(x={'x':x_train}, y=y_train, batch_size=100, num_epochs=None, shuffle=True)
45+
46+
eval_input_fn = tf.estimator.inputs.numpy_input_fn(x={'x':x_eval}, y=y_eval, num_epochs=1, shuffle=False)
47+
48+
predict_input_fn = tf.estimator.inputs.numpy_input_fn(x={'x':x_predict}, num_epochs=1, shuffle=False)
49+
50+
estimator.train(input_fn=train_input_fn, steps=20000)
51+
52+
print(estimator.evaluate(input_fn=eval_input_fn))
53+
print(list(estimator.predict(input_fn=predict_input_fn)))

0 commit comments

Comments
 (0)