diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py index afacef7acd9eaa..3cdfe0b227fc34 100644 --- a/tensorflow/python/ops/nn_impl.py +++ b/tensorflow/python/ops/nn_impl.py @@ -463,15 +463,15 @@ def sufficient_statistics(x, axes, shift=None, keep_dims=False, name=None): with ops.name_scope(name, "sufficient_statistics", [x, shift]): x = ops.convert_to_tensor(x, name="x") x_shape = x.get_shape() - if x_shape.is_fully_defined(): + if all(x_shape[d].value is not None for d in axes): counts = 1 for d in axes: counts *= x_shape[d].value counts = constant_op.constant(counts, dtype=x.dtype) else: # shape needs to be inferred at runtime. - x_dims = array_ops.gather(array_ops.shape(x), axes) - counts = math_ops.cast( - math_ops.reduce_prod(x_dims), x.dtype, name="count") + x_dims = array_ops.gather( + math_ops.cast(array_ops.shape(x), x.dtype), axes) + counts = math_ops.reduce_prod(x_dims, name="count") if shift is not None: shift = ops.convert_to_tensor(shift, name="shift") m_ss = math_ops.sub(x, shift)