@@ -37,38 +37,24 @@ init = tf.initialize_all_variables()
37
37
sess = tf.Session()
38
38
```
39
39
40
- To extract the weights and set the weights, we need to write a couple lines of
41
- boilerplate code.
40
+ To extract the weights and set the weights, you can call
42
41
43
42
``` python
44
- def get_and_set_weights_methods ():
45
- assignment_placeholders = []
46
- assignment_nodes = []
47
- for var in tf.trainable_variables():
48
- assignment_placeholders.append(tf.placeholder(var.value().dtype, var.get_shape().as_list()))
49
- assignment_nodes.append(var.assign(assignment_placeholders[- 1 ]))
50
- # Define a function for getting the network weights.
51
- def get_weights ():
52
- return [v.eval(session = sess) for v in tf.trainable_variables()]
53
- # Define a function for setting the network weights.
54
- def set_weights (new_weights ):
55
- sess.run(assignment_nodes, feed_dict = {p: w for p, w in zip (assignment_placeholders, new_weights)})
56
- # Return the methods.
57
- return get_weights, set_weights
58
-
59
- get_weights, set_weights = get_and_set_weights_methods()
43
+ variables = ray.experimental.TensorFlowVariables(loss, sess)
60
44
```
61
45
46
+ which gives you methods to set and get the weights as well as collecting all of the variables in the model.
47
+
62
48
Now we can use these methods to extract the weights, and place them back in the
63
49
network as follows.
64
50
65
51
``` python
66
52
# First initialize the weights.
67
53
sess.run(init)
68
54
# Get the weights
69
- weights = get_weights() # Returns a list of numpy arrays
55
+ weights = variables. get_weights() # Returns a dictionary of numpy arrays
70
56
# Set the weights
71
- set_weights(weights)
57
+ variables. set_weights(weights)
72
58
```
73
59
74
60
** Note:** If we were to set the weights using the ` assign ` method like below,
@@ -117,20 +103,9 @@ def net_vars_initializer():
117
103
init = tf.initialize_all_variables()
118
104
sess = tf.Session()
119
105
# Additional code for setting and getting the weights.
120
- def get_and_set_weights_methods ():
121
- assignment_placeholders = []
122
- assignment_nodes = []
123
- for var in tf.trainable_variables():
124
- assignment_placeholders.append(tf.placeholder(var.value().dtype, var.get_shape().as_list()))
125
- assignment_nodes.append(var.assign(assignment_placeholders[- 1 ]))
126
- def get_weights ():
127
- return [v.eval(session = sess) for v in tf.trainable_variables()]
128
- def set_weights (new_weights ):
129
- sess.run(assignment_nodes, feed_dict = {p: w for p, w in zip (assignment_placeholders, new_weights)})
130
- return get_weights, set_weights
131
- get_weights, set_weights = get_and_set_weights_methods()
106
+ variables = ray.experimental.TensorFlowVariables(loss, sess)
132
107
# Return all of the data needed to use the network.
133
- return get_weights, set_weights , sess, train, loss, x_data, y_data, init
108
+ return variables , sess, train, loss, x_data, y_data, init
134
109
135
110
def net_vars_reinitializer (net_vars ):
136
111
return net_vars
@@ -142,19 +117,19 @@ ray.reusables.net_vars = ray.Reusable(net_vars_initializer, net_vars_reinitializ
142
117
# new weights.
143
118
@ray.remote
144
119
def step (weights , x , y ):
145
- get_weights, set_weights , sess, train, _, x_data, y_data, _ = ray.reusables.net_vars
120
+ variables , sess, train, _, x_data, y_data, _ = ray.reusables.net_vars
146
121
# Set the weights in the network.
147
- set_weights(weights)
122
+ variables. set_weights(weights)
148
123
# Do one step of training.
149
124
sess.run(train, feed_dict = {x_data: x, y_data: y})
150
125
# Return the new weights.
151
- return get_weights()
126
+ return variables. get_weights()
152
127
153
- get_weights, set_weights , sess, _, loss, x_data, y_data, init = ray.reusables.net_vars
128
+ variables , sess, _, loss, x_data, y_data, init = ray.reusables.net_vars
154
129
# Initialize the network weights.
155
130
sess.run(init)
156
131
# Get the weights as a list of numpy arrays.
157
- weights = get_weights()
132
+ weights = variables. get_weights()
158
133
159
134
# Define a remote function for generating fake data.
160
135
@ray.remote (num_return_vals = 2 )
0 commit comments