Skip to content

Commit

Permalink
Fix mirrored_strategy_test failure after contant tensors are always o…
Browse files Browse the repository at this point in the history
…n CPU

We use constant tensors a lot in our tests, and NCCL will complain if inputs
are on the same device.

PiperOrigin-RevId: 302121666
Change-Id: I6872ac2d63fbdfdac253f6e7c8b8602a8cd2fe7e
  • Loading branch information
crccw authored and tensorflower-gardener committed Mar 20, 2020
1 parent c7fb55c commit 3caa238
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 24 deletions.
2 changes: 1 addition & 1 deletion tensorflow/python/distribute/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1206,7 +1206,7 @@ cuda_py_test(
srcs = ["mirrored_strategy_test.py"],
shard_count = 5,
tags = [
# "multi_and_single_gpu", # b/151862653
"multi_and_single_gpu",
"no_windows_gpu", # TODO(b/130551176)
],
deps = [
Expand Down
15 changes: 8 additions & 7 deletions tensorflow/python/distribute/distribute_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -1035,15 +1035,16 @@ def mean_reduce_helper(v, axis=axis):
if dim is not None:
# By returning a python value in the static shape case, we can
# maybe get a fast path for reducing the denominator.
return numer, array_ops.constant(dim, dtype=dtypes.int64)
# TODO(b/151871486): Remove array_ops.identity after we fallback to
# simple reduction if inputs are all on CPU.
return numer, array_ops.identity(
constant_op.constant(dim, dtype=dtypes.int64))
elif axis < 0:
axis = axis + array_ops.rank(v)
if v.shape.rank == 1:
# TODO(b/139422050): Currently tf.shape is not supported in TPU dynamic
# padder, use tf.size instead to workaround if the rank is 1.
denom = array_ops.size(v, out_type=dtypes.int64)
else:
denom = array_ops.shape_v2(v, out_type=dtypes.int64)[axis]
# TODO(b/151871486): Remove array_ops.identity after we fallback to simple
# reduction if inputs are all on CPU.
denom = array_ops.identity(
array_ops.shape_v2(v, out_type=dtypes.int64)[axis])
# TODO(josh11b): Should we cast denom to v.dtype here instead of after the
# reduce is complete?
return numer, denom
Expand Down
9 changes: 3 additions & 6 deletions tensorflow/python/distribute/mirrored_strategy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ def testReduceAxisToCpu(self, distribution):
def replica_squared_fn(dtype=dtype):
# Lists with different lengths on different replicas.
replica_id = _replica_id_as_int()
return math_ops.cast([replica_id] * (replica_id + 1), dtype)
return array_ops.identity(
math_ops.cast([replica_id] * (replica_id + 1), dtype))

self.reduce_axis_helper(distribution, replica_squared_fn)

Expand Down Expand Up @@ -1406,11 +1407,7 @@ def _replica_id():
replica_id = ds_context.get_replica_context().replica_id_in_sync_group
if not isinstance(replica_id, ops.Tensor):
replica_id = constant_op.constant(replica_id)
# TODO(b/149852830): Workaround for small Tensor caching (which is only on
# CPU) to ensure the value is on the correct device.
replica_id = math_ops.cast(replica_id, dtypes.float32)
replica_id = math_ops.cast(replica_id, dtypes.int32)
return replica_id
return array_ops.identity(replica_id)


def _replica_id_as_int():
Expand Down
19 changes: 9 additions & 10 deletions tensorflow/python/distribute/strategy_test_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
Expand Down Expand Up @@ -120,7 +119,7 @@ def _test_minimize_loss_eager(self, d):
l = core.Dense(1, use_bias=False)

def loss(x):
y = array_ops.reshape(l(x), []) - constant_op.constant(1.)
y = array_ops.reshape(l(x), []) - array_ops.identity(1.)
return y * y
# TODO(isaprykin): Extract implicit_grad+get_filtered_grad_fn into a
# common `implicit_grad` function and put it in DistributionStrategy.
Expand All @@ -130,7 +129,7 @@ def loss(x):
def update(v, g):
return v.assign_sub(0.2 * g)

one = constant_op.constant([[1.]])
one = array_ops.identity([[1.]])

def step():
"""Perform one optimization step."""
Expand Down Expand Up @@ -177,15 +176,15 @@ def _test_minimize_loss_graph(self,
l = core.Dense(1, use_bias=False)

def loss(x):
y = array_ops.reshape(l(x), []) - constant_op.constant(1.)
y = array_ops.reshape(l(x), []) - array_ops.identity(1.)
return y * y

grad_fn = backprop.implicit_grad(loss)

def update(v, g):
return v.assign_sub(learning_rate * g)

one = constant_op.constant([[1.]])
one = array_ops.identity([[1.]])

def step():
"""Perform one optimization step."""
Expand Down Expand Up @@ -453,7 +452,7 @@ class OneDeviceDistributionTestBase(test.TestCase):
"""Some tests that should work with any one-device DistributionStrategy."""

def _test_run(self, strategy):
out1 = strategy.run(lambda: constant_op.constant(4.))
out1 = strategy.run(lambda: array_ops.identity(4.))
self.assertAllEqual([4.], self.evaluate(strategy.unwrap(out1)))

out2 = strategy.run(lambda x: {"a": x * 2, "b": x * x}, args=(out1,))
Expand Down Expand Up @@ -506,7 +505,7 @@ def _test_collective_comms_gradients(self, strategy, comm_fn, inputs,
self.skipTest("`tf.gradients` is not supported with eager execution.")

def step(c):
x = constant_op.constant(42.)
x = array_ops.identity(42.)
y = comm_fn(x) * c
return gradients_impl.gradients(y, [x])[0]

Expand All @@ -524,7 +523,7 @@ def _test_collective_comms_gradient_tape(self, strategy, comm_fn, inputs,
expected_grads):

def step(c):
x = constant_op.constant(42.)
x = array_ops.identity(42.)
with backprop.GradientTape() as tape:
tape.watch(x)
y = comm_fn(x) * c
Expand Down Expand Up @@ -634,7 +633,7 @@ def _test_collective_comms_gradients(self, strategy, comm_fn, inputs,
self.skipTest("`tf.gradients` is not supported with eager execution.")

def step(c):
x = constant_op.constant(42.)
x = array_ops.identity(42.)
y = comm_fn(x) * c
return gradients_impl.gradients(y, [x])[0]

Expand All @@ -652,7 +651,7 @@ def _test_collective_comms_gradient_tape(self, strategy, comm_fn, inputs,
expected_grads):

def step(c):
x = constant_op.constant(42.)
x = array_ops.identity(42.)
with backprop.GradientTape() as tape:
tape.watch(x)
y = comm_fn(x) * c
Expand Down

0 comments on commit 3caa238

Please sign in to comment.