Skip to content

Commit 6fe69be

Browse files
Wapaul1robertnishihara
authored andcommitted
Selects from all variables now independent of graph, and uses standar… (ray-project#199)
* Smarter variable retrieval and doc update * doc update and small fixes * addressing robert's comments
1 parent 303d0fe commit 6fe69be

File tree

3 files changed

+204
-50
lines changed

3 files changed

+204
-50
lines changed

doc/using-ray-with-tensorflow.md

+36-24
Original file line numberDiff line numberDiff line change
@@ -72,15 +72,21 @@ b.assign(np.zeros(1)) # This adds a node to the graph every time you call it.
7272
## Complete Example
7373

7474
Putting this all together, we would first create the graph on each worker using
75-
environment variables. Within the environment variables, we would define
76-
`get_weights` and `set_weights` methods. We would then use those methods to ship
77-
the weights (as lists of numpy arrays) between the processes without shipping
78-
the actual TensorFlow graphs, which are much more complex Python objects.
75+
environment variables. Within the environment variables, we would use the
76+
`get_weights` and `set_weights` methods of the `TensorFlowVariables` class. We
77+
would then use those methods to ship the weights (as a dictionary of variable
78+
names mapping to tensorflow tensors) between the processes without shipping the
79+
actual TensorFlow graphs, which are much more complex Python objects. Note that
80+
to avoid namespace collision with already created variables on the workers, we
81+
use a variable_scope and a prefix in the environment variables and then pass
82+
true to the prefix in `TensorFlowVariables` so it can properly decode the variable
83+
names.
7984

8085
```python
8186
import tensorflow as tf
8287
import numpy as np
8388
import ray
89+
import uuid
8490

8591
ray.init(num_workers=5)
8692

@@ -89,25 +95,31 @@ NUM_BATCHES = 1
8995
NUM_ITERS = 201
9096

9197
def net_vars_initializer():
92-
# Seed TensorFlow to make the script deterministic.
93-
tf.set_random_seed(0)
94-
# Define the inputs.
95-
x_data = tf.placeholder(tf.float32, shape=[BATCH_SIZE])
96-
y_data = tf.placeholder(tf.float32, shape=[BATCH_SIZE])
97-
# Define the weights and computation.
98-
w = tf.Variable(tf.random_uniform([1], -1.0, 1.0))
99-
b = tf.Variable(tf.zeros([1]))
100-
y = w * x_data + b
101-
# Define the loss.
102-
loss = tf.reduce_mean(tf.square(y - y_data))
103-
optimizer = tf.train.GradientDescentOptimizer(0.5)
104-
train = optimizer.minimize(loss)
105-
# Define the weight initializer and session.
106-
init = tf.global_variables_initializer()
107-
sess = tf.Session()
108-
# Additional code for setting and getting the weights.
109-
variables = ray.experimental.TensorFlowVariables(loss, sess)
110-
# Return all of the data needed to use the network.
98+
# Prefix should be random so that there is no conflict with variable names in
99+
# the cluster setting.
100+
prefix = str(uuid.uuid1().hex)
101+
# Use the tensorflow variable_scope to prefix all of the variables
102+
with tf.variable_scope(prefix):
103+
# Seed TensorFlow to make the script deterministic.
104+
tf.set_random_seed(0)
105+
# Define the inputs.
106+
x_data = tf.placeholder(tf.float32, shape=[BATCH_SIZE])
107+
y_data = tf.placeholder(tf.float32, shape=[BATCH_SIZE])
108+
# Define the weights and computation.
109+
w = tf.Variable(tf.random_uniform([1], -1.0, 1.0))
110+
b = tf.Variable(tf.zeros([1]))
111+
y = w * x_data + b
112+
# Define the loss.
113+
loss = tf.reduce_mean(tf.square(y - y_data))
114+
optimizer = tf.train.GradientDescentOptimizer(0.5)
115+
train = optimizer.minimize(loss)
116+
# Define the weight initializer and session.
117+
init = tf.global_variables_initializer()
118+
sess = tf.Session()
119+
# Additional code for setting and getting the weights, and use a prefix
120+
# so that the variable names can be converted between workers.
121+
variables = ray.experimental.TensorFlowVariables(loss, sess, prefix=True)
122+
# Return all of the data needed to use the network.
111123
return variables, sess, train, loss, x_data, y_data, init
112124

113125
def net_vars_reinitializer(net_vars):
@@ -131,7 +143,7 @@ def step(weights, x, y):
131143
variables, sess, _, loss, x_data, y_data, init = ray.env.net_vars
132144
# Initialize the network weights.
133145
sess.run(init)
134-
# Get the weights as a list of numpy arrays.
146+
# Get the weights as a dictionary of numpy arrays.
135147
weights = variables.get_weights()
136148

137149
# Define a remote function for generating fake data.

python/ray/experimental/tfutils.py

+26-11
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from __future__ import division
33
from __future__ import print_function
44
import numpy as np
5+
from collections import deque, OrderedDict
56

67
def unflatten(vector, shapes):
78
i = 0
@@ -27,28 +28,42 @@ class TensorFlowVariables(object):
2728
assignment_placeholders (List[tf.placeholders]): The nodes that weights get
2829
passed to.
2930
assignment_nodes (List[tf.Tensor]): The nodes that assign the weights.
31+
prefix (Bool): Boolean for if there is a prefix on the variable names.
3032
"""
31-
def __init__(self, loss, sess=None):
33+
def __init__(self, loss, sess=None, prefix=False):
3234
"""Creates a TensorFlowVariables instance."""
3335
import tensorflow as tf
3436
self.sess = sess
3537
self.loss = loss
36-
variable_names = [op.node_def.name for op in loss.graph.get_operations() if op.node_def.op == "Variable"]
37-
self.variables = [v for v in tf.trainable_variables() if v.op.node_def.name in variable_names]
38+
self.prefix = prefix
39+
queue = deque([loss])
40+
variable_names = []
41+
42+
# We do a BFS on the dependency graph of the input function to find
43+
# the variables.
44+
while len(queue) != 0:
45+
op = queue.popleft().op
46+
queue.extend(op.inputs)
47+
if op.node_def.op == "Variable":
48+
variable_names.append(op.node_def.name)
49+
self.variables = OrderedDict()
50+
for v in [v for v in tf.global_variables() if v.op.node_def.name in variable_names]:
51+
name = v.op.node_def.name.split("/", 1 if prefix else 0)[-1]
52+
self.variables[name] = v
3853
self.assignment_placeholders = dict()
3954
self.assignment_nodes = []
4055

4156
# Create new placeholders to put in custom weights.
42-
for var in self.variables:
43-
self.assignment_placeholders[var.op.node_def.name] = tf.placeholder(var.value().dtype, var.get_shape().as_list())
44-
self.assignment_nodes.append(var.assign(self.assignment_placeholders[var.op.node_def.name]))
57+
for k, var in self.variables.items():
58+
self.assignment_placeholders[k] = tf.placeholder(var.value().dtype, var.get_shape().as_list())
59+
self.assignment_nodes.append(var.assign(self.assignment_placeholders[k]))
4560

4661
def set_session(self, sess):
4762
"""Modifies the current session used by the class."""
4863
self.sess = sess
4964

5065
def get_flat_size(self):
51-
return sum([np.prod(v.get_shape().as_list()) for v in self.variables])
66+
return sum([np.prod(v.get_shape().as_list()) for v in self.variables.values()])
5267

5368
def _check_sess(self):
5469
"""Checks if the session is set, and if not throw an error message."""
@@ -57,20 +72,20 @@ def _check_sess(self):
5772
def get_flat(self):
5873
"""Gets the weights and returns them as a flat array."""
5974
self._check_sess()
60-
return np.concatenate([v.eval(session=self.sess).flatten() for v in self.variables])
75+
return np.concatenate([v.eval(session=self.sess).flatten() for v in self.variables.values()])
6176

6277
def set_flat(self, new_weights):
6378
"""Sets the weights to new_weights, converting from a flat array."""
6479
self._check_sess()
65-
shapes = [v.get_shape().as_list() for v in self.variables]
80+
shapes = [v.get_shape().as_list() for v in self.variables.values()]
6681
arrays = unflatten(new_weights, shapes)
67-
placeholders = [self.assignment_placeholders[v.op.node_def.name] for v in self.variables]
82+
placeholders = [self.assignment_placeholders[k] for k, v in self.variables.items()]
6883
self.sess.run(self.assignment_nodes, feed_dict=dict(zip(placeholders,arrays)))
6984

7085
def get_weights(self):
7186
"""Returns the weights of the variables of the loss function in a list."""
7287
self._check_sess()
73-
return {v.op.node_def.name: v.eval(session=self.sess) for v in self.variables}
88+
return {k: v.eval(session=self.sess) for k, v in self.variables.items()}
7489

7590
def set_weights(self, new_weights):
7691
"""Sets the weights to new_weights."""

test/tensorflow_test.py

+142-15
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,47 @@
33
from __future__ import print_function
44

55
import unittest
6+
import uuid
67
import tensorflow as tf
78
import ray
89
from numpy.testing import assert_almost_equal
910

11+
def make_linear_network(w_name=None, b_name=None):
12+
# Define the inputs.
13+
x_data = tf.placeholder(tf.float32, shape=[100])
14+
y_data = tf.placeholder(tf.float32, shape=[100])
15+
# Define the weights and computation.
16+
w = tf.Variable(tf.random_uniform([1], -1.0, 1.0), name=w_name)
17+
b = tf.Variable(tf.zeros([1]), name=b_name)
18+
y = w * x_data + b
19+
# Return the loss and weight initializer.
20+
return tf.reduce_mean(tf.square(y - y_data)), tf.global_variables_initializer()
21+
22+
def net_vars_initializer():
23+
# Random prefix so variable names do not clash if we use nets with
24+
# the same name.
25+
prefix = str(uuid.uuid1().hex)
26+
# Use the tensorflow variable_scope to prefix all of the variables
27+
with tf.variable_scope(prefix):
28+
# Create the network.
29+
loss, init = make_linear_network()
30+
sess = tf.Session()
31+
# Additional code for setting and getting the weights.
32+
variables = ray.experimental.TensorFlowVariables(loss, sess, prefix=True)
33+
# Return all of the data needed to use the network.
34+
return variables, init, sess
35+
36+
def net_vars_reinitializer(net_vars):
37+
return net_vars
38+
1039
class TensorFlowTest(unittest.TestCase):
1140

1241
def testTensorFlowVariables(self):
1342
ray.init(num_workers=2)
1443

15-
x_data = tf.placeholder(tf.float32, shape=[100])
16-
y_data = tf.placeholder(tf.float32, shape=[100])
17-
18-
w = tf.Variable(tf.random_uniform([1], -1.0, 1.0))
19-
b = tf.Variable(tf.zeros([1]))
20-
y = w * x_data + b
21-
loss = tf.reduce_mean(tf.square(y - y_data))
22-
2344
sess = tf.Session()
24-
sess.run(tf.global_variables_initializer())
45+
loss, init = make_linear_network()
46+
sess.run(init)
2547

2648
variables = ray.experimental.TensorFlowVariables(loss, sess)
2749
weights = variables.get_weights()
@@ -32,12 +54,8 @@ def testTensorFlowVariables(self):
3254
variables.set_weights(weights)
3355
self.assertEqual(weights, variables.get_weights())
3456

35-
w2 = tf.Variable(tf.random_uniform([1], -1.0, 1.0), name="w")
36-
b2 = tf.Variable(tf.zeros([1]), name="b")
37-
y2 = w2 * x_data + b2
38-
loss2 = tf.reduce_mean(tf.square(y2 - y_data))
39-
40-
sess.run(tf.global_variables_initializer())
57+
loss2, init2 = make_linear_network("w", "b")
58+
sess.run(init2)
4159

4260
variables2 = ray.experimental.TensorFlowVariables(loss2, sess)
4361
weights2 = variables2.get_weights()
@@ -60,5 +78,114 @@ def testTensorFlowVariables(self):
6078

6179
ray.worker.cleanup()
6280

81+
# Test that the variable names for the two different nets are not
82+
# modified by TensorFlow to be unique (i.e. they should already
83+
# be unique because of the variable prefix).
84+
def testVariableNameCollision(self):
85+
ray.init(num_workers=2)
86+
87+
ray.env.net1 = ray.EnvironmentVariable(net_vars_initializer, net_vars_reinitializer)
88+
ray.env.net2 = ray.EnvironmentVariable(net_vars_initializer, net_vars_reinitializer)
89+
90+
net_vars1, init1, sess1 = ray.env.net1
91+
net_vars2, init2, sess2 = ray.env.net2
92+
93+
# Initialize the networks
94+
sess1.run(init1)
95+
sess2.run(init2)
96+
97+
# This is checking that the variable names of the two nets are the same,
98+
# i.e. that the names in the weight dictionaries are the same
99+
ray.env.net1[0].set_weights(ray.env.net2[0].get_weights())
100+
101+
ray.worker.cleanup()
102+
103+
# Test that different networks on the same worker are independent and
104+
# we can get/set their weights without any interaction.
105+
def testNetworksIndependent(self):
106+
# Note we use only one worker to ensure that all of the remote functions run on the same worker.
107+
ray.init(num_workers=1)
108+
109+
ray.env.net1 = ray.EnvironmentVariable(net_vars_initializer, net_vars_reinitializer)
110+
ray.env.net2 = ray.EnvironmentVariable(net_vars_initializer, net_vars_reinitializer)
111+
112+
net_vars1, init1, sess1 = ray.env.net1
113+
net_vars2, init2, sess2 = ray.env.net2
114+
115+
# Initialize the networks
116+
sess1.run(init1)
117+
sess2.run(init2)
118+
119+
@ray.remote
120+
def get_vars1():
121+
return ray.env.net1[0].get_weights()
122+
123+
@ray.remote
124+
def get_vars2():
125+
return ray.env.net2[0].get_weights()
126+
127+
@ray.remote
128+
def set_vars1(weights):
129+
ray.env.net1[0].set_weights(weights)
130+
131+
@ray.remote
132+
def set_vars2(weights):
133+
ray.env.net2[0].set_weights(weights)
134+
135+
# Get the weights.
136+
weights1 = net_vars1.get_weights()
137+
weights2 = net_vars2.get_weights()
138+
self.assertNotEqual(weights1, weights2)
139+
140+
# Swap the weights.
141+
set_vars2.remote(weights1)
142+
set_vars1.remote(weights2)
143+
144+
# Get the new weights.
145+
new_weights1 = ray.get(get_vars1.remote())
146+
new_weights2 = ray.get(get_vars2.remote())
147+
self.assertNotEqual(new_weights1, new_weights2)
148+
149+
# Check that the weights were swapped.
150+
self.assertEqual(weights1, new_weights2)
151+
self.assertEqual(weights2, new_weights1)
152+
153+
ray.worker.cleanup()
154+
155+
def testNetworkDriverWorkerIndependent(self):
156+
ray.init(num_workers=1)
157+
158+
# Create a network on the driver locally.
159+
sess1 = tf.Session()
160+
loss1, init1 = make_linear_network()
161+
net_vars1 = ray.experimental.TensorFlowVariables(loss1, sess1)
162+
sess1.run(init1)
163+
164+
# Create a network on the driver via an environment variable.
165+
ray.env.net = ray.EnvironmentVariable(net_vars_initializer, net_vars_reinitializer)
166+
167+
net_vars2, init2, sess2 = ray.env.net
168+
sess2.run(init2)
169+
170+
# Get the weights.
171+
weights1 = net_vars1.get_weights()
172+
weights2 = net_vars2.get_weights()
173+
self.assertNotEqual(weights1, weights2)
174+
175+
# Swap the weights.
176+
net_vars1.set_weights(weights2)
177+
net_vars2.set_weights(weights1)
178+
179+
# Get the new weights.
180+
new_weights1 = net_vars1.get_weights()
181+
new_weights2 = net_vars2.get_weights()
182+
self.assertNotEqual(new_weights1, new_weights2)
183+
184+
# Check that the weights were swapped.
185+
self.assertEqual(weights1, new_weights2)
186+
self.assertEqual(weights2, new_weights1)
187+
188+
ray.worker.cleanup()
189+
63190
if __name__ == "__main__":
64191
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)