Skip to content

Commit 7151ed5

Browse files
robertnishiharapcmoritz
authored andcommitted
Fix bug in tensorflow tests. (ray-project#218)
* Fix bug in tensorflow tests. * Address comment.
1 parent 9bb8162 commit 7151ed5

File tree

1 file changed

+26
-43
lines changed

1 file changed

+26
-43
lines changed

test/tensorflow_test.py

+26-43
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def testVariableNameCollision(self):
8686

8787
ray.env.net1 = ray.EnvironmentVariable(net_vars_initializer, net_vars_reinitializer)
8888
ray.env.net2 = ray.EnvironmentVariable(net_vars_initializer, net_vars_reinitializer)
89-
89+
9090
net_vars1, init1, sess1 = ray.env.net1
9191
net_vars2, init2, sess2 = ray.env.net2
9292

@@ -108,7 +108,7 @@ def testNetworksIndependent(self):
108108

109109
ray.env.net1 = ray.EnvironmentVariable(net_vars_initializer, net_vars_reinitializer)
110110
ray.env.net2 = ray.EnvironmentVariable(net_vars_initializer, net_vars_reinitializer)
111-
111+
112112
net_vars1, init1, sess1 = ray.env.net1
113113
net_vars2, init2, sess2 = ray.env.net2
114114

@@ -117,41 +117,32 @@ def testNetworksIndependent(self):
117117
sess2.run(init2)
118118

119119
@ray.remote
120-
def get_vars1():
121-
return ray.env.net1[0].get_weights()
122-
123-
@ray.remote
124-
def get_vars2():
125-
return ray.env.net2[0].get_weights()
126-
127-
@ray.remote
128-
def set_vars1(weights):
129-
ray.env.net1[0].set_weights(weights)
130-
131-
@ray.remote
132-
def set_vars2(weights):
133-
ray.env.net2[0].set_weights(weights)
134-
135-
# Get the weights.
120+
def set_and_get_weights(weights1, weights2):
121+
ray.env.net1[0].set_weights(weights1)
122+
ray.env.net2[0].set_weights(weights2)
123+
return ray.env.net1[0].get_weights(), ray.env.net2[0].get_weights()
124+
125+
# Make sure the two networks have different weights. TODO(rkn): Note that
126+
# equality comparisons of numpy arrays normally does not work. This only
127+
# works because at the moment they have size 1.
136128
weights1 = net_vars1.get_weights()
137129
weights2 = net_vars2.get_weights()
138130
self.assertNotEqual(weights1, weights2)
139131

140-
# Swap the weights.
141-
set_vars2.remote(weights1)
142-
set_vars1.remote(weights2)
143-
144-
# Get the new weights.
145-
new_weights1 = ray.get(get_vars1.remote())
146-
new_weights2 = ray.get(get_vars2.remote())
147-
self.assertNotEqual(new_weights1, new_weights2)
132+
# Set the weights and get the weights, and make sure they are unchanged.
133+
new_weights1, new_weights2 = ray.get(set_and_get_weights.remote(weights1, weights2))
134+
self.assertEqual(weights1, new_weights1)
135+
self.assertEqual(weights2, new_weights2)
148136

149-
# Check that the weights were swapped.
150-
self.assertEqual(weights1, new_weights2)
151-
self.assertEqual(weights2, new_weights1)
137+
# Swap the weights.
138+
new_weights2, new_weights1 = ray.get(set_and_get_weights.remote(weights2, weights1))
139+
self.assertEqual(weights1, new_weights1)
140+
self.assertEqual(weights2, new_weights2)
152141

153142
ray.worker.cleanup()
154143

144+
# This test creates an additional network on the driver so that the tensorflow
145+
# variables on the driver and the worker differ.
155146
def testNetworkDriverWorkerIndependent(self):
156147
ray.init(num_workers=1)
157148

@@ -167,23 +158,15 @@ def testNetworkDriverWorkerIndependent(self):
167158
net_vars2, init2, sess2 = ray.env.net
168159
sess2.run(init2)
169160

170-
# Get the weights.
171-
weights1 = net_vars1.get_weights()
172161
weights2 = net_vars2.get_weights()
173-
self.assertNotEqual(weights1, weights2)
174-
175-
# Swap the weights.
176-
net_vars1.set_weights(weights2)
177-
net_vars2.set_weights(weights1)
178162

179-
# Get the new weights.
180-
new_weights1 = net_vars1.get_weights()
181-
new_weights2 = net_vars2.get_weights()
182-
self.assertNotEqual(new_weights1, new_weights2)
163+
@ray.remote
164+
def set_and_get_weights(weights):
165+
ray.env.net[0].set_weights(weights)
166+
return ray.env.net[0].get_weights()
183167

184-
# Check that the weights were swapped.
185-
self.assertEqual(weights1, new_weights2)
186-
self.assertEqual(weights2, new_weights1)
168+
new_weights2 = ray.get(set_and_get_weights.remote(net_vars2.get_weights()))
169+
self.assertEqual(weights2, new_weights2)
187170

188171
ray.worker.cleanup()
189172

0 commit comments

Comments
 (0)