Skip to content

Commit 19d6ca0

Browse files
decsterrobertnishihara
authored andcommitted
Support constructing TensorFlowVariables from multiple tf operations (ray-project#2182)
1 parent d699bfb commit 19d6ca0

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

python/ray/experimental/tfutils.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class TensorFlowVariables(object):
2828
assignment_nodes (Dict[str, tf.Tensor]): Nodes that assign weights.
2929
"""
3030

31-
def __init__(self, loss, sess=None, input_variables=None):
31+
def __init__(self, output, sess=None, input_variables=None):
3232
"""Creates TensorFlowVariables containing extracted variables.
3333
3434
The variables are extracted by performing a BFS search on the
@@ -38,18 +38,20 @@ def __init__(self, loss, sess=None, input_variables=None):
3838
variable has a placeholder and assignment operation created for it.
3939
4040
Args:
41-
loss (tf.Operation): The tensorflow operation to extract all
42-
variables from.
41+
output (tf.Operation, List[tf.Operation]): The tensorflow
42+
operation to extract all variables from.
4343
sess (tf.Session): Session used for running the get and set
4444
methods.
4545
input_variables (List[tf.Variables]): Variables to include in the
4646
list.
4747
"""
4848
import tensorflow as tf
4949
self.sess = sess
50-
queue = deque([loss])
50+
if not isinstance(output, (list, tuple)):
51+
output = [output]
52+
queue = deque(output)
5153
variable_names = []
52-
explored_inputs = {loss}
54+
explored_inputs = set(output)
5355

5456
# We do a BFS on the dependency graph of the input function to find
5557
# the variables.

test/tensorflow_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def testTensorFlowVariables(self):
128128
variables2.set_flat(flat_weights)
129129
assert_almost_equal(flat_weights, variables2.get_flat())
130130

131-
variables3 = ray.experimental.TensorFlowVariables(loss2)
131+
variables3 = ray.experimental.TensorFlowVariables([loss2])
132132
self.assertEqual(variables3.sess, None)
133133
sess = tf.Session()
134134
variables3.set_session(sess)

0 commit comments

Comments
 (0)