|
| 1 | +import tensorflow as tf |
| 2 | +import numpy as np |
| 3 | + |
| 4 | +# getting dataset and training ,Testing data |
| 5 | +mnist_data = tf.contrib.learn.datasets.load_dataset('mnist') |
| 6 | +x_train = mnist_data.train.images |
| 7 | +y_train = np.asarray(mnist_data.train.labels, dtype=np.int32) |
| 8 | +x_eval = mnist_data.test.images |
| 9 | +y_eval = np.asarray(mnist_data.test.labels, dtype=np.int32) |
| 10 | + |
| 11 | +x_predict =x_eval[:1] |
| 12 | + |
| 13 | + |
| 14 | +# creating linear regression model y = Wx+b |
| 15 | +def model_fn(features, labels, mode): |
| 16 | + x = tf.reshape(features['x'],[-1,784]) |
| 17 | + W = tf.get_variable(name='W', shape=[784, 10], dtype=tf.float32) |
| 18 | + b = tf.get_variable(name='b', shape=[10], dtype=tf.float32) |
| 19 | + y = tf.add(tf.matmul(x, W), b) |
| 20 | + |
| 21 | + if mode == tf.estimator.ModeKeys.PREDICT: |
| 22 | + return tf.estimator.EstimatorSpec(mode=mode, predictions=tf.nn.softmax(logits=y)) |
| 23 | + |
| 24 | + onehot_labels = tf.one_hot(indices=tf.cast(labels, tf.int32), depth=10) |
| 25 | + loss = tf.losses.softmax_cross_entropy(onehot_labels=onehot_labels, logits=y) |
| 26 | + |
| 27 | +#training |
| 28 | + if mode == tf.estimator.ModeKeys.TRAIN: |
| 29 | + optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01) |
| 30 | + train_step = optimizer.minimize(loss=loss, global_step=tf.train.get_global_step()) |
| 31 | + |
| 32 | + return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_step) |
| 33 | + |
| 34 | +#testing and evaluation |
| 35 | + eval_metric_ops={'accuracy':tf.metrics.accuracy(labels=labels,predictions=tf.argmax(y, 1))} |
| 36 | + return tf.estimator.EstimatorSpec(mode=mode, loss=loss, eval_metric_ops=eval_metric_ops) |
| 37 | + |
| 38 | + |
| 39 | + |
| 40 | + |
| 41 | + |
| 42 | +estimator= tf.estimator.Estimator(model_fn= model_fn) |
| 43 | + |
| 44 | +train_input_fn = tf.estimator.inputs.numpy_input_fn(x={'x':x_train}, y=y_train, batch_size=100, num_epochs=None, shuffle=True) |
| 45 | + |
| 46 | +eval_input_fn = tf.estimator.inputs.numpy_input_fn(x={'x':x_eval}, y=y_eval, num_epochs=1, shuffle=False) |
| 47 | + |
| 48 | +predict_input_fn = tf.estimator.inputs.numpy_input_fn(x={'x':x_predict}, num_epochs=1, shuffle=False) |
| 49 | + |
| 50 | +estimator.train(input_fn=train_input_fn, steps=20000) |
| 51 | + |
| 52 | +print(estimator.evaluate(input_fn=eval_input_fn)) |
| 53 | +print(list(estimator.predict(input_fn=predict_input_fn))) |
0 commit comments