Skip to content

Commit d43c58c

Browse files
committed
update random_forest
1 parent 6e94cd9 commit d43c58c

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

examples/2_BasicModels/random_forest.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import tensorflow as tf
1414
from tensorflow.contrib.tensor_forest.python import tensor_forest
15+
from tensorflow.python.ops import resources
1516

1617
# Ignore all GPUs, tf random forest does not benefit from it.
1718
import os
@@ -51,11 +52,12 @@
5152
correct_prediction = tf.equal(tf.argmax(infer_op, 1), tf.cast(Y, tf.int64))
5253
accuracy_op = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
5354

54-
# Initialize the variables (i.e. assign their default value)
55-
init_vars = tf.global_variables_initializer()
55+
# Initialize the variables (i.e. assign their default value) and forest resources
56+
init_vars = tf.group(tf.global_variables_initializer(),
57+
resources.initialize_resources(resources.shared_resources()))
5658

5759
# Start TensorFlow session
58-
sess = tf.train.MonitoredSession()
60+
sess = tf.Session()
5961

6062
# Run the initializer
6163
sess.run(init_vars)

notebooks/2_BasicModels/random_forest.ipynb

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,9 @@
126126
"correct_prediction = tf.equal(tf.argmax(infer_op, 1), tf.cast(Y, tf.int64))\n",
127127
"accuracy_op = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))\n",
128128
"\n",
129-
"# Initialize the variables (i.e. assign their default value)\n",
130-
"init_vars = tf.global_variables_initializer()"
129+
"# Initialize the variables (i.e. assign their default value) and forest resources\n",
130+
"init_vars = tf.group(tf.global_variables_initializer(),\n",
131+
" resources.initialize_resources(resources.shared_resources()))"
131132
]
132133
},
133134
{

0 commit comments

Comments
 (0)