Skip to content

Commit 76a0d48

Browse files
author
yichen
committed
add gpu option
1 parent 08dce44 commit 76a0d48

File tree

1 file changed

+17
-4
lines changed

1 file changed

+17
-4
lines changed

src/ml/basic/nn_tensor.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,20 @@ def train_neural_network(X, Y):
6060
accuracy = tf.reduce_mean(tf.cast(correct,'float'))
6161
print('accuracy: ', accuracy.eval({X:mnist.test.images, Y:mnist.test.labels}))
6262

63-
t1 = datetime.datetime.now()
64-
train_neural_network(X,Y)
65-
t2 = datetime.datetime.now()
66-
print(datetime.timedelta(t1, t2))
63+
64+
if __name__ == '__main__':
65+
t1 = datetime.datetime.now()
66+
# with tf.device(assign_to_device('/cpu:0')):
67+
with tf.device('/cpu:0'):
68+
train_neural_network(X,Y)
69+
t2 = datetime.datetime.now()
70+
71+
print(t2 - t1)
72+
73+
# with tf.device(assign_to_device('/gpu:0', ps_device='/cpu:0')):
74+
# with tf.device(assign_to_device('/gpu:0')):
75+
# with tf.device('/gpu:0'):
76+
# train_neural_network(X,Y)
77+
# t3 = datetime.datetime.now()
78+
79+
# print(t3 - t2)

0 commit comments

Comments
 (0)