Skip to content

Commit

Permalink
Merge pull request tensorflow#6031 from tensorflow/vincentvanhoucke-p…
Browse files Browse the repository at this point in the history
…atch-3

Use gather on floats instead of int32 to keep the kernel on GPU when possible.
  • Loading branch information
vincentvanhoucke authored Dec 2, 2016
2 parents 4172ca2 + a5b8d6c commit 55973b6
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions tensorflow/python/ops/nn_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,15 +463,15 @@ def sufficient_statistics(x, axes, shift=None, keep_dims=False, name=None):
with ops.name_scope(name, "sufficient_statistics", [x, shift]):
x = ops.convert_to_tensor(x, name="x")
x_shape = x.get_shape()
if x_shape.is_fully_defined():
if all(x_shape[d].value is not None for d in axes):
counts = 1
for d in axes:
counts *= x_shape[d].value
counts = constant_op.constant(counts, dtype=x.dtype)
else: # shape needs to be inferred at runtime.
x_dims = array_ops.gather(array_ops.shape(x), axes)
counts = math_ops.cast(
math_ops.reduce_prod(x_dims), x.dtype, name="count")
x_dims = array_ops.gather(
math_ops.cast(array_ops.shape(x), x.dtype), axes)
counts = math_ops.reduce_prod(x_dims, name="count")
if shift is not None:
shift = ops.convert_to_tensor(shift, name="shift")
m_ss = math_ops.sub(x, shift)
Expand Down

0 comments on commit 55973b6

Please sign in to comment.