Skip to content

Commit 1a7e1c4

Browse files
Wapaul1pcmoritz
authored andcommitted
Added example for compute grads in ray tutorial (ray-project#238)
* Added example for compute grads in ray * Added formatting * Removed need for placeholders in apply gradient * Streamlined examples * Fixed docs * Added formatting * Removed old references * Simplified code some * Addressed comments * Changes to first code block * Added test for training and updated code snippets * Formatting * Removed mean * Removed all mention of mean * Added comments * Added comments
1 parent 1fec94e commit 1a7e1c4

File tree

3 files changed

+185
-15
lines changed

3 files changed

+185
-15
lines changed

doc/using-ray-with-tensorflow.md

+147-5
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ y = w * x_data + b
3131

3232
loss = tf.reduce_mean(tf.square(y - y_data))
3333
optimizer = tf.train.GradientDescentOptimizer(0.5)
34-
train = optimizer.minimize(loss)
34+
grads = optimizer.compute_gradients(loss)
35+
train = optimizer.apply_gradients(grads)
3536

3637
init = tf.global_variables_initializer()
3738
sess = tf.Session()
@@ -106,14 +107,15 @@ def net_vars_initializer():
106107
# Define the loss.
107108
loss = tf.reduce_mean(tf.square(y - y_data))
108109
optimizer = tf.train.GradientDescentOptimizer(0.5)
109-
train = optimizer.minimize(loss)
110+
grads = optimizer.compute_gradients(loss)
111+
train = optimizer.apply_gradients(grads)
110112
# Define the weight initializer and session.
111113
init = tf.global_variables_initializer()
112114
sess = tf.Session()
113115
# Additional code for setting and getting the weights
114116
variables = ray.experimental.TensorFlowVariables(loss, sess)
115117
# Return all of the data needed to use the network.
116-
return variables, sess, train, loss, x_data, y_data, init
118+
return variables, sess, grads, train, loss, x_data, y_data, init
117119

118120
def net_vars_reinitializer(net_vars):
119121
return net_vars
@@ -125,15 +127,15 @@ ray.env.net_vars = ray.EnvironmentVariable(net_vars_initializer, net_vars_reinit
125127
# new weights.
126128
@ray.remote
127129
def step(weights, x, y):
128-
variables, sess, train, _, x_data, y_data, _ = ray.env.net_vars
130+
variables, sess, _, train, _, x_data, y_data, _ = ray.env.net_vars
129131
# Set the weights in the network.
130132
variables.set_weights(weights)
131133
# Do one step of training.
132134
sess.run(train, feed_dict={x_data: x, y_data: y})
133135
# Return the new weights.
134136
return variables.get_weights()
135137

136-
variables, sess, _, loss, x_data, y_data, init = ray.env.net_vars
138+
variables, sess, _, train, loss, x_data, y_data, init = ray.env.net_vars
137139
# Initialize the network weights.
138140
sess.run(init)
139141
# Get the weights as a dictionary of numpy arrays.
@@ -176,3 +178,143 @@ for iteration in range(NUM_ITERS):
176178
if iteration % 20 == 0:
177179
print("Iteration {}: weights are {}".format(iteration, weights))
178180
```
181+
182+
## How to Train in Parallel using Ray
183+
184+
In some cases, you may want to do data-parallel training on your network. We use the network
185+
above to illustrate how to do this in Ray. The only differences are in the remote function
186+
`step` and the driver code.
187+
188+
In the function `step`, we run the grad operation rather than the train operation to get the gradients.
189+
Since Tensorflow pairs the gradients with the variables in a tuple, we extract the gradients to avoid
190+
needless computation.
191+
192+
### Extracting numerical gradients
193+
194+
Code like the following can be used in a remote function to compute numerical gradients.
195+
196+
```python
197+
x_values = [1] * 100
198+
y_values = [2] * 100
199+
numerical_grads = sess.run([grad[0] for grad in grads], feed_dict={x_data: x_values, y_data: y_values})
200+
```
201+
202+
### Using the returned gradients to train the network
203+
204+
By pairing the symbolic gradients with the numerical gradients in a feed_dict, we can update the network.
205+
206+
```python
207+
# We can feed the gradient values in using the associated symbolic gradient
208+
# operation defined in tensorflow.
209+
feed_dict = {grad[0]: numerical_grad for (grad, numerical_grad) in zip(grads, numerical_grads)}
210+
sess.run(train, feed_dict=feed_dict)
211+
```
212+
213+
You can then run `variables.get_weights()` to see the updated weights of the network.
214+
215+
For reference, the full code is below:
216+
217+
```python
218+
import tensorflow as tf
219+
import numpy as np
220+
import ray
221+
222+
ray.init(num_workers=5)
223+
224+
BATCH_SIZE = 100
225+
NUM_BATCHES = 1
226+
NUM_ITERS = 201
227+
228+
def net_vars_initializer():
229+
# Use a separate graph for each network.
230+
with tf.Graph().as_default():
231+
# Seed TensorFlow to make the script deterministic.
232+
tf.set_random_seed(0)
233+
# Define the inputs.
234+
x_data = tf.placeholder(tf.float32, shape=[BATCH_SIZE])
235+
y_data = tf.placeholder(tf.float32, shape=[BATCH_SIZE])
236+
# Define the weights and computation.
237+
w = tf.Variable(tf.random_uniform([1], -1.0, 1.0))
238+
b = tf.Variable(tf.zeros([1]))
239+
y = w * x_data + b
240+
# Define the loss.
241+
loss = tf.reduce_mean(tf.square(y - y_data))
242+
optimizer = tf.train.GradientDescentOptimizer(0.5)
243+
grads = optimizer.compute_gradients(loss)
244+
train = optimizer.apply_gradients(grads)
245+
246+
# Define the weight initializer and session.
247+
init = tf.global_variables_initializer()
248+
sess = tf.Session()
249+
# Additional code for setting and getting the weights
250+
variables = ray.experimental.TensorFlowVariables(loss, sess)
251+
# Return all of the data needed to use the network.
252+
return variables, sess, grads, train, loss, x_data, y_data, init
253+
254+
def net_vars_reinitializer(net_vars):
255+
return net_vars
256+
257+
# Define an environment variable for the network variables.
258+
ray.env.net_vars = ray.EnvironmentVariable(net_vars_initializer, net_vars_reinitializer)
259+
260+
# Define a remote function that trains the network for one step and returns the
261+
# new weights.
262+
@ray.remote
263+
def step(weights, x, y):
264+
variables, sess, grads, _, _, x_data, y_data, _ = ray.env.net_vars
265+
# Set the weights in the network.
266+
variables.set_weights(weights)
267+
# Do one step of training. We only need the actual gradients so we filter over the list.
268+
actual_grads = sess.run([grad[0] for grad in grads], feed_dict={x_data: x, y_data: y})
269+
return actual_grads
270+
271+
272+
variables, sess, grads, train, loss, x_data, y_data, init = ray.env.net_vars
273+
# Initialize the network weights.
274+
sess.run(init)
275+
# Get the weights as a dictionary of numpy arrays.
276+
weights = variables.get_weights()
277+
278+
# Define a remote function for generating fake data.
279+
@ray.remote(num_return_vals=2)
280+
def generate_fake_x_y_data(num_data, seed=0):
281+
# Seed numpy to make the script deterministic.
282+
np.random.seed(seed)
283+
x = np.random.rand(num_data)
284+
y = x * 0.1 + 0.3
285+
return x, y
286+
287+
# Generate some training data.
288+
batch_ids = [generate_fake_x_y_data.remote(BATCH_SIZE, seed=i) for i in range(NUM_BATCHES)]
289+
x_ids = [x_id for x_id, y_id in batch_ids]
290+
y_ids = [y_id for x_id, y_id in batch_ids]
291+
# Generate some test data.
292+
x_test, y_test = ray.get(generate_fake_x_y_data.remote(BATCH_SIZE, seed=NUM_BATCHES))
293+
294+
295+
# Do some steps of training.
296+
for iteration in range(NUM_ITERS):
297+
# Put the weights in the object store. This is optional. We could instead pass
298+
# the variable weights directly into step.remote, in which case it would be
299+
# placed in the object store under the hood. However, in that case multiple
300+
# copies of the weights would be put in the object store, so this approach is
301+
# more efficient.
302+
weights_id = ray.put(weights)
303+
# Call the remote function multiple times in parallel.
304+
gradients_ids = [step.remote(weights_id, x_ids[i], y_ids[i]) for i in range(NUM_BATCHES)]
305+
# Get all of the weights.
306+
gradients_list = ray.get(gradients_ids)
307+
308+
# Take the mean of the different gradients. Each element of gradients_list is a list
309+
# of gradients, and we want to take the mean of each one.
310+
mean_grads = [sum([gradients[i] for gradients in gradients_list]) / len(gradients_list) for i in range(len(gradients_list[0]))]
311+
312+
feed_dict = {grad[0]: mean_grad for (grad, mean_grad) in zip(grads, mean_grads)}
313+
sess.run(train, feed_dict=feed_dict)
314+
weights = variables.get_weights()
315+
316+
# Print the current weights. They should converge to roughly to the values 0.1
317+
# and 0.3 used in generate_fake_x_y_data.
318+
if iteration % 20 == 0:
319+
print("Iteration {}: weights are {}".format(iteration, weights))
320+
```

python/ray/experimental/tfutils.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,13 @@ def __init__(self, loss, sess=None):
6363
self.variables = OrderedDict()
6464
for v in [v for v in tf.global_variables() if v.op.node_def.name in variable_names]:
6565
self.variables[v.op.node_def.name] = v
66-
self.assignment_placeholders = dict()
66+
self.placeholders = dict()
6767
self.assignment_nodes = []
6868

6969
# Create new placeholders to put in custom weights.
7070
for k, var in self.variables.items():
71-
self.assignment_placeholders[k] = tf.placeholder(var.value().dtype, var.get_shape().as_list())
72-
self.assignment_nodes.append(var.assign(self.assignment_placeholders[k]))
71+
self.placeholders[k] = tf.placeholder(var.value().dtype, var.get_shape().as_list())
72+
self.assignment_nodes.append(var.assign(self.placeholders[k]))
7373

7474
def set_session(self, sess):
7575
"""Modifies the current session used by the class."""
@@ -92,7 +92,7 @@ def set_flat(self, new_weights):
9292
self._check_sess()
9393
shapes = [v.get_shape().as_list() for v in self.variables.values()]
9494
arrays = unflatten(new_weights, shapes)
95-
placeholders = [self.assignment_placeholders[k] for k, v in self.variables.items()]
95+
placeholders = [self.placeholders[k] for k, v in self.variables.items()]
9696
self.sess.run(self.assignment_nodes, feed_dict=dict(zip(placeholders,arrays)))
9797

9898
def get_weights(self):
@@ -103,4 +103,4 @@ def get_weights(self):
103103
def set_weights(self, new_weights):
104104
"""Sets the weights to new_weights."""
105105
self._check_sess()
106-
self.sess.run(self.assignment_nodes, feed_dict={self.assignment_placeholders[name]: value for (name, value) in new_weights.items()})
106+
self.sess.run(self.assignment_nodes, feed_dict={self.placeholders[name]: value for (name, value) in new_weights.items()})

test/tensorflow_test.py

+33-5
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,10 @@ def train_vars_initializer():
3939
loss, init, x_data, y_data = make_linear_network()
4040
sess = tf.Session()
4141
variables = ray.experimental.TensorFlowVariables(loss, sess)
42-
grad = tf.gradients(loss, list(variables.variables.values()))
43-
return variables, init, sess, grad, [x_data, y_data]
42+
optimizer = tf.train.GradientDescentOptimizer(0.9)
43+
grads = optimizer.compute_gradients(loss)
44+
train = optimizer.apply_gradients(grads)
45+
return loss, variables, init, sess, grads, train, [x_data, y_data]
4446

4547

4648
class TensorFlowTest(unittest.TestCase):
@@ -200,16 +202,42 @@ def testRemoteTrainingStep(self):
200202

201203
@ray.remote
202204
def training_step(weights):
203-
variables, _, sess, grad, placeholders = ray.env.net
205+
_, variables, _, sess, grads, _, placeholders = ray.env.net
204206
variables.set_weights(weights)
205-
return sess.run(grad, feed_dict=dict(zip(placeholders, [[1]*100]*2)))
207+
return sess.run([grad[0] for grad in grads], feed_dict=dict(zip(placeholders, [[1]*100]*2)))
206208

207-
variables, init, sess, _, _ = ray.env.net
209+
_, variables, init, sess, _, _, _ = ray.env.net
208210

209211
sess.run(init)
210212
ray.get(training_step.remote(variables.get_weights()))
211213

212214
ray.worker.cleanup()
213215

216+
217+
def testRemoteTrainingLoss(self):
218+
ray.init(num_workers=2)
219+
220+
ray.env.net = ray.EnvironmentVariable(train_vars_initializer, net_vars_reinitializer)
221+
222+
@ray.remote
223+
def training_step(weights):
224+
_, variables, _, sess, grads, _, placeholders = ray.env.net
225+
variables.set_weights(weights)
226+
return sess.run([grad[0] for grad in grads], feed_dict=dict(zip(placeholders, [[1]*100, [2]*100])))
227+
228+
loss, variables, init, sess, grads, train, placeholders = ray.env.net
229+
230+
sess.run(init)
231+
before_acc = sess.run(loss, feed_dict=dict(zip(placeholders, [[2]*100, [4]*100])))
232+
233+
for _ in range(3):
234+
gradients_list = ray.get([training_step.remote(variables.get_weights()) for _ in range(2)])
235+
mean_grads = [sum([gradients[i] for gradients in gradients_list]) / len(gradients_list) for i in range(len(gradients_list[0]))]
236+
feed_dict = {grad[0]: mean_grad for (grad, mean_grad) in zip(grads, mean_grads)}
237+
sess.run(train, feed_dict=feed_dict)
238+
after_acc = sess.run(loss, feed_dict=dict(zip(placeholders, [[2]*100, [4]*100])))
239+
self.assertTrue(before_acc < after_acc)
240+
ray.worker.cleanup()
241+
214242
if __name__ == "__main__":
215243
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)