Skip to content

Commit 0c4e666

Browse files
committed
fix random forest TF 1.4 compatibility
1 parent d3f3c83 commit 0c4e666

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

examples/2_BasicModels/random_forest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,15 @@
4747
loss_op = forest_graph.training_loss(X, Y)
4848

4949
# Measure the accuracy
50-
infer_op = forest_graph.inference_graph(X)
50+
infer_op, _, _ = forest_graph.inference_graph(X)
5151
correct_prediction = tf.equal(tf.argmax(infer_op, 1), tf.cast(Y, tf.int64))
5252
accuracy_op = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
5353

5454
# Initialize the variables (i.e. assign their default value)
5555
init_vars = tf.global_variables_initializer()
5656

5757
# Start TensorFlow session
58-
sess = tf.Session()
58+
sess = tf.train.MonitoredSession()
5959

6060
# Run the initializer
6161
sess.run(init_vars)

notebooks/2_BasicModels/random_forest.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@
122122
"loss_op = forest_graph.training_loss(X, Y)\n",
123123
"\n",
124124
"# Measure the accuracy\n",
125-
"infer_op = forest_graph.inference_graph(X)\n",
125+
"infer_op, _, _ = forest_graph.inference_graph(X)\n",
126126
"correct_prediction = tf.equal(tf.argmax(infer_op, 1), tf.cast(Y, tf.int64))\n",
127127
"accuracy_op = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))\n",
128128
"\n",
@@ -158,7 +158,7 @@
158158
],
159159
"source": [
160160
"# Start TensorFlow session\n",
161-
"sess = tf.Session()\n",
161+
"sess = tf.train.MonitoredSession()\n",
162162
"\n",
163163
"# Run the initializer\n",
164164
"sess.run(init_vars)\n",

0 commit comments

Comments
 (0)