Skip to content

Commit aaf3be3

Browse files
Wapaul1pcmoritz
authored andcommitted
Fixed lbfgs for ray-cluster (ray-project#180)
* Updated lbfgs example to include TensorflowVariables * Whitespace.
1 parent be4a37b commit aaf3be3

File tree

3 files changed

+141
-97
lines changed

3 files changed

+141
-97
lines changed

examples/lbfgs/driver.py

+96-95
Original file line numberDiff line numberDiff line change
@@ -10,103 +10,102 @@
1010

1111
from tensorflow.examples.tutorials.mnist import input_data
1212

13-
if __name__ == "__main__":
14-
ray.init(num_workers=10)
15-
16-
# Define the dimensions of the data and of the model.
17-
image_dimension = 784
18-
label_dimension = 10
19-
w_shape = [image_dimension, label_dimension]
20-
w_size = np.prod(w_shape)
21-
b_shape = [label_dimension]
22-
b_size = np.prod(b_shape)
23-
dim = w_size + b_size
24-
25-
# Define a function for initializing the network. Note that this code does not
26-
# call initialize the network weights. If it did, the weights would be
27-
# randomly initialized on each worker and would differ from worker to worker.
28-
# We pass the weights into the remote functions loss and grad so that the
29-
# weights are the same on each worker.
30-
def net_initialization():
31-
x = tf.placeholder(tf.float32, [None, image_dimension])
32-
w = tf.Variable(tf.zeros(w_shape))
33-
b = tf.Variable(tf.zeros(b_shape))
13+
class LinearModel(object):
14+
"""Simple class for a one layer neural network.
15+
16+
Note that this code does not initialize the network weights. Instead weights
17+
are set via self.variables.set_weights.
18+
19+
Example:
20+
net = LinearModel([10,10])
21+
weights = [np.random.normal(size=[10,10]), np.random.normal(size=[10])]
22+
variable_names = [v.name for v in net.variables]
23+
net.variables.set_weights(dict(zip(variable_names, weights)))
24+
25+
Attributes:
26+
x (tf.placeholder): Input vector.
27+
w (tf.Variable): Weight matrix.
28+
b (tf.Variable): Bias vector.
29+
y_ (tf.placeholder): Input result vector.
30+
cross_entropy (tf.Operation): Final layer of network.
31+
cross_entropy_grads (tf.Operation): Gradient computation.
32+
sess (tf.Session): Session used for training.
33+
variables (TensorFlowVariables): Extracted variables and methods to
34+
manipulate them.
35+
"""
36+
def __init__(self, shape):
37+
"""Creates a LinearModel object."""
38+
x = tf.placeholder(tf.float32, [None, shape[0]])
39+
w = tf.Variable(tf.zeros(shape))
40+
b = tf.Variable(tf.zeros(shape[1]))
41+
self.x = x
42+
self.w = w
43+
self.b = b
3444
y = tf.nn.softmax(tf.matmul(x, w) + b)
35-
y_ = tf.placeholder(tf.float32, [None, label_dimension])
45+
y_ = tf.placeholder(tf.float32, [None, shape[1]])
46+
self.y_ = y_
3647
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
37-
cross_entropy_grads = tf.gradients(cross_entropy, [w, b])
38-
39-
sess = tf.Session()
40-
41-
# In order to set the weights of the TensorFlow graph on a worker, we add
42-
# assignment nodes. To get the network weights (as a list of numpy arrays)
43-
# and to set the network weights (from a list of numpy arrays), use the
44-
# methods get_weights and set_weights. This can be done from within a remote
45-
# function or on the driver.
46-
def get_and_set_weights_methods():
47-
assignment_placeholders = []
48-
assignment_nodes = []
49-
for var in tf.trainable_variables():
50-
assignment_placeholders.append(tf.placeholder(var.value().dtype, var.get_shape().as_list()))
51-
assignment_nodes.append(var.assign(assignment_placeholders[-1]))
52-
53-
def get_weights():
54-
return [v.eval(session=sess) for v in tf.trainable_variables()]
55-
56-
def set_weights(new_weights):
57-
sess.run(assignment_nodes, feed_dict={p: w for p, w in zip(assignment_placeholders, new_weights)})
58-
59-
return get_weights, set_weights
60-
61-
get_weights, set_weights = get_and_set_weights_methods()
62-
63-
return sess, cross_entropy, cross_entropy_grads, x, y_, get_weights, set_weights
64-
65-
# By default, when a reusable variable is used by a remote function, the
66-
# initialization code will be rerun at the end of the remote task to ensure
67-
# that the state of the variable is not changed by the remote task. However,
68-
# the initialization code may be expensive. This case is one example, because
69-
# a TensorFlow network is constructed. In this case, we pass in a special
70-
# reinitialization function which gets run instead of the original
71-
# initialization code. As users, if we pass in custom reinitialization code,
72-
# we must ensure that no state is leaked between tasks.
73-
def net_reinitialization(net_vars):
74-
return net_vars
75-
76-
# Create a reusable variable for the network.
77-
ray.reusables.net_vars = ray.Reusable(net_initialization, net_reinitialization)
78-
79-
# Load the weights into the network.
80-
def load_weights(theta):
81-
sess, _, _, _, _, get_weights, set_weights = ray.reusables.net_vars
82-
set_weights([theta[:w_size].reshape(w_shape), theta[w_size:].reshape(b_shape)])
83-
84-
# Compute the loss on a batch of data.
85-
@ray.remote
86-
def loss(theta, xs, ys):
87-
sess, cross_entropy, _, x, y_, _, _ = ray.reusables.net_vars
88-
load_weights(theta)
89-
return float(sess.run(cross_entropy, feed_dict={x: xs, y_: ys}))
90-
91-
# Compute the gradient of the loss on a batch of data.
92-
@ray.remote
93-
def grad(theta, xs, ys):
94-
sess, _, cross_entropy_grads, x, y_, _, _ = ray.reusables.net_vars
95-
load_weights(theta)
96-
gradients = sess.run(cross_entropy_grads, feed_dict={x: xs, y_: ys})
97-
return np.concatenate([g.flatten() for g in gradients])
98-
99-
# Compute the loss on the entire dataset.
100-
def full_loss(theta):
101-
theta_id = ray.put(theta)
102-
loss_ids = [loss.remote(theta_id, xs_id, ys_id) for (xs_id, ys_id) in batch_ids]
103-
return sum(ray.get(loss_ids))
104-
105-
# Compute the gradient of the loss on the entire dataset.
106-
def full_grad(theta):
107-
theta_id = ray.put(theta)
108-
grad_ids = [grad.remote(theta_id, xs_id, ys_id) for (xs_id, ys_id) in batch_ids]
109-
return sum(ray.get(grad_ids)).astype("float64") # This conversion is necessary for use with fmin_l_bfgs_b.
48+
self.cross_entropy = cross_entropy
49+
self.cross_entropy_grads = tf.gradients(cross_entropy, [w, b])
50+
self.sess = tf.Session()
51+
# In order to get and set the weights, we pass in the loss function to Ray's
52+
# TensorFlowVariables to automatically create methods to modify the weights.
53+
self.variables = ray.experimental.TensorFlowVariables(cross_entropy, self.sess)
54+
55+
def loss(self, xs, ys):
56+
"""Computes the loss of the network."""
57+
return float(self.sess.run(self.cross_entropy, feed_dict={self.x: xs, self.y_: ys}))
58+
59+
def grad(self, xs, ys):
60+
"""Computes the gradients of the network."""
61+
return self.sess.run(self.cross_entropy_grads, feed_dict={self.x: xs, self.y_: ys})
62+
63+
def net_initialization():
64+
return LinearModel([784,10])
65+
66+
# By default, when a reusable variable is used by a remote function, the
67+
# initialization code will be rerun at the end of the remote task to ensure
68+
# that the state of the variable is not changed by the remote task. However,
69+
# the initialization code may be expensive. This case is one example, because
70+
# a TensorFlow network is constructed. In this case, we pass in a special
71+
# reinitialization function which gets run instead of the original
72+
# initialization code. As users, if we pass in custom reinitialization code,
73+
# we must ensure that no state is leaked between tasks.
74+
def net_reinitialization(net):
75+
return net
76+
77+
# Register the network with Ray and create a reusable variable for it.
78+
ray.reusables.net = ray.Reusable(net_initialization, net_reinitialization)
79+
80+
# Compute the loss on a batch of data.
81+
@ray.remote
82+
def loss(theta, xs, ys):
83+
net = ray.reusables.net
84+
net.variables.set_flat(theta)
85+
return net.loss(xs,ys)
86+
87+
# Compute the gradient of the loss on a batch of data.
88+
@ray.remote
89+
def grad(theta, xs, ys):
90+
net = ray.reusables.net
91+
net.variables.set_flat(theta)
92+
gradients = net.grad(xs, ys)
93+
return np.concatenate([g.flatten() for g in gradients])
94+
95+
# Compute the loss on the entire dataset.
96+
def full_loss(theta):
97+
theta_id = ray.put(theta)
98+
loss_ids = [loss.remote(theta_id, xs_id, ys_id) for (xs_id, ys_id) in batch_ids]
99+
return sum(ray.get(loss_ids))
100+
101+
# Compute the gradient of the loss on the entire dataset.
102+
def full_grad(theta):
103+
theta_id = ray.put(theta)
104+
grad_ids = [grad.remote(theta_id, xs_id, ys_id) for (xs_id, ys_id) in batch_ids]
105+
return sum(ray.get(grad_ids)).astype("float64") # This conversion is necessary for use with fmin_l_bfgs_b.
106+
107+
if __name__ == "__main__":
108+
ray.init(num_workers=10)
110109

111110
# From the perspective of scipy.optimize.fmin_l_bfgs_b, full_loss is simply a
112111
# function which takes some parameters theta, and computes a loss. Similarly,
@@ -128,7 +127,9 @@ def full_grad(theta):
128127
batch_ids = [(ray.put(xs), ray.put(ys)) for (xs, ys) in batches]
129128

130129
# Initialize the weights for the network to the vector of all zeros.
130+
dim = ray.reusables.net.variables.get_flat_size()
131131
theta_init = 1e-2 * np.random.normal(size=dim)
132+
132133
# Use L-BFGS to minimize the loss function.
133134
print("Running L-BFGS.")
134135
result = scipy.optimize.fmin_l_bfgs_b(full_loss, theta_init, maxiter=10, fprime=full_grad, disp=True)

lib/python/ray/experimental/tfutils.py

+34-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,18 @@
11
from __future__ import absolute_import
22
from __future__ import division
33
from __future__ import print_function
4+
import numpy as np
5+
6+
def unflatten(vector, shapes):
7+
i = 0
8+
arrays = []
9+
for shape in shapes:
10+
size = np.prod(shape)
11+
array = vector[i:(i + size)].reshape(shape)
12+
arrays.append(array)
13+
i += size
14+
assert len(vector) == i, "Passed weight does not have the correct shape."
15+
return arrays
416

517
class TensorFlowVariables(object):
618
"""An object used to extract variables from a loss function.
@@ -35,12 +47,32 @@ def set_session(self, sess):
3547
"""Modifies the current session used by the class."""
3648
self.sess = sess
3749

50+
def get_flat_size(self):
51+
return sum([np.prod(v.get_shape().as_list()) for v in self.variables])
52+
53+
def _check_sess(self):
54+
"""Checks if the session is set, and if not throw an error message."""
55+
assert self.sess is not None, "The session is not set. Set the session either by passing it into the TensorFlowVariables constructor or by calling set_session(sess)."
56+
57+
def get_flat(self):
58+
"""Gets the weights and returns them as a flat array."""
59+
self._check_sess()
60+
return np.concatenate([v.eval(session=self.sess).flatten() for v in self.variables])
61+
62+
def set_flat(self, new_weights):
63+
"""Sets the weights to new_weights, converting from a flat array."""
64+
self._check_sess()
65+
shapes = [v.get_shape().as_list() for v in self.variables]
66+
arrays = unflatten(new_weights, shapes)
67+
placeholders = [self.assignment_placeholders[v.op.node_def.name] for v in self.variables]
68+
self.sess.run(self.assignment_nodes, feed_dict=dict(zip(placeholders,arrays)))
69+
3870
def get_weights(self):
3971
"""Returns the weights of the variables of the loss function in a list."""
40-
assert self.sess is not None, "The session is not set. Set the session either by passing it into the TensorFlowVariables constructor or by calling set_session(sess)."
72+
self._check_sess()
4173
return {v.op.node_def.name: v.eval(session=self.sess) for v in self.variables}
4274

4375
def set_weights(self, new_weights):
4476
"""Sets the weights to new_weights."""
45-
assert self.sess is not None, "The session is not set. Set the session either by passing it into the TensorFlowVariables constructor or by calling set_session(sess)."
77+
self._check_sess()
4678
self.sess.run(self.assignment_nodes, feed_dict={self.assignment_placeholders[name]: value for (name, value) in new_weights.items()})

test/tensorflow_test.py

+11
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import unittest
66
import tensorflow as tf
77
import ray
8+
from numpy.testing import assert_almost_equal
89

910
class TensorFlowTest(unittest.TestCase):
1011

@@ -47,6 +48,16 @@ def testTensorFlowVariables(self):
4748
variables2.set_weights(weights2)
4849
self.assertEqual(weights2, variables2.get_weights())
4950

51+
flat_weights = variables2.get_flat() + 2.0
52+
variables2.set_flat(flat_weights)
53+
assert_almost_equal(flat_weights, variables2.get_flat())
54+
55+
variables3 = ray.experimental.TensorFlowVariables(loss2)
56+
self.assertEqual(variables3.sess, None)
57+
sess = tf.Session()
58+
variables3.set_session(sess)
59+
self.assertEqual(variables3.sess, sess)
60+
5061
ray.worker.cleanup()
5162

5263
if __name__ == "__main__":

0 commit comments

Comments
 (0)