Skip to content

Commit 8633b22

Browse files
TF-XLA Bridge ops, tests, and registrations for ReverseOp and ReverseV2Op.
Change: 149488853
1 parent 95df7a3 commit 8633b22

File tree

7 files changed

+221
-0
lines changed

7 files changed

+221
-0
lines changed

tensorflow/compiler/tests/BUILD

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,17 @@ tf_xla_py_test(
224224
],
225225
)
226226

227+
tf_xla_py_test(
228+
name = "reverse_ops_test",
229+
size = "small",
230+
srcs = ["reverse_ops_test.py"],
231+
deps = [
232+
":xla_test",
233+
"//tensorflow/python:array_ops",
234+
"//tensorflow/python:framework_for_generated_wrappers",
235+
],
236+
)
237+
227238
tf_xla_py_test(
228239
name = "ternary_ops_test",
229240
size = "small",

tensorflow/compiler/tests/randomized_tests.cc

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1747,6 +1747,30 @@ TEST_F(OpTest, Reshape) {
17471747
});
17481748
}
17491749

1750+
TEST_F(OpTest, Reverse) {
1751+
Repeatedly([this]() {
1752+
std::vector<int64> dims = RandomDims(1);
1753+
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
1754+
int64 rank = dims.size();
1755+
ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Reverse")
1756+
.Input(RandomTensor(type, dims))
1757+
.Input(RandomTensor(DT_BOOL, {rank}))
1758+
.Attr("T", DT_FLOAT));
1759+
});
1760+
}
1761+
1762+
TEST_F(OpTest, ReverseV2) {
1763+
Repeatedly([this]() {
1764+
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
1765+
Tensor data = RandomTensor(type);
1766+
Tensor indices = RandomReductionIndices(data.dims());
1767+
ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ReverseV2")
1768+
.Input(data)
1769+
.Input(indices)
1770+
.Attr("T", DT_FLOAT));
1771+
});
1772+
}
1773+
17501774
TEST_F(OpTest, Rsqrt) {
17511775
Repeatedly([this]() {
17521776
ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Rsqrt")
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Functional tests for XLA Reverse Ops."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
import itertools
22+
import numpy as np
23+
24+
from tensorflow.compiler.tests.xla_test import XLATestCase
25+
from tensorflow.python.framework import constant_op
26+
from tensorflow.python.framework import dtypes
27+
from tensorflow.python.ops import array_ops
28+
from tensorflow.python.platform import googletest
29+
30+
31+
class ReverseOpsTest(XLATestCase):
32+
33+
def testReverseOneDim(self):
34+
shape = (7, 5, 9, 11)
35+
for revdim in range(len(shape)):
36+
self._AssertReverseEqual([revdim], shape)
37+
38+
def testReverseMoreThanOneDim(self):
39+
shape = (7, 5, 9, 11)
40+
for revdims in itertools.chain.from_iterable(
41+
itertools.combinations(range(len(shape)), k)
42+
for k in range(2, len(shape)+1)):
43+
self._AssertReverseEqual(revdims, shape)
44+
45+
def _AssertReverseEqual(self, revdims, shape):
46+
np.random.seed(120)
47+
pval = np.random.randint(0, 100, size=shape).astype(float)
48+
with self.test_session():
49+
with self.test_scope():
50+
p = array_ops.placeholder(dtypes.int32, shape=shape)
51+
axis = constant_op.constant(
52+
np.array(revdims, dtype=np.int32),
53+
shape=(len(revdims),), dtype=dtypes.int32)
54+
rval = array_ops.reverse(p, axis).eval({p: pval})
55+
56+
slices = [
57+
slice(-1, None, -1) if d in revdims else slice(None)
58+
for d in range(len(shape))]
59+
self.assertEqual(
60+
pval[slices].flatten().tolist(),
61+
rval.flatten().tolist())
62+
63+
64+
if __name__ == '__main__':
65+
googletest.main()

tensorflow/compiler/tf2xla/const_analysis.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ Status BackwardsConstAnalysis(const Graph& g,
6565
{"Range", "limit"},
6666
{"Range", "delta"},
6767
{"Reshape", "shape"},
68+
{"Reverse", "dims"},
69+
{"ReverseV2", "axis"},
6870
{"Slice", "begin"},
6971
{"Slice", "size"},
7072
{"Split", "split_dim"},

tensorflow/compiler/tf2xla/kernels/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ tf_kernel_library(
4343
"relu_op.cc",
4444
"reshape_op.cc",
4545
"retval_op.cc",
46+
"reverse_op.cc",
4647
"select_op.cc",
4748
"sequence_ops.cc",
4849
"shape_op.cc",
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
// XLA-specific reverse Op.
17+
18+
#include "tensorflow/compiler/tf2xla/type_util.h"
19+
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
20+
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
21+
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
22+
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
23+
#include "tensorflow/compiler/xla/literal_util.h"
24+
#include "tensorflow/core/framework/op_kernel.h"
25+
#include "tensorflow/core/framework/register_types.h"
26+
#include "tensorflow/core/framework/tensor.h"
27+
28+
namespace tensorflow {
29+
namespace {
30+
31+
class ReverseOp : public XlaOpKernel {
32+
public:
33+
explicit ReverseOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
34+
35+
void Compile(XlaOpKernelContext* ctx) override {
36+
// r = tf.reverse(x, revdims)
37+
const TensorShape x_shape = ctx->InputShape(0);
38+
const TensorShape revd_shape = ctx->InputShape(1);
39+
// Validate input sizes.
40+
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(revd_shape),
41+
errors::InvalidArgument("axes must be a vector, not shape ",
42+
revd_shape.DebugString()));
43+
OP_REQUIRES(ctx, revd_shape.num_elements() == x_shape.dims(),
44+
errors::InvalidArgument("axes ", revd_shape.DebugString(),
45+
" must have same number of elements as"
46+
" than input tensor has dimensions ",
47+
x_shape.DebugString(), "."));
48+
if (revd_shape.num_elements() == 0) {
49+
ctx->SetOutput(0, ctx->Input(0));
50+
return;
51+
}
52+
// ComputationBuilder::Rev() requires concrete values for dimensions arg.
53+
xla::Literal lax;
54+
OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {x_shape.dims()}, &lax));
55+
std::vector<bool> revdims(x_shape.dims());
56+
std::copy(lax.preds().begin(), lax.preds().end(), revdims.begin());
57+
std::vector<int64> dimensions;
58+
59+
for (int d = 0; d < x_shape.dims(); ++d) {
60+
if (revdims[d]) {
61+
dimensions.push_back(d);
62+
}
63+
}
64+
65+
ctx->SetOutput(0, ctx->builder()->Rev(ctx->Input(0), dimensions));
66+
}
67+
};
68+
69+
REGISTER_XLA_OP("Reverse", ReverseOp);
70+
71+
class ReverseV2Op : public XlaOpKernel {
72+
public:
73+
explicit ReverseV2Op(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
74+
75+
void Compile(XlaOpKernelContext* ctx) override {
76+
// r = tf.reverse(x, axes)
77+
const TensorShape x_shape = ctx->InputShape(0);
78+
const TensorShape axes_shape = ctx->InputShape(1);
79+
// Validate input sizes.
80+
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(axes_shape),
81+
errors::InvalidArgument("axes must be a vector, not shape ",
82+
axes_shape.DebugString()));
83+
OP_REQUIRES(ctx, axes_shape.num_elements() <= x_shape.dims(),
84+
errors::InvalidArgument("axes ", axes_shape.DebugString(),
85+
" can not have more elements"
86+
" than input tensor has dimensions ",
87+
x_shape.DebugString(), "."));
88+
// Reverse is a no-op if axes argument is empty.
89+
if (axes_shape.num_elements() == 0) {
90+
ctx->SetOutput(0, ctx->Input(0));
91+
return;
92+
}
93+
// ComputationBuilder::Rev() requires concrete values for dimensions arg.
94+
std::vector<int64> axes;
95+
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &axes));
96+
97+
for (int d = 0; d < axes.size(); ++d) {
98+
OP_REQUIRES(ctx, (0 <= axes[d]) && (axes[d] < x_shape.dims()),
99+
errors::InvalidArgument(axes[d], " is out of range [0, ",
100+
x_shape.dims(), ")."));
101+
}
102+
103+
ctx->SetOutput(0, ctx->builder()->Rev(ctx->Input(0), axes));
104+
}
105+
};
106+
107+
REGISTER_XLA_OP("ReverseV2", ReverseV2Op);
108+
109+
} // namespace
110+
} // namespace tensorflow

tensorflow/compiler/tf2xla/op_registrations.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,10 @@ REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
218218
Name("Reshape").TypeConstraint("T", kCpuAllTypes));
219219
REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("ResourceApplyGradientDescent")
220220
.TypeConstraint("T", kCpuAllTypes));
221+
REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
222+
Name("Reverse").TypeConstraint("T", kCpuAllTypes));
223+
REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
224+
Name("ReverseV2").TypeConstraint("T", kCpuAllTypes));
221225
REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
222226
Name("Rsqrt").TypeConstraint("T", kCpuFloatTypes));
223227
REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
@@ -493,6 +497,10 @@ REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
493497
Name("Reshape").TypeConstraint("T", kGpuAllTypes));
494498
REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("ResourceApplyGradientDescent")
495499
.TypeConstraint("T", kGpuAllTypes));
500+
REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
501+
Name("Reverse").TypeConstraint("T", kGpuAllTypes))
502+
REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
503+
Name("ReverseV2").TypeConstraint("T", kGpuAllTypes))
496504
REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
497505
Name("Rsqrt").TypeConstraint("T", kGpuFloatTypes));
498506
REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,

0 commit comments

Comments
 (0)