@@ -17,32 +17,39 @@ def make_linear_network(w_name=None, b_name=None):
17
17
b = tf .Variable (tf .zeros ([1 ]), name = b_name )
18
18
y = w * x_data + b
19
19
# Return the loss and weight initializer.
20
- return tf .reduce_mean (tf .square (y - y_data )), tf .global_variables_initializer ()
20
+ return tf .reduce_mean (tf .square (y - y_data )), tf .global_variables_initializer (), x_data , y_data
21
21
22
22
def net_vars_initializer ():
23
- # Random prefix so variable names do not clash if we use nets with
24
- # the same name.
25
- prefix = str (uuid .uuid1 ().hex )
26
- # Use the tensorflow variable_scope to prefix all of the variables
27
- with tf .variable_scope (prefix ):
23
+ # Uses a separate graph for each network.
24
+ with tf .Graph ().as_default ():
28
25
# Create the network.
29
- loss , init = make_linear_network ()
26
+ loss , init , _ , _ = make_linear_network ()
30
27
sess = tf .Session ()
31
28
# Additional code for setting and getting the weights.
32
- variables = ray .experimental .TensorFlowVariables (loss , sess , prefix = True )
29
+ variables = ray .experimental .TensorFlowVariables (loss , sess )
33
30
# Return all of the data needed to use the network.
34
31
return variables , init , sess
35
32
36
33
def net_vars_reinitializer (net_vars ):
37
34
return net_vars
38
35
36
+ def train_vars_initializer ():
37
+ # Almost the same as above, but now returns the placeholders and gradient.
38
+ with tf .Graph ().as_default ():
39
+ loss , init , x_data , y_data = make_linear_network ()
40
+ sess = tf .Session ()
41
+ variables = ray .experimental .TensorFlowVariables (loss , sess )
42
+ grad = tf .gradients (loss , list (variables .variables .values ()))
43
+ return variables , init , sess , grad , [x_data , y_data ]
44
+
45
+
39
46
class TensorFlowTest (unittest .TestCase ):
40
47
41
48
def testTensorFlowVariables (self ):
42
49
ray .init (num_workers = 2 )
43
50
44
51
sess = tf .Session ()
45
- loss , init = make_linear_network ()
52
+ loss , init , _ , _ = make_linear_network ()
46
53
sess .run (init )
47
54
48
55
variables = ray .experimental .TensorFlowVariables (loss , sess )
@@ -54,7 +61,7 @@ def testTensorFlowVariables(self):
54
61
variables .set_weights (weights )
55
62
self .assertEqual (weights , variables .get_weights ())
56
63
57
- loss2 , init2 = make_linear_network ("w" , "b" )
64
+ loss2 , init2 , _ , _ = make_linear_network ("w" , "b" )
58
65
sess .run (init2 )
59
66
60
67
variables2 = ray .experimental .TensorFlowVariables (loss2 , sess )
@@ -148,7 +155,7 @@ def testNetworkDriverWorkerIndependent(self):
148
155
149
156
# Create a network on the driver locally.
150
157
sess1 = tf .Session ()
151
- loss1 , init1 = make_linear_network ()
158
+ loss1 , init1 , _ , _ = make_linear_network ()
152
159
net_vars1 = ray .experimental .TensorFlowVariables (loss1 , sess1 )
153
160
sess1 .run (init1 )
154
161
@@ -170,5 +177,39 @@ def set_and_get_weights(weights):
170
177
171
178
ray .worker .cleanup ()
172
179
180
+ def testVariablesControlDependencies (self ):
181
+ ray .init (num_workers = 1 )
182
+
183
+ # Creates a network and appends a momentum optimizer.
184
+ sess = tf .Session ()
185
+ loss , init , _ , _ = make_linear_network ()
186
+ minimizer = tf .train .MomentumOptimizer (0.9 , 0.9 ).minimize (loss )
187
+ net_vars = ray .experimental .TensorFlowVariables (minimizer , sess )
188
+ sess .run (init )
189
+
190
+ # Tests if all variables are properly retrieved, 2 variables and 2 momentum
191
+ # variables.
192
+ self .assertEqual (len (net_vars .variables .items ()), 4 )
193
+
194
+ ray .worker .cleanup ()
195
+
196
+ def testRemoteTrainingStep (self ):
197
+ ray .init (num_workers = 1 )
198
+
199
+ ray .env .net = ray .EnvironmentVariable (train_vars_initializer , net_vars_reinitializer )
200
+
201
+ @ray .remote
202
+ def training_step (weights ):
203
+ variables , _ , sess , grad , placeholders = ray .env .net
204
+ variables .set_weights (weights )
205
+ return sess .run (grad , feed_dict = dict (zip (placeholders , [[1 ]* 100 ]* 2 )))
206
+
207
+ variables , init , sess , _ , _ = ray .env .net
208
+
209
+ sess .run (init )
210
+ ray .get (training_step .remote (variables .get_weights ()))
211
+
212
+ ray .worker .cleanup ()
213
+
173
214
if __name__ == "__main__" :
174
215
unittest .main (verbosity = 2 )
0 commit comments