Skip to content

Commit f2b6a7b

Browse files
Wapaul1robertnishihara
authored andcommitted
Polished TensorFlowVariables code and documentation (ray-project#566)
1 parent ca0f08d commit f2b6a7b

File tree

3 files changed

+210
-43
lines changed

3 files changed

+210
-43
lines changed

doc/source/using-ray-with-tensorflow.rst

+66-4
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ unmanageably large over time.
8282
w.assign(np.zeros(1)) # This adds a node to the graph every time you call it.
8383
b.assign(np.zeros(1)) # This adds a node to the graph every time you call it.
8484
85-
Complete Example
86-
----------------
85+
Complete Example for Weight Averaging
86+
-------------------------------------
8787

8888
Putting this all together, we would first embed the graph in an actor. Within
8989
the actor, we would use the ``get_weights`` and ``set_weights`` methods of the
@@ -185,8 +185,8 @@ complex Python objects.
185185
if iteration % 20 == 0:
186186
print("Iteration {}: weights are {}".format(iteration, weights))
187187
188-
How to Train in Parallel using Ray
189-
----------------------------------
188+
How to Train in Parallel using Ray and Gradients
189+
------------------------------------------------
190190

191191
In some cases, you may want to do data-parallel training on your network. We use the network
192192
above to illustrate how to do this in Ray. The only differences are in the remote function
@@ -320,3 +320,65 @@ For reference, the full code is below:
320320
# and 0.3 used in generate_fake_x_y_data.
321321
if iteration % 20 == 0:
322322
print("Iteration {}: weights are {}".format(iteration, weights))
323+
324+
.. autoclass:: ray.experimental.TensorFlowVariables
325+
:members:
326+
327+
Troubleshooting
328+
---------------
329+
330+
Note that ``TensorFlowVariables`` uses variable names to determine what
331+
variables to set when calling ``set_weights``. One common issue arises when two
332+
networks are defined in the same TensorFlow graph. In this case, TensorFlow
333+
appends an underscore and integer to the names of variables to disambiguate
334+
them. This will cause ``TensorFlowVariables`` to fail. For example, if we have a
335+
class definiton ``Network`` with a ``TensorFlowVariables`` instance:
336+
337+
.. code-block:: python
338+
339+
import ray
340+
import tensorflow as tf
341+
342+
class Network(object):
343+
def __init__(self):
344+
a = tf.Variable(1)
345+
b = tf.Variable(1)
346+
c = tf.add(a, b)
347+
sess = tf.Session()
348+
init = tf.global_variables_initializer()
349+
sess.run(init)
350+
self.variables = ray.experimental.TensorFlowVariables(c, sess)
351+
352+
def set_weights(self, weights):
353+
self.variables.set_weights(weights)
354+
355+
def get_weights(self):
356+
return self.variables.get_weights()
357+
358+
and run the following code:
359+
360+
.. code-block:: python
361+
362+
a = Network()
363+
b = Network()
364+
b.set_weights(a.get_weights())
365+
366+
the code would fail. If we instead defined each network in its own TensorFlow
367+
graph, then it would work:
368+
369+
.. code-block:: python
370+
371+
with tf.Graph().as_default():
372+
a = Network()
373+
with tf.Graph().as_default():
374+
b = Network()
375+
b.set_weights(a.get_weights())
376+
377+
This issue does not occur between actors that contain a network, as each actor
378+
is in its own process, and thus is in its own graph. This also does not occur
379+
when using ``set_flat``.
380+
381+
Another issue to keep in mind is that ``TensorFlowVariables`` needs to add new
382+
operations to the graph. If you close the graph and make it immutable, e.g.
383+
creating a ``MonitoredTrainingSession`` the initialization will fail. To resolve
384+
this, simply create the instance before you close the graph.

python/ray/experimental/tfutils.py

+89-31
Original file line numberDiff line numberDiff line change
@@ -18,24 +18,34 @@ def unflatten(vector, shapes):
1818

1919

2020
class TensorFlowVariables(object):
21-
"""An object used to extract variables from a loss function.
22-
23-
This object also provides methods for getting and setting the weights of
24-
the relevant variables.
21+
"""A class used to set and get weights for Tensorflow networks.
2522
2623
Attributes:
2724
sess (tf.Session): The tensorflow session used to run assignment.
28-
loss: The loss function passed in by the user.
29-
variables (List[tf.Variable]): Extracted variables from the loss.
30-
assignment_placeholders (List[tf.placeholders]): The nodes that weights
31-
get passed to.
32-
assignment _nodes (List[tf.Tensor]): The nodes that assign the weights.
25+
variables (Dict[str, tf.Variable]): Extracted variables from the loss
26+
or additional variables that are passed in.
27+
placeholders (Dict[str, tf.placeholders]): Placeholders for weights.
28+
assignment_nodes (Dict[str, tf.Tensor]): Nodes that assign weights.
3329
"""
34-
def __init__(self, loss, sess=None):
35-
"""Creates a TensorFlowVariables instance."""
30+
def __init__(self, loss, sess=None, input_variables=None):
31+
"""Creates TensorFlowVariables containing extracted variables.
32+
33+
The variables are extracted by performing a BFS search on the
34+
dependency graph with loss as the root node. After the tree is
35+
traversed and those variables are collected, we append input_variables
36+
to the collected variables. For each variable in the list, the
37+
variable has a placeholder and assignment operation created for it.
38+
39+
Args:
40+
loss (tf.Operation): The tensorflow operation to extract all
41+
variables from.
42+
sess (tf.Session): Session used for running the get and set
43+
methods.
44+
input_variables (List[tf.Variables]): Variables to include in the
45+
list.
46+
"""
3647
import tensorflow as tf
3748
self.sess = sess
38-
self.loss = loss
3949
queue = deque([loss])
4050
variable_names = []
4151
explored_inputs = set([loss])
@@ -44,9 +54,10 @@ def __init__(self, loss, sess=None):
4454
# the variables.
4555
while len(queue) != 0:
4656
tf_obj = queue.popleft()
47-
48-
# The object put into the queue is not necessarily an operation, so
49-
# we want the op attribute to get the operation underlying the
57+
if tf_obj is None:
58+
continue
59+
# The object put into the queue is not necessarily an operation,
60+
# so we want the op attribute to get the operation underlying the
5061
# object. Only operations contain the inputs that we can explore.
5162
if hasattr(tf_obj, "op"):
5263
tf_obj = tf_obj.op
@@ -63,23 +74,37 @@ def __init__(self, loss, sess=None):
6374
if "Variable" in tf_obj.node_def.op:
6475
variable_names.append(tf_obj.node_def.name)
6576
self.variables = OrderedDict()
66-
for v in [v for v in tf.global_variables()
67-
if v.op.node_def.name in variable_names]:
77+
variable_list = [v for v in tf.global_variables()
78+
if v.op.node_def.name in variable_names]
79+
if input_variables is not None:
80+
variable_list += input_variables
81+
for v in variable_list:
6882
self.variables[v.op.node_def.name] = v
83+
6984
self.placeholders = dict()
70-
self.assignment_nodes = []
85+
self.assignment_nodes = dict()
7186

7287
# Create new placeholders to put in custom weights.
7388
for k, var in self.variables.items():
7489
self.placeholders[k] = tf.placeholder(var.value().dtype,
75-
var.get_shape().as_list())
76-
self.assignment_nodes.append(var.assign(self.placeholders[k]))
90+
var.get_shape().as_list(),
91+
name="Placeholder_" + k)
92+
self.assignment_nodes[k] = var.assign(self.placeholders[k])
7793

7894
def set_session(self, sess):
79-
"""Modifies the current session used by the class."""
95+
"""Sets the current session used by the class.
96+
97+
Args:
98+
sess (tf.Session): Session to set the attribute with.
99+
"""
80100
self.sess = sess
81101

82102
def get_flat_size(self):
103+
"""Returns the total length of all of the flattened variables.
104+
105+
Returns:
106+
The length of all flattened variables concatenated.
107+
"""
83108
return sum([np.prod(v.get_shape().as_list())
84109
for v in self.variables.values()])
85110

@@ -91,31 +116,64 @@ def _check_sess(self):
91116
"calling set_session(sess).")
92117

93118
def get_flat(self):
94-
"""Gets the weights and returns them as a flat array."""
119+
"""Gets the weights and returns them as a flat array.
120+
121+
Returns:
122+
1D Array containing the flattened weights.
123+
"""
95124
self._check_sess()
96125
return np.concatenate([v.eval(session=self.sess).flatten()
97126
for v in self.variables.values()])
98127

99128
def set_flat(self, new_weights):
100-
"""Sets the weights to new_weights, converting from a flat array."""
129+
"""Sets the weights to new_weights, converting from a flat array.
130+
131+
Note:
132+
You can only set all weights in the network using this function,
133+
i.e., the length of the array must match get_flat_size.
134+
135+
Args:
136+
new_weights (np.ndarray): Flat array containing weights.
137+
"""
101138
self._check_sess()
102139
shapes = [v.get_shape().as_list() for v in self.variables.values()]
103140
arrays = unflatten(new_weights, shapes)
104-
placeholders = [self.placeholders[k]
105-
for k, v in self.variables.items()]
106-
self.sess.run(self.assignment_nodes,
141+
placeholders = [self.placeholders[k] for k, v
142+
in self.variables.items()]
143+
self.sess.run(list(self.assignment_nodes.values()),
107144
feed_dict=dict(zip(placeholders, arrays)))
108145

109146
def get_weights(self):
110-
"""Returns a list of the weights of the loss function variables."""
147+
"""Returns a dictionary containing the weights of the network.
148+
149+
Returns:
150+
Dictionary mapping variable names to their weights.
151+
"""
111152
self._check_sess()
112-
return {k: v.eval(session=self.sess)
113-
for k, v in self.variables.items()}
153+
return {k: v.eval(session=self.sess) for k, v
154+
in self.variables.items()}
114155

115156
def set_weights(self, new_weights):
116-
"""Sets the weights to new_weights."""
157+
"""Sets the weights to new_weights.
158+
159+
Note:
160+
Can set subsets of variables as well, by only passing in the
161+
variables you want to be set.
162+
163+
Args:
164+
new_weights (Dict): Dictionary mapping variable names to their
165+
weights.
166+
"""
117167
self._check_sess()
118-
self.sess.run(self.assignment_nodes,
168+
assign_list = [self.assignment_nodes[name]
169+
for name in new_weights.keys()
170+
if name in self.assignment_nodes]
171+
assert assign_list, ("No variables in the input matched those in the "
172+
"network. Possible cause: Two networks were "
173+
"defined in the same TensorFlow graph. To fix "
174+
"this, place each network definition in its own "
175+
"tf.Graph.")
176+
self.sess.run(assign_list,
119177
feed_dict={self.placeholders[name]: value
120178
for (name, value) in new_weights.items()
121179
if name in self.placeholders})

test/tensorflow_test.py

+55-8
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,32 @@ def make_linear_network(w_name=None, b_name=None):
2222
tf.global_variables_initializer(), x_data, y_data)
2323

2424

25+
class LossActor(object):
26+
27+
def __init__(self, use_loss=True):
28+
# Uses a separate graph for each network.
29+
with tf.Graph().as_default():
30+
# Create the network.
31+
var = [tf.Variable(1)]
32+
loss, init, _, _ = make_linear_network()
33+
sess = tf.Session()
34+
# Additional code for setting and getting the weights.
35+
weights = ray.experimental.TensorFlowVariables(loss if use_loss
36+
else None,
37+
sess,
38+
input_variables=var)
39+
# Return all of the data needed to use the network.
40+
self.values = [weights, init, sess]
41+
sess.run(init)
42+
43+
def set_and_get_weights(self, weights):
44+
self.values[0].set_weights(weights)
45+
return self.values[0].get_weights()
46+
47+
def get_weights(self):
48+
return self.values[0].get_weights()
49+
50+
2551
class NetActor(object):
2652

2753
def __init__(self):
@@ -102,7 +128,6 @@ def testTensorFlowVariables(self):
102128

103129
variables2.set_weights(weights2)
104130
self.assertEqual(weights2, variables2.get_weights())
105-
106131
flat_weights = variables2.get_flat() + 2.0
107132
variables2.set_flat(flat_weights)
108133
assert_almost_equal(flat_weights, variables2.get_flat())
@@ -114,7 +139,7 @@ def testTensorFlowVariables(self):
114139
self.assertEqual(variables3.sess, sess)
115140

116141
# Test that the variable names for the two different nets are not
117-
# modified by TensorFlow to be unique (i.e. they should already
142+
# modified by TensorFlow to be unique (i.e., they should already
118143
# be unique because of the variable prefix).
119144
def testVariableNameCollision(self):
120145
ray.init(num_workers=2)
@@ -123,9 +148,31 @@ def testVariableNameCollision(self):
123148
net2 = NetActor()
124149

125150
# This is checking that the variable names of the two nets are the
126-
# same, i.e. that the names in the weight dictionaries are the same
151+
# same, i.e., that the names in the weight dictionaries are the same.
127152
net1.values[0].set_weights(net2.values[0].get_weights())
128153

154+
# Test that TensorFlowVariables can take in addition variables through
155+
# input_variables arg and with no loss.
156+
def testAdditionalVariablesNoLoss(self):
157+
ray.init(num_workers=1)
158+
159+
net = LossActor(use_loss=False)
160+
self.assertEqual(len(net.values[0].variables.items()), 1)
161+
self.assertEqual(len(net.values[0].placeholders.items()), 1)
162+
163+
net.values[0].set_weights(net.values[0].get_weights())
164+
165+
# Test that TensorFlowVariables can take in addition variables through
166+
# input_variables arg and with a loss.
167+
def testAdditionalVariablesWithLoss(self):
168+
ray.init(num_workers=1)
169+
170+
net = LossActor()
171+
self.assertEqual(len(net.values[0].variables.items()), 3)
172+
self.assertEqual(len(net.values[0].placeholders.items()), 3)
173+
174+
net.values[0].set_weights(net.values[0].get_weights())
175+
129176
# Test that different networks on the same worker are independent and
130177
# we can get/set their weights without any interaction.
131178
def testNetworksIndependent(self):
@@ -197,12 +244,12 @@ def testRemoteTrainingLoss(self):
197244
ray.init(num_workers=2)
198245

199246
net = ray.remote(TrainActor).remote()
200-
(loss, variables, _, sess, grads,
201-
train, placeholders) = TrainActor().values
247+
net_values = TrainActor().values
248+
loss, variables, _, sess, grads, train, placeholders = net_values
202249

203-
before_acc = sess.run(loss,
204-
feed_dict=dict(zip(placeholders,
205-
[[2] * 100, [4] * 100])))
250+
before_acc = sess.run(loss, feed_dict=dict(zip(placeholders,
251+
[[2] * 100,
252+
[4] * 100])))
206253

207254
for _ in range(3):
208255
gradients_list = ray.get(

0 commit comments

Comments
 (0)