3
3
from __future__ import print_function
4
4
5
5
import unittest
6
+ import uuid
6
7
import tensorflow as tf
7
8
import ray
8
9
from numpy .testing import assert_almost_equal
9
10
11
+ def make_linear_network (w_name = None , b_name = None ):
12
+ # Define the inputs.
13
+ x_data = tf .placeholder (tf .float32 , shape = [100 ])
14
+ y_data = tf .placeholder (tf .float32 , shape = [100 ])
15
+ # Define the weights and computation.
16
+ w = tf .Variable (tf .random_uniform ([1 ], - 1.0 , 1.0 ), name = w_name )
17
+ b = tf .Variable (tf .zeros ([1 ]), name = b_name )
18
+ y = w * x_data + b
19
+ # Return the loss and weight initializer.
20
+ return tf .reduce_mean (tf .square (y - y_data )), tf .global_variables_initializer ()
21
+
22
+ def net_vars_initializer ():
23
+ # Random prefix so variable names do not clash if we use nets with
24
+ # the same name.
25
+ prefix = str (uuid .uuid1 ().hex )
26
+ # Use the tensorflow variable_scope to prefix all of the variables
27
+ with tf .variable_scope (prefix ):
28
+ # Create the network.
29
+ loss , init = make_linear_network ()
30
+ sess = tf .Session ()
31
+ # Additional code for setting and getting the weights.
32
+ variables = ray .experimental .TensorFlowVariables (loss , sess , prefix = True )
33
+ # Return all of the data needed to use the network.
34
+ return variables , init , sess
35
+
36
+ def net_vars_reinitializer (net_vars ):
37
+ return net_vars
38
+
10
39
class TensorFlowTest (unittest .TestCase ):
11
40
12
41
def testTensorFlowVariables (self ):
13
42
ray .init (num_workers = 2 )
14
43
15
- x_data = tf .placeholder (tf .float32 , shape = [100 ])
16
- y_data = tf .placeholder (tf .float32 , shape = [100 ])
17
-
18
- w = tf .Variable (tf .random_uniform ([1 ], - 1.0 , 1.0 ))
19
- b = tf .Variable (tf .zeros ([1 ]))
20
- y = w * x_data + b
21
- loss = tf .reduce_mean (tf .square (y - y_data ))
22
-
23
44
sess = tf .Session ()
24
- sess .run (tf .global_variables_initializer ())
45
+ loss , init = make_linear_network ()
46
+ sess .run (init )
25
47
26
48
variables = ray .experimental .TensorFlowVariables (loss , sess )
27
49
weights = variables .get_weights ()
@@ -32,12 +54,8 @@ def testTensorFlowVariables(self):
32
54
variables .set_weights (weights )
33
55
self .assertEqual (weights , variables .get_weights ())
34
56
35
- w2 = tf .Variable (tf .random_uniform ([1 ], - 1.0 , 1.0 ), name = "w" )
36
- b2 = tf .Variable (tf .zeros ([1 ]), name = "b" )
37
- y2 = w2 * x_data + b2
38
- loss2 = tf .reduce_mean (tf .square (y2 - y_data ))
39
-
40
- sess .run (tf .global_variables_initializer ())
57
+ loss2 , init2 = make_linear_network ("w" , "b" )
58
+ sess .run (init2 )
41
59
42
60
variables2 = ray .experimental .TensorFlowVariables (loss2 , sess )
43
61
weights2 = variables2 .get_weights ()
@@ -60,5 +78,114 @@ def testTensorFlowVariables(self):
60
78
61
79
ray .worker .cleanup ()
62
80
81
+ # Test that the variable names for the two different nets are not
82
+ # modified by TensorFlow to be unique (i.e. they should already
83
+ # be unique because of the variable prefix).
84
+ def testVariableNameCollision (self ):
85
+ ray .init (num_workers = 2 )
86
+
87
+ ray .env .net1 = ray .EnvironmentVariable (net_vars_initializer , net_vars_reinitializer )
88
+ ray .env .net2 = ray .EnvironmentVariable (net_vars_initializer , net_vars_reinitializer )
89
+
90
+ net_vars1 , init1 , sess1 = ray .env .net1
91
+ net_vars2 , init2 , sess2 = ray .env .net2
92
+
93
+ # Initialize the networks
94
+ sess1 .run (init1 )
95
+ sess2 .run (init2 )
96
+
97
+ # This is checking that the variable names of the two nets are the same,
98
+ # i.e. that the names in the weight dictionaries are the same
99
+ ray .env .net1 [0 ].set_weights (ray .env .net2 [0 ].get_weights ())
100
+
101
+ ray .worker .cleanup ()
102
+
103
+ # Test that different networks on the same worker are independent and
104
+ # we can get/set their weights without any interaction.
105
+ def testNetworksIndependent (self ):
106
+ # Note we use only one worker to ensure that all of the remote functions run on the same worker.
107
+ ray .init (num_workers = 1 )
108
+
109
+ ray .env .net1 = ray .EnvironmentVariable (net_vars_initializer , net_vars_reinitializer )
110
+ ray .env .net2 = ray .EnvironmentVariable (net_vars_initializer , net_vars_reinitializer )
111
+
112
+ net_vars1 , init1 , sess1 = ray .env .net1
113
+ net_vars2 , init2 , sess2 = ray .env .net2
114
+
115
+ # Initialize the networks
116
+ sess1 .run (init1 )
117
+ sess2 .run (init2 )
118
+
119
+ @ray .remote
120
+ def get_vars1 ():
121
+ return ray .env .net1 [0 ].get_weights ()
122
+
123
+ @ray .remote
124
+ def get_vars2 ():
125
+ return ray .env .net2 [0 ].get_weights ()
126
+
127
+ @ray .remote
128
+ def set_vars1 (weights ):
129
+ ray .env .net1 [0 ].set_weights (weights )
130
+
131
+ @ray .remote
132
+ def set_vars2 (weights ):
133
+ ray .env .net2 [0 ].set_weights (weights )
134
+
135
+ # Get the weights.
136
+ weights1 = net_vars1 .get_weights ()
137
+ weights2 = net_vars2 .get_weights ()
138
+ self .assertNotEqual (weights1 , weights2 )
139
+
140
+ # Swap the weights.
141
+ set_vars2 .remote (weights1 )
142
+ set_vars1 .remote (weights2 )
143
+
144
+ # Get the new weights.
145
+ new_weights1 = ray .get (get_vars1 .remote ())
146
+ new_weights2 = ray .get (get_vars2 .remote ())
147
+ self .assertNotEqual (new_weights1 , new_weights2 )
148
+
149
+ # Check that the weights were swapped.
150
+ self .assertEqual (weights1 , new_weights2 )
151
+ self .assertEqual (weights2 , new_weights1 )
152
+
153
+ ray .worker .cleanup ()
154
+
155
+ def testNetworkDriverWorkerIndependent (self ):
156
+ ray .init (num_workers = 1 )
157
+
158
+ # Create a network on the driver locally.
159
+ sess1 = tf .Session ()
160
+ loss1 , init1 = make_linear_network ()
161
+ net_vars1 = ray .experimental .TensorFlowVariables (loss1 , sess1 )
162
+ sess1 .run (init1 )
163
+
164
+ # Create a network on the driver via an environment variable.
165
+ ray .env .net = ray .EnvironmentVariable (net_vars_initializer , net_vars_reinitializer )
166
+
167
+ net_vars2 , init2 , sess2 = ray .env .net
168
+ sess2 .run (init2 )
169
+
170
+ # Get the weights.
171
+ weights1 = net_vars1 .get_weights ()
172
+ weights2 = net_vars2 .get_weights ()
173
+ self .assertNotEqual (weights1 , weights2 )
174
+
175
+ # Swap the weights.
176
+ net_vars1 .set_weights (weights2 )
177
+ net_vars2 .set_weights (weights1 )
178
+
179
+ # Get the new weights.
180
+ new_weights1 = net_vars1 .get_weights ()
181
+ new_weights2 = net_vars2 .get_weights ()
182
+ self .assertNotEqual (new_weights1 , new_weights2 )
183
+
184
+ # Check that the weights were swapped.
185
+ self .assertEqual (weights1 , new_weights2 )
186
+ self .assertEqual (weights2 , new_weights1 )
187
+
188
+ ray .worker .cleanup ()
189
+
63
190
if __name__ == "__main__" :
64
191
unittest .main (verbosity = 2 )
0 commit comments