Skip to content

Commit 0ac2abe

Browse files
Wapaul1pcmoritz
authored andcommitted
Added helper class for getting tf variables from loss function (ray-project#184)
* Added helper class for getting tf variables from loss function * Updated usage and documentation * Removed try-catches * Added futures * Added documentation * fixes and tests * more tests * install tensorflow in travis
1 parent c13d73b commit 0ac2abe

File tree

6 files changed

+109
-42
lines changed

6 files changed

+109
-42
lines changed

.travis.yml

+1
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ script:
7272

7373
- python test/runtest.py
7474
- python test/array_test.py
75+
- python test/tensorflow_test.py
7576
- python test/failure_test.py
7677
- python test/microbenchmarks.py
7778
- python test/stress_tests.py

.travis/install-dependencies.sh

+4-4
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ fi
2020
if [[ "$PYTHON" == "2.7" ]] && [[ "$platform" == "linux" ]]; then
2121
sudo apt-get update
2222
sudo apt-get install -y cmake build-essential autoconf curl libtool python-dev python-numpy python-pip libboost-all-dev unzip
23-
sudo pip install funcsigs colorama psutil redis
23+
sudo pip install funcsigs colorama psutil redis tensorflow
2424
sudo pip install --upgrade git+git://github.com/cloudpipe/cloudpickle.git@0d225a4695f1f65ae1cbb2e0bbc145e10167cce4
2525
elif [[ "$PYTHON" == "3.5" ]] && [[ "$platform" == "linux" ]]; then
2626
sudo apt-get update
@@ -29,7 +29,7 @@ elif [[ "$PYTHON" == "3.5" ]] && [[ "$platform" == "linux" ]]; then
2929
wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh
3030
bash miniconda.sh -b -p $HOME/miniconda
3131
export PATH="$HOME/miniconda/bin:$PATH"
32-
pip install numpy funcsigs colorama psutil redis
32+
pip install numpy funcsigs colorama psutil redis tensorflow
3333
pip install --upgrade git+git://github.com/cloudpipe/cloudpickle.git@0d225a4695f1f65ae1cbb2e0bbc145e10167cce4
3434
elif [[ "$PYTHON" == "2.7" ]] && [[ "$platform" == "macosx" ]]; then
3535
# check that brew is installed
@@ -43,7 +43,7 @@ elif [[ "$PYTHON" == "2.7" ]] && [[ "$platform" == "macosx" ]]; then
4343
fi
4444
brew install cmake automake autoconf libtool boost
4545
sudo easy_install pip
46-
sudo pip install numpy funcsigs colorama psutil redis --ignore-installed six
46+
sudo pip install numpy funcsigs colorama psutil redis tensorflow --ignore-installed six
4747
sudo pip install --upgrade git+git://github.com/cloudpipe/cloudpickle.git@0d225a4695f1f65ae1cbb2e0bbc145e10167cce4
4848
elif [[ "$PYTHON" == "3.5" ]] && [[ "$platform" == "macosx" ]]; then
4949
# check that brew is installed
@@ -60,7 +60,7 @@ elif [[ "$PYTHON" == "3.5" ]] && [[ "$platform" == "macosx" ]]; then
6060
wget https://repo.continuum.io/miniconda/Miniconda3-latest-MacOSX-x86_64.sh -O miniconda.sh
6161
bash miniconda.sh -b -p $HOME/miniconda
6262
export PATH="$HOME/miniconda/bin:$PATH"
63-
pip install numpy funcsigs colorama psutil redis
63+
pip install numpy funcsigs colorama psutil redis tensorflow
6464
pip install --upgrade git+git://github.com/cloudpipe/cloudpickle.git@0d225a4695f1f65ae1cbb2e0bbc145e10167cce4
6565
else
6666
echo "Unrecognized environment."

doc/using-ray-with-tensorflow.md

+13-38
Original file line numberDiff line numberDiff line change
@@ -37,38 +37,24 @@ init = tf.initialize_all_variables()
3737
sess = tf.Session()
3838
```
3939

40-
To extract the weights and set the weights, we need to write a couple lines of
41-
boilerplate code.
40+
To extract the weights and set the weights, you can call
4241

4342
```python
44-
def get_and_set_weights_methods():
45-
assignment_placeholders = []
46-
assignment_nodes = []
47-
for var in tf.trainable_variables():
48-
assignment_placeholders.append(tf.placeholder(var.value().dtype, var.get_shape().as_list()))
49-
assignment_nodes.append(var.assign(assignment_placeholders[-1]))
50-
# Define a function for getting the network weights.
51-
def get_weights():
52-
return [v.eval(session=sess) for v in tf.trainable_variables()]
53-
# Define a function for setting the network weights.
54-
def set_weights(new_weights):
55-
sess.run(assignment_nodes, feed_dict={p: w for p, w in zip(assignment_placeholders, new_weights)})
56-
# Return the methods.
57-
return get_weights, set_weights
58-
59-
get_weights, set_weights = get_and_set_weights_methods()
43+
variables = ray.experimental.TensorFlowVariables(loss, sess)
6044
```
6145

46+
which gives you methods to set and get the weights as well as collecting all of the variables in the model.
47+
6248
Now we can use these methods to extract the weights, and place them back in the
6349
network as follows.
6450

6551
```python
6652
# First initialize the weights.
6753
sess.run(init)
6854
# Get the weights
69-
weights = get_weights() # Returns a list of numpy arrays
55+
weights = variables.get_weights() # Returns a dictionary of numpy arrays
7056
# Set the weights
71-
set_weights(weights)
57+
variables.set_weights(weights)
7258
```
7359

7460
**Note:** If we were to set the weights using the `assign` method like below,
@@ -117,20 +103,9 @@ def net_vars_initializer():
117103
init = tf.initialize_all_variables()
118104
sess = tf.Session()
119105
# Additional code for setting and getting the weights.
120-
def get_and_set_weights_methods():
121-
assignment_placeholders = []
122-
assignment_nodes = []
123-
for var in tf.trainable_variables():
124-
assignment_placeholders.append(tf.placeholder(var.value().dtype, var.get_shape().as_list()))
125-
assignment_nodes.append(var.assign(assignment_placeholders[-1]))
126-
def get_weights():
127-
return [v.eval(session=sess) for v in tf.trainable_variables()]
128-
def set_weights(new_weights):
129-
sess.run(assignment_nodes, feed_dict={p: w for p, w in zip(assignment_placeholders, new_weights)})
130-
return get_weights, set_weights
131-
get_weights, set_weights = get_and_set_weights_methods()
106+
variables = ray.experimental.TensorFlowVariables(loss, sess)
132107
# Return all of the data needed to use the network.
133-
return get_weights, set_weights, sess, train, loss, x_data, y_data, init
108+
return variables, sess, train, loss, x_data, y_data, init
134109

135110
def net_vars_reinitializer(net_vars):
136111
return net_vars
@@ -142,19 +117,19 @@ ray.reusables.net_vars = ray.Reusable(net_vars_initializer, net_vars_reinitializ
142117
# new weights.
143118
@ray.remote
144119
def step(weights, x, y):
145-
get_weights, set_weights, sess, train, _, x_data, y_data, _ = ray.reusables.net_vars
120+
variables, sess, train, _, x_data, y_data, _ = ray.reusables.net_vars
146121
# Set the weights in the network.
147-
set_weights(weights)
122+
variables.set_weights(weights)
148123
# Do one step of training.
149124
sess.run(train, feed_dict={x_data: x, y_data: y})
150125
# Return the new weights.
151-
return get_weights()
126+
return variables.get_weights()
152127

153-
get_weights, set_weights, sess, _, loss, x_data, y_data, init = ray.reusables.net_vars
128+
variables, sess, _, loss, x_data, y_data, init = ray.reusables.net_vars
154129
# Initialize the network weights.
155130
sess.run(init)
156131
# Get the weights as a list of numpy arrays.
157-
weights = get_weights()
132+
weights = variables.get_weights()
158133

159134
# Define a remote function for generating fake data.
160135
@ray.remote(num_return_vals=2)

lib/python/ray/experimental/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
from __future__ import print_function
44

55
from .utils import copy_directory
6+
from .tfutils import TensorFlowVariables
+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from __future__ import absolute_import
2+
from __future__ import division
3+
from __future__ import print_function
4+
5+
class TensorFlowVariables(object):
6+
"""An object used to extract variables from a loss function, and provide
7+
methods for getting and setting the weights of said variables.
8+
9+
Attributes:
10+
sess (tf.Session): The tensorflow session used to run assignment.
11+
loss: The loss function passed in by the user.
12+
variables (List[tf.Variable]): Extracted variables from the loss.
13+
assignment_placeholders (List[tf.placeholders]): The nodes that weights get passed to.
14+
assignment_nodes (List[tf.Tensor]): The nodes that assign the weights.
15+
"""
16+
def __init__(self, loss, sess):
17+
"""Creates a TensorFlowVariables instance."""
18+
import tensorflow as tf
19+
self.sess = sess
20+
self.loss = loss
21+
variable_names = [op.node_def.name for op in loss.graph.get_operations() if op.node_def.op == "Variable"]
22+
self.variables = [v for v in tf.trainable_variables() if v.op.node_def.name in variable_names]
23+
self.assignment_placeholders = dict()
24+
self.assignment_nodes = []
25+
26+
# Create new placeholders to put in custom weights.
27+
for var in self.variables:
28+
self.assignment_placeholders[var.op.node_def.name] = tf.placeholder(var.value().dtype, var.get_shape().as_list())
29+
self.assignment_nodes.append(var.assign(self.assignment_placeholders[var.op.node_def.name]))
30+
31+
def get_weights(self):
32+
"""Returns the weights of the variables of the loss function in a list."""
33+
return {v.op.node_def.name: v.eval(session=self.sess) for v in self.variables}
34+
35+
def set_weights(self, new_weights):
36+
"""Sets the weights to new_weights."""
37+
self.sess.run(self.assignment_nodes, feed_dict={self.assignment_placeholders[name]: value for (name, value) in new_weights.items()})

test/tensorflow_test.py

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from __future__ import absolute_import
2+
from __future__ import division
3+
from __future__ import print_function
4+
5+
import unittest
6+
import tensorflow as tf
7+
import ray
8+
9+
class TensorFlowTest(unittest.TestCase):
10+
11+
def testTensorFlowVariables(self):
12+
ray.init(start_ray_local=True, num_workers=2)
13+
14+
x_data = tf.placeholder(tf.float32, shape=[100])
15+
y_data = tf.placeholder(tf.float32, shape=[100])
16+
17+
w = tf.Variable(tf.random_uniform([1], -1.0, 1.0))
18+
b = tf.Variable(tf.zeros([1]))
19+
y = w * x_data + b
20+
loss = tf.reduce_mean(tf.square(y - y_data))
21+
22+
sess = tf.Session()
23+
sess.run(tf.global_variables_initializer())
24+
25+
variables = ray.experimental.TensorFlowVariables(loss, sess)
26+
weights = variables.get_weights()
27+
28+
for (name, val) in weights.items():
29+
weights[name] += 1.0
30+
31+
variables.set_weights(weights)
32+
self.assertEqual(weights, variables.get_weights())
33+
34+
w2 = tf.Variable(tf.random_uniform([1], -1.0, 1.0), name="w")
35+
b2 = tf.Variable(tf.zeros([1]), name="b")
36+
y2 = w2 * x_data + b2
37+
loss2 = tf.reduce_mean(tf.square(y2 - y_data))
38+
39+
sess.run(tf.global_variables_initializer())
40+
41+
variables2 = ray.experimental.TensorFlowVariables(loss2, sess)
42+
weights2 = variables2.get_weights()
43+
44+
for (name, val) in weights2.items():
45+
weights2[name] += 2.0
46+
47+
variables2.set_weights(weights2)
48+
self.assertEqual(weights2, variables2.get_weights())
49+
50+
ray.worker.cleanup()
51+
52+
if __name__ == "__main__":
53+
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)