Skip to content

Commit 8e03823

Browse files
Refactor nearest_neighbor for TF1.0
Signed-off-by: Norman Heckscher <norman.heckscher@gmail.com>
1 parent 7839ba2 commit 8e03823

File tree

2 files changed

+25
-16
lines changed

2 files changed

+25
-16
lines changed

examples/2_BasicModels/nearest_neighbor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,14 @@
2626

2727
# Nearest Neighbor calculation using L1 Distance
2828
# Calculate L1 Distance
29-
distance = tf.reduce_sum(tf.abs(tf.add(xtr, tf.neg(xte))), reduction_indices=1)
29+
distance = tf.reduce_sum(tf.abs(tf.add(xtr, tf.negative(xte))), reduction_indices=1)
3030
# Prediction: Get min distance index (Nearest neighbor)
3131
pred = tf.arg_min(distance, 0)
3232

3333
accuracy = 0.
3434

3535
# Initializing the variables
36-
init = tf.initialize_all_variables()
36+
init = tf.global_variables_initializer()
3737

3838
# Launch the graph
3939
with tf.Session() as sess:

notebooks/2_BasicModels/nearest_neighbor.ipynb

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
},
1919
{
2020
"cell_type": "code",
21-
"execution_count": 2,
21+
"execution_count": 1,
2222
"metadata": {
2323
"collapsed": false
2424
},
@@ -27,10 +27,10 @@
2727
"name": "stdout",
2828
"output_type": "stream",
2929
"text": [
30-
"Extracting /tmp/data/train-images-idx3-ubyte.gz\n",
31-
"Extracting /tmp/data/train-labels-idx1-ubyte.gz\n",
32-
"Extracting /tmp/data/t10k-images-idx3-ubyte.gz\n",
33-
"Extracting /tmp/data/t10k-labels-idx1-ubyte.gz\n"
30+
"Extracting MNIST_data/train-images-idx3-ubyte.gz\n",
31+
"Extracting MNIST_data/train-labels-idx1-ubyte.gz\n",
32+
"Extracting MNIST_data/t10k-images-idx3-ubyte.gz\n",
33+
"Extracting MNIST_data/t10k-labels-idx1-ubyte.gz\n"
3434
]
3535
}
3636
],
@@ -40,14 +40,14 @@
4040
"\n",
4141
"# Import MINST data\n",
4242
"from tensorflow.examples.tutorials.mnist import input_data\n",
43-
"mnist = input_data.read_data_sets(\"/tmp/data/\", one_hot=True)"
43+
"mnist = input_data.read_data_sets(\"MNIST_data/\", one_hot=True)"
4444
]
4545
},
4646
{
4747
"cell_type": "code",
48-
"execution_count": 3,
48+
"execution_count": 2,
4949
"metadata": {
50-
"collapsed": true
50+
"collapsed": false
5151
},
5252
"outputs": [],
5353
"source": [
@@ -61,19 +61,19 @@
6161
"\n",
6262
"# Nearest Neighbor calculation using L1 Distance\n",
6363
"# Calculate L1 Distance\n",
64-
"distance = tf.reduce_sum(tf.abs(tf.add(xtr, tf.neg(xte))), reduction_indices=1)\n",
64+
"distance = tf.reduce_sum(tf.abs(tf.add(xtr, tf.negative(xte))), reduction_indices=1)\n",
6565
"# Prediction: Get min distance index (Nearest neighbor)\n",
6666
"pred = tf.arg_min(distance, 0)\n",
6767
"\n",
6868
"accuracy = 0.\n",
6969
"\n",
7070
"# Initializing the variables\n",
71-
"init = tf.initialize_all_variables()"
71+
"init = tf.global_variables_initializer()"
7272
]
7373
},
7474
{
7575
"cell_type": "code",
76-
"execution_count": 4,
76+
"execution_count": 3,
7777
"metadata": {
7878
"collapsed": false
7979
},
@@ -305,6 +305,15 @@
305305
" print \"Done!\"\n",
306306
" print \"Accuracy:\", accuracy"
307307
]
308+
},
309+
{
310+
"cell_type": "code",
311+
"execution_count": null,
312+
"metadata": {
313+
"collapsed": true
314+
},
315+
"outputs": [],
316+
"source": []
308317
}
309318
],
310319
"metadata": {
@@ -316,16 +325,16 @@
316325
"language_info": {
317326
"codemirror_mode": {
318327
"name": "ipython",
319-
"version": 2.0
328+
"version": 2
320329
},
321330
"file_extension": ".py",
322331
"mimetype": "text/x-python",
323332
"name": "python",
324333
"nbconvert_exporter": "python",
325334
"pygments_lexer": "ipython2",
326-
"version": "2.7.11"
335+
"version": "2.7.13"
327336
}
328337
},
329338
"nbformat": 4,
330339
"nbformat_minor": 0
331-
}
340+
}

0 commit comments

Comments
 (0)