@@ -31,7 +31,8 @@ y = w * x_data + b
31
31
32
32
loss = tf.reduce_mean(tf.square(y - y_data))
33
33
optimizer = tf.train.GradientDescentOptimizer(0.5 )
34
- train = optimizer.minimize(loss)
34
+ grads = optimizer.compute_gradients(loss)
35
+ train = optimizer.apply_gradients(grads)
35
36
36
37
init = tf.global_variables_initializer()
37
38
sess = tf.Session()
@@ -106,14 +107,15 @@ def net_vars_initializer():
106
107
# Define the loss.
107
108
loss = tf.reduce_mean(tf.square(y - y_data))
108
109
optimizer = tf.train.GradientDescentOptimizer(0.5 )
109
- train = optimizer.minimize(loss)
110
+ grads = optimizer.compute_gradients(loss)
111
+ train = optimizer.apply_gradients(grads)
110
112
# Define the weight initializer and session.
111
113
init = tf.global_variables_initializer()
112
114
sess = tf.Session()
113
115
# Additional code for setting and getting the weights
114
116
variables = ray.experimental.TensorFlowVariables(loss, sess)
115
117
# Return all of the data needed to use the network.
116
- return variables, sess, train, loss, x_data, y_data, init
118
+ return variables, sess, grads, train, loss, x_data, y_data, init
117
119
118
120
def net_vars_reinitializer (net_vars ):
119
121
return net_vars
@@ -125,15 +127,15 @@ ray.env.net_vars = ray.EnvironmentVariable(net_vars_initializer, net_vars_reinit
125
127
# new weights.
126
128
@ray.remote
127
129
def step (weights , x , y ):
128
- variables, sess, train, _, x_data, y_data, _ = ray.env.net_vars
130
+ variables, sess, _, train, _, x_data, y_data, _ = ray.env.net_vars
129
131
# Set the weights in the network.
130
132
variables.set_weights(weights)
131
133
# Do one step of training.
132
134
sess.run(train, feed_dict = {x_data: x, y_data: y})
133
135
# Return the new weights.
134
136
return variables.get_weights()
135
137
136
- variables, sess, _, loss, x_data, y_data, init = ray.env.net_vars
138
+ variables, sess, _, train, loss, x_data, y_data, init = ray.env.net_vars
137
139
# Initialize the network weights.
138
140
sess.run(init)
139
141
# Get the weights as a dictionary of numpy arrays.
@@ -176,3 +178,143 @@ for iteration in range(NUM_ITERS):
176
178
if iteration % 20 == 0 :
177
179
print (" Iteration {} : weights are {} " .format(iteration, weights))
178
180
```
181
+
182
+ ## How to Train in Parallel using Ray
183
+
184
+ In some cases, you may want to do data-parallel training on your network. We use the network
185
+ above to illustrate how to do this in Ray. The only differences are in the remote function
186
+ ` step ` and the driver code.
187
+
188
+ In the function ` step ` , we run the grad operation rather than the train operation to get the gradients.
189
+ Since Tensorflow pairs the gradients with the variables in a tuple, we extract the gradients to avoid
190
+ needless computation.
191
+
192
+ ### Extracting numerical gradients
193
+
194
+ Code like the following can be used in a remote function to compute numerical gradients.
195
+
196
+ ``` python
197
+ x_values = [1 ] * 100
198
+ y_values = [2 ] * 100
199
+ numerical_grads = sess.run([grad[0 ] for grad in grads], feed_dict = {x_data: x_values, y_data: y_values})
200
+ ```
201
+
202
+ ### Using the returned gradients to train the network
203
+
204
+ By pairing the symbolic gradients with the numerical gradients in a feed_dict, we can update the network.
205
+
206
+ ``` python
207
+ # We can feed the gradient values in using the associated symbolic gradient
208
+ # operation defined in tensorflow.
209
+ feed_dict = {grad[0 ]: numerical_grad for (grad, numerical_grad) in zip (grads, numerical_grads)}
210
+ sess.run(train, feed_dict = feed_dict)
211
+ ```
212
+
213
+ You can then run ` variables.get_weights() ` to see the updated weights of the network.
214
+
215
+ For reference, the full code is below:
216
+
217
+ ``` python
218
+ import tensorflow as tf
219
+ import numpy as np
220
+ import ray
221
+
222
+ ray.init(num_workers = 5 )
223
+
224
+ BATCH_SIZE = 100
225
+ NUM_BATCHES = 1
226
+ NUM_ITERS = 201
227
+
228
+ def net_vars_initializer ():
229
+ # Use a separate graph for each network.
230
+ with tf.Graph().as_default():
231
+ # Seed TensorFlow to make the script deterministic.
232
+ tf.set_random_seed(0 )
233
+ # Define the inputs.
234
+ x_data = tf.placeholder(tf.float32, shape = [BATCH_SIZE ])
235
+ y_data = tf.placeholder(tf.float32, shape = [BATCH_SIZE ])
236
+ # Define the weights and computation.
237
+ w = tf.Variable(tf.random_uniform([1 ], - 1.0 , 1.0 ))
238
+ b = tf.Variable(tf.zeros([1 ]))
239
+ y = w * x_data + b
240
+ # Define the loss.
241
+ loss = tf.reduce_mean(tf.square(y - y_data))
242
+ optimizer = tf.train.GradientDescentOptimizer(0.5 )
243
+ grads = optimizer.compute_gradients(loss)
244
+ train = optimizer.apply_gradients(grads)
245
+
246
+ # Define the weight initializer and session.
247
+ init = tf.global_variables_initializer()
248
+ sess = tf.Session()
249
+ # Additional code for setting and getting the weights
250
+ variables = ray.experimental.TensorFlowVariables(loss, sess)
251
+ # Return all of the data needed to use the network.
252
+ return variables, sess, grads, train, loss, x_data, y_data, init
253
+
254
+ def net_vars_reinitializer (net_vars ):
255
+ return net_vars
256
+
257
+ # Define an environment variable for the network variables.
258
+ ray.env.net_vars = ray.EnvironmentVariable(net_vars_initializer, net_vars_reinitializer)
259
+
260
+ # Define a remote function that trains the network for one step and returns the
261
+ # new weights.
262
+ @ray.remote
263
+ def step (weights , x , y ):
264
+ variables, sess, grads, _, _, x_data, y_data, _ = ray.env.net_vars
265
+ # Set the weights in the network.
266
+ variables.set_weights(weights)
267
+ # Do one step of training. We only need the actual gradients so we filter over the list.
268
+ actual_grads = sess.run([grad[0 ] for grad in grads], feed_dict = {x_data: x, y_data: y})
269
+ return actual_grads
270
+
271
+
272
+ variables, sess, grads, train, loss, x_data, y_data, init = ray.env.net_vars
273
+ # Initialize the network weights.
274
+ sess.run(init)
275
+ # Get the weights as a dictionary of numpy arrays.
276
+ weights = variables.get_weights()
277
+
278
+ # Define a remote function for generating fake data.
279
+ @ray.remote (num_return_vals = 2 )
280
+ def generate_fake_x_y_data (num_data , seed = 0 ):
281
+ # Seed numpy to make the script deterministic.
282
+ np.random.seed(seed)
283
+ x = np.random.rand(num_data)
284
+ y = x * 0.1 + 0.3
285
+ return x, y
286
+
287
+ # Generate some training data.
288
+ batch_ids = [generate_fake_x_y_data.remote(BATCH_SIZE , seed = i) for i in range (NUM_BATCHES )]
289
+ x_ids = [x_id for x_id, y_id in batch_ids]
290
+ y_ids = [y_id for x_id, y_id in batch_ids]
291
+ # Generate some test data.
292
+ x_test, y_test = ray.get(generate_fake_x_y_data.remote(BATCH_SIZE , seed = NUM_BATCHES ))
293
+
294
+
295
+ # Do some steps of training.
296
+ for iteration in range (NUM_ITERS ):
297
+ # Put the weights in the object store. This is optional. We could instead pass
298
+ # the variable weights directly into step.remote, in which case it would be
299
+ # placed in the object store under the hood. However, in that case multiple
300
+ # copies of the weights would be put in the object store, so this approach is
301
+ # more efficient.
302
+ weights_id = ray.put(weights)
303
+ # Call the remote function multiple times in parallel.
304
+ gradients_ids = [step.remote(weights_id, x_ids[i], y_ids[i]) for i in range (NUM_BATCHES )]
305
+ # Get all of the weights.
306
+ gradients_list = ray.get(gradients_ids)
307
+
308
+ # Take the mean of the different gradients. Each element of gradients_list is a list
309
+ # of gradients, and we want to take the mean of each one.
310
+ mean_grads = [sum ([gradients[i] for gradients in gradients_list]) / len (gradients_list) for i in range (len (gradients_list[0 ]))]
311
+
312
+ feed_dict = {grad[0 ]: mean_grad for (grad, mean_grad) in zip (grads, mean_grads)}
313
+ sess.run(train, feed_dict = feed_dict)
314
+ weights = variables.get_weights()
315
+
316
+ # Print the current weights. They should converge to roughly to the values 0.1
317
+ # and 0.3 used in generate_fake_x_y_data.
318
+ if iteration % 20 == 0 :
319
+ print (" Iteration {} : weights are {} " .format(iteration, weights))
320
+ ```
0 commit comments