@@ -86,7 +86,7 @@ def testVariableNameCollision(self):
86
86
87
87
ray .env .net1 = ray .EnvironmentVariable (net_vars_initializer , net_vars_reinitializer )
88
88
ray .env .net2 = ray .EnvironmentVariable (net_vars_initializer , net_vars_reinitializer )
89
-
89
+
90
90
net_vars1 , init1 , sess1 = ray .env .net1
91
91
net_vars2 , init2 , sess2 = ray .env .net2
92
92
@@ -108,7 +108,7 @@ def testNetworksIndependent(self):
108
108
109
109
ray .env .net1 = ray .EnvironmentVariable (net_vars_initializer , net_vars_reinitializer )
110
110
ray .env .net2 = ray .EnvironmentVariable (net_vars_initializer , net_vars_reinitializer )
111
-
111
+
112
112
net_vars1 , init1 , sess1 = ray .env .net1
113
113
net_vars2 , init2 , sess2 = ray .env .net2
114
114
@@ -117,41 +117,32 @@ def testNetworksIndependent(self):
117
117
sess2 .run (init2 )
118
118
119
119
@ray .remote
120
- def get_vars1 ():
121
- return ray .env .net1 [0 ].get_weights ()
122
-
123
- @ray .remote
124
- def get_vars2 ():
125
- return ray .env .net2 [0 ].get_weights ()
126
-
127
- @ray .remote
128
- def set_vars1 (weights ):
129
- ray .env .net1 [0 ].set_weights (weights )
130
-
131
- @ray .remote
132
- def set_vars2 (weights ):
133
- ray .env .net2 [0 ].set_weights (weights )
134
-
135
- # Get the weights.
120
+ def set_and_get_weights (weights1 , weights2 ):
121
+ ray .env .net1 [0 ].set_weights (weights1 )
122
+ ray .env .net2 [0 ].set_weights (weights2 )
123
+ return ray .env .net1 [0 ].get_weights (), ray .env .net2 [0 ].get_weights ()
124
+
125
+ # Make sure the two networks have different weights. TODO(rkn): Note that
126
+ # equality comparisons of numpy arrays normally does not work. This only
127
+ # works because at the moment they have size 1.
136
128
weights1 = net_vars1 .get_weights ()
137
129
weights2 = net_vars2 .get_weights ()
138
130
self .assertNotEqual (weights1 , weights2 )
139
131
140
- # Swap the weights.
141
- set_vars2 .remote (weights1 )
142
- set_vars1 .remote (weights2 )
143
-
144
- # Get the new weights.
145
- new_weights1 = ray .get (get_vars1 .remote ())
146
- new_weights2 = ray .get (get_vars2 .remote ())
147
- self .assertNotEqual (new_weights1 , new_weights2 )
132
+ # Set the weights and get the weights, and make sure they are unchanged.
133
+ new_weights1 , new_weights2 = ray .get (set_and_get_weights .remote (weights1 , weights2 ))
134
+ self .assertEqual (weights1 , new_weights1 )
135
+ self .assertEqual (weights2 , new_weights2 )
148
136
149
- # Check that the weights were swapped.
150
- self .assertEqual (weights1 , new_weights2 )
151
- self .assertEqual (weights2 , new_weights1 )
137
+ # Swap the weights.
138
+ new_weights2 , new_weights1 = ray .get (set_and_get_weights .remote (weights2 , weights1 ))
139
+ self .assertEqual (weights1 , new_weights1 )
140
+ self .assertEqual (weights2 , new_weights2 )
152
141
153
142
ray .worker .cleanup ()
154
143
144
+ # This test creates an additional network on the driver so that the tensorflow
145
+ # variables on the driver and the worker differ.
155
146
def testNetworkDriverWorkerIndependent (self ):
156
147
ray .init (num_workers = 1 )
157
148
@@ -167,23 +158,15 @@ def testNetworkDriverWorkerIndependent(self):
167
158
net_vars2 , init2 , sess2 = ray .env .net
168
159
sess2 .run (init2 )
169
160
170
- # Get the weights.
171
- weights1 = net_vars1 .get_weights ()
172
161
weights2 = net_vars2 .get_weights ()
173
- self .assertNotEqual (weights1 , weights2 )
174
-
175
- # Swap the weights.
176
- net_vars1 .set_weights (weights2 )
177
- net_vars2 .set_weights (weights1 )
178
162
179
- # Get the new weights.
180
- new_weights1 = net_vars1 . get_weights ()
181
- new_weights2 = net_vars2 . get_weights ( )
182
- self . assertNotEqual ( new_weights1 , new_weights2 )
163
+ @ ray . remote
164
+ def set_and_get_weights ( weights ):
165
+ ray . env . net [ 0 ]. set_weights ( weights )
166
+ return ray . env . net [ 0 ]. get_weights ( )
183
167
184
- # Check that the weights were swapped.
185
- self .assertEqual (weights1 , new_weights2 )
186
- self .assertEqual (weights2 , new_weights1 )
168
+ new_weights2 = ray .get (set_and_get_weights .remote (net_vars2 .get_weights ()))
169
+ self .assertEqual (weights2 , new_weights2 )
187
170
188
171
ray .worker .cleanup ()
189
172
0 commit comments