Skip to content

Commit

Permalink
Negative axis support for argmin and argmax (tensorflow#6042)
Browse files Browse the repository at this point in the history
  • Loading branch information
AnishShah authored and martinwicke committed Dec 8, 2016
1 parent ed9ae03 commit 3d64e20
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 19 deletions.
16 changes: 9 additions & 7 deletions tensorflow/core/kernels/argmax_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,27 +57,29 @@ class ArgOp : public OpKernel {
const int32 dim = internal::SubtleMustCopy(dimension.scalar<int32>()());
const int input_dims = input.dims();

OP_REQUIRES(context, dim >= 0, errors::InvalidArgument("dim must be >= 0"));
OP_REQUIRES(context, dim < input_dims,
errors::InvalidArgument("Minimum tensor rank: ", dim + 1,
" but got: ", input_dims));
int axis = dim < 0 ? dim + input_dims : dim;

OP_REQUIRES(context, axis >= 0 && axis < input_dims,
errors::InvalidArgument(
"Expected dimension in the range [", -input_dims, ", ",
input_dims, "), but got ", dim));
OP_REQUIRES(
context, input.dim_size(dim) > 0,
context, input.dim_size(axis) > 0,
errors::InvalidArgument("Reduction axis ", dim, " is empty in shape ",
input.shape().DebugString()));

TensorShape output_shape;
const TensorShape& input_shape = input.shape();
for (int d = 0; d < input_dims - 1; ++d) {
output_shape.AddDim(input_shape.dim_size((d < dim) ? d : d + 1));
output_shape.AddDim(input_shape.dim_size((d < axis) ? d : d + 1));
}
Tensor* output = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));

#define HANDLE_DIM(NDIM) \
case NDIM: \
ArgFunctor::Reduce##NDIM(context->eigen_device<Device>(), \
input.tensor<T, NDIM>(), dim, \
input.tensor<T, NDIM>(), axis, \
output->tensor<int64, NDIM - 1>()); \
break;

Expand Down
11 changes: 6 additions & 5 deletions tensorflow/core/ops/math_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1268,17 +1268,18 @@ Status ArgOpShape(shape_inference::InferenceContext* c) {
dimension_val = dim_t->scalar<int64>()();
}

if (dimension_val < 0 || dimension_val >= input_rank) {
int64 axis = dimension_val < 0 ? dimension_val + input_rank : dimension_val;
if (axis < 0 || axis >= input_rank) {
return errors::InvalidArgument("Dimension (", dimension_val,
") must be in the range [0, ", input_rank,
"), where ", input_rank, " is the ",
"number of dimensions in the input.");
") must be in the range [", -input_rank,
", ", input_rank, "), where ", input_rank,
" is the number of dimensions in the input.");
}

// Return the input shape without the dimension being reduced.
std::vector<DimensionHandle> dims;
for (int i = 0; i < input_rank; ++i) {
if (dimension_val != i) {
if (axis != i) {
dims.emplace_back(c->Dim(input_shape, i));
}
}
Expand Down
8 changes: 6 additions & 2 deletions tensorflow/core/ops/math_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -450,11 +450,15 @@ TEST(MathOpsTest, ArgOps_ShapeFn) {
// Dimension value out of bounds
dimension = test::AsScalar(10);
op.input_tensors[1] = &dimension;
INFER_ERROR("must be in the range [0, 3)", op, "[2,3,4];[]");
INFER_ERROR("must be in the range [-3, 3)", op, "[2,3,4];[]");

dimension = test::AsScalar(-10);
op.input_tensors[1] = &dimension;
INFER_ERROR("must be in the range [0, 3)", op, "[2,3,4];[]");
INFER_ERROR("must be in the range [-3, 3)", op, "[2,3,4];[]");

dimension = test::AsScalar(-1);
op.input_tensors[1] = &dimension;
INFER_OK(op, "[2,3,4];[]", "[d0_0,d0_1]");
}

TEST(MathOpsTest, Betainc_ShapeFn) {
Expand Down
10 changes: 6 additions & 4 deletions tensorflow/g3doc/api_docs/python/math_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -3320,8 +3320,9 @@ Returns the index with the smallest value across axes of a tensor.

* <b>`input`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, `complex64`, `complex128`, `qint8`, `quint8`, `qint32`, `half`.
* <b>`axis`</b>: A `Tensor`. Must be one of the following types: `int32`, `int64`.
int32, 0 <= axis < rank(input). Describes which axis
of the input Tensor to reduce across. For vectors, use axis = 0.
int32, -rank(input) <= axis < rank(input). Describes which axis
of the input Tensor to reduce across. Negative axis are interpreted as
counting from the end of the array. For vectors, use axis = 0.
* <b>`name`</b>: A name for the operation (optional).

##### Returns:
Expand All @@ -3340,8 +3341,9 @@ Returns the index with the largest value across axes of a tensor.

* <b>`input`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, `complex64`, `complex128`, `qint8`, `quint8`, `qint32`, `half`.
* <b>`axis`</b>: A `Tensor`. Must be one of the following types: `int32`, `int64`.
int32, 0 <= axis < rank(input). Describes which axis
of the input Tensor to reduce across. For vectors, use axis = 0.
int32, -rank(input) <= axis < rank(input). Describes which axis
of the input Tensor to reduce across. Negative axis are interpreted as
counting from the end of the array. For vectors, use axis = 0.
* <b>`name`</b>: A name for the operation (optional).

##### Returns:
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/python/kernel_tests/argmax_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _testDim(self, dtype):
x = np.asarray(100*np.random.randn(3, 2, 4, 5, 6), dtype=dtype)

# Check that argmin and argmax match numpy along all dimensions
for dim in range(5):
for dim in range(-5, 5):
self._testBothArg(tf.argmax, x, dim, x.argmax(dim))
self._testBothArg(tf.argmin, x, dim, x.argmin(dim))

Expand Down

0 comments on commit 3d64e20

Please sign in to comment.