Skip to content

Commit d64754c

Browse files
authored
Merge pull request tensorflow#20601 from yifeif/branch_203518000
Branch 203518000
2 parents b2fe2a8 + 50e72c4 commit d64754c

File tree

319 files changed

+20110
-6498
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

319 files changed

+20110
-6498
lines changed

tensorflow/compiler/tests/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -924,7 +924,7 @@ tf_xla_py_test(
924924

925925
tf_xla_py_test(
926926
name = "sort_ops_test",
927-
size = "small",
927+
size = "medium",
928928
srcs = ["sort_ops_test.py"],
929929
# Times out in fastbuild mode.
930930
tags = ["optonly"],

tensorflow/compiler/tests/random_ops_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,10 +140,10 @@ def probit(x, sess=sess):
140140
def testShuffle1d(self):
141141
with self.test_session() as sess:
142142
with self.test_scope():
143-
x = math_ops.range(20)
143+
x = math_ops.range(1 << 16)
144144
shuffle = random_ops.random_shuffle(x)
145145
result = sess.run(shuffle)
146-
expected = range(20)
146+
expected = range(1 << 16)
147147
# Compare sets to avoid randomness behavior changes but make sure still
148148
# have all the values.
149149
self.assertAllEqual(set(result), set(expected))

tensorflow/compiler/tf2xla/BUILD

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ cc_library(
162162
":sharding_util",
163163
":tf2xla_util",
164164
"//tensorflow/compiler/tf2xla/lib:util",
165-
"//tensorflow/compiler/xla:literal_util",
165+
"//tensorflow/compiler/xla:literal",
166166
"//tensorflow/compiler/xla:shape_util",
167167
"//tensorflow/compiler/xla:status_macros",
168168
"//tensorflow/compiler/xla:statusor",
@@ -202,7 +202,7 @@ cc_library(
202202
],
203203
visibility = [":friends"],
204204
deps = [
205-
"//tensorflow/compiler/xla:literal_util",
205+
"//tensorflow/compiler/xla:literal",
206206
"//tensorflow/compiler/xla:shape_util",
207207
"//tensorflow/compiler/xla:xla_data_proto",
208208
"//tensorflow/core:core_cpu_internal",
@@ -285,6 +285,7 @@ tf_cc_test(
285285
deps = [
286286
":tf2xla",
287287
":tf2xla_proto",
288+
"//tensorflow/compiler/xla:literal",
288289
"//tensorflow/compiler/xla:literal_util",
289290
"//tensorflow/compiler/xla:statusor",
290291
"//tensorflow/compiler/xla/client:client_library",
@@ -327,7 +328,7 @@ tf_cc_test(
327328
"//tensorflow/cc:ops",
328329
"//tensorflow/cc:resource_variable_ops",
329330
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
330-
"//tensorflow/compiler/xla:literal_util",
331+
"//tensorflow/compiler/xla:literal",
331332
"//tensorflow/compiler/xla:shape_util",
332333
"//tensorflow/compiler/xla:status_macros",
333334
"//tensorflow/compiler/xla/client:client_library",
@@ -364,6 +365,7 @@ tf_cc_test(
364365
],
365366
deps = [
366367
":common",
368+
"//tensorflow/compiler/xla:literal",
367369
"//tensorflow/compiler/xla:literal_util",
368370
"//tensorflow/core:framework",
369371
"//tensorflow/core:test",

tensorflow/compiler/tf2xla/kernels/BUILD

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ tf_kernel_library(
114114
"//tensorflow/compiler/tf2xla/lib:while_loop",
115115
"//tensorflow/compiler/tf2xla/ops:xla_ops",
116116
"//tensorflow/compiler/xla:array4d",
117+
"//tensorflow/compiler/xla:literal",
117118
"//tensorflow/compiler/xla:literal_util",
118119
"//tensorflow/compiler/xla:shape_util",
119120
"//tensorflow/compiler/xla:status_macros",
@@ -159,7 +160,7 @@ tf_kernel_library(
159160
"//tensorflow/compiler/tf2xla:common",
160161
"//tensorflow/compiler/tf2xla:xla_compiler",
161162
"//tensorflow/compiler/tf2xla/ops:xla_ops",
162-
"//tensorflow/compiler/xla:literal_util",
163+
"//tensorflow/compiler/xla:literal",
163164
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
164165
"//tensorflow/core:framework",
165166
"//tensorflow/core:lib",
@@ -175,7 +176,7 @@ tf_kernel_library(
175176
"//tensorflow/compiler/tf2xla:common",
176177
"//tensorflow/compiler/tf2xla:xla_compiler",
177178
"//tensorflow/compiler/tf2xla/ops:xla_ops",
178-
"//tensorflow/compiler/xla:literal_util",
179+
"//tensorflow/compiler/xla:literal",
179180
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
180181
"//tensorflow/core:framework",
181182
"//tensorflow/core:lib",
@@ -210,6 +211,7 @@ tf_kernel_library(
210211
":index_ops_kernel_argmax_float_2d",
211212
"//tensorflow/compiler/tf2xla:common",
212213
"//tensorflow/compiler/tf2xla:xla_compiler",
214+
"//tensorflow/compiler/xla:literal",
213215
"//tensorflow/compiler/xla:literal_util",
214216
"//tensorflow/compiler/xla/client:client_library",
215217
"//tensorflow/compiler/xla/client/lib:arithmetic",

tensorflow/compiler/tf2xla/kernels/bcast_ops.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ limitations under the License.
1919
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
2020
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
2121
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
22-
#include "tensorflow/compiler/xla/literal_util.h"
22+
#include "tensorflow/compiler/xla/literal.h"
2323
#include "tensorflow/core/platform/macros.h"
2424
#include "tensorflow/core/platform/types.h"
2525
#include "tensorflow/core/util/bcast.h"

tensorflow/compiler/tf2xla/kernels/elu_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ limitations under the License.
1919
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
2020
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
2121
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
22-
#include "tensorflow/compiler/xla/literal_util.h"
22+
#include "tensorflow/compiler/xla/literal.h"
2323
#include "tensorflow/core/framework/kernel_def_builder.h"
2424
#include "tensorflow/core/framework/types.h"
2525
#include "tensorflow/core/kernels/no_op.h"

tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ limitations under the License.
1919
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
2020
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
2121
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
22-
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
22+
#include "tensorflow/compiler/xla/literal_util.h"
2323
#include "tensorflow/core/framework/kernel_def_builder.h"
2424
#include "tensorflow/core/framework/op_kernel.h"
2525
#include "tensorflow/core/framework/register_types.h"
@@ -78,14 +78,14 @@ class ArgMaxCustomCallOp : public XlaOpKernel {
7878
std::vector<xla::XlaOp> args;
7979
args.push_back(ctx->Input(0));
8080
args.push_back(xla::ConstantLiteral(
81-
&b, *xla::Literal::CreateR1<int64>(input_shape.dim_sizes())));
81+
&b, *xla::LiteralUtil::CreateR1<int64>(input_shape.dim_sizes())));
8282
if (input_shape.dims() > 1) {
8383
// Don't bother passing the output shape and dim for the 1d case, since
8484
// the shape is always a scalar and the dim is always 0.
8585
args.push_back(xla::ConstantLiteral(
86-
&b, *xla::Literal::CreateR1<int64>(output_shape.dim_sizes())));
86+
&b, *xla::LiteralUtil::CreateR1<int64>(output_shape.dim_sizes())));
8787
args.push_back(
88-
xla::ConstantLiteral(&b, *xla::Literal::CreateR0<int32>(dim)));
88+
xla::ConstantLiteral(&b, *xla::LiteralUtil::CreateR0<int32>(dim)));
8989
}
9090

9191
xla::Shape xla_shape =

tensorflow/compiler/tf2xla/kernels/pooling_ops.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ limitations under the License.
2222
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
2323
#include "tensorflow/compiler/xla/client/lib/constants.h"
2424
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
25-
#include "tensorflow/compiler/xla/literal_util.h"
25+
#include "tensorflow/compiler/xla/literal.h"
2626
#include "tensorflow/compiler/xla/util.h"
2727
#include "tensorflow/core/framework/op_kernel.h"
2828
#include "tensorflow/core/framework/register_types.h"

tensorflow/compiler/tf2xla/kernels/random_ops.cc

Lines changed: 112 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -74,56 +74,121 @@ class RandomShuffleOp : public XlaOpKernel {
7474
for (tensorflow::TensorShapeDim dimension : input_shape) {
7575
num_elements *= dimension.size;
7676
}
77+
7778
if (num_elements <= 1 || n <= 1) {
7879
// No shuffling is required, so copy input directly to output
7980
ctx->SetOutput(0, input);
80-
} else {
81-
// Generate the random swaps for the indices.
82-
auto swaps_shape = xla::ShapeUtil::MakeShape(xla::S32, {n});
83-
auto swaps =
84-
xla::RngUniform(xla::ConstantR0<int32>(builder, 0),
85-
xla::ConstantR0<int32>(builder, n), swaps_shape);
86-
87-
// Generate range(n) as the initial value for the indices to be swapped.
88-
xla::XlaOp indices = xla::Iota(builder, xla::S32, n);
89-
90-
// Swap the indices at i and swaps[i].
91-
auto swap_body_fn = [&](xla::XlaOp i,
92-
gtl::ArraySlice<xla::XlaOp> loop_vars,
93-
xla::XlaBuilder* builder)
94-
-> xla::StatusOr<std::vector<xla::XlaOp>> {
95-
auto swaps = loop_vars[0];
96-
auto indices = loop_vars[1];
97-
i = xla::Reshape(i, {1});
98-
// temp = indices[i]
99-
auto temp = xla::DynamicSlice(indices, i, {1});
100-
// swap_index = swaps[i]
101-
auto swap_index = xla::DynamicSlice(swaps, i, {1});
102-
// swap_value = indices[swaps[i]]
103-
auto swap_value = xla::DynamicSlice(indices, swap_index, {1});
104-
// indices[i] = indices[swaps[i]]
105-
indices = xla::DynamicUpdateSlice(indices, swap_value, i);
106-
// indices[swaps[i]] = temp
107-
indices = xla::DynamicUpdateSlice(indices, temp, swap_index);
108-
return std::vector<xla::XlaOp>{swaps, indices};
109-
};
110-
// for i in range(n):
111-
auto swap_loop_result =
112-
XlaForEachIndex(n, xla::S32, swap_body_fn, {swaps, indices},
113-
"indices_swap_loop", builder)
114-
.ValueOrDie();
115-
auto swapped_indices = swap_loop_result[1];
116-
117-
// Gather the data using the swapped indices as the shuffled order.
118-
auto indices_tensor_shape = TensorShape({n});
119-
DataType type = ctx->expected_output_dtype(0);
120-
xla::XlaOp gather;
121-
OP_REQUIRES_OK(ctx, XlaGather(input, input_shape, swapped_indices,
122-
indices_tensor_shape,
123-
/*axis=*/0, /*indices_are_nd=*/false, type,
124-
DT_INT32, builder, &gather));
125-
ctx->SetOutput(0, gather);
81+
return;
82+
}
83+
84+
if (input_shape.dims() == 1) {
85+
// For R1s, shuffle values by sorting instead of the obvious Fisher-Yates
86+
// algorithm. Fisher-Yates is simple to implement and correct, but not
87+
// easily parallelizable. For a sufficiently parallel architecture, it is
88+
// faster to sort many times, than Fisher-Yates shuffle once.
89+
90+
// Shuffle values by assigning each value a random key and sorting the
91+
// keys. Keys can collide causing detectable patterns in the shuffled
92+
// output. Collisions translates into more ascending sub-sequences in the
93+
// shuffled output than would be expected by chance. To avoid collisions,
94+
// the number of possible key values must be sufficiently large.
95+
96+
// How are more than 2^32 keys created? In each loop iteration, the
97+
// algorithm sorts by random keys. Conceptually, the earlier iterations
98+
// are sorting on the lower-order bits of larger keys that are never
99+
// actually assembled.
100+
101+
// The expected number of collisions is n - d + d(1 - 1/d)^n, where d is
102+
// the number of possible keys and n is the number of values. If d = n^2,
103+
// then the limit as n goes to infinity is 1/2. If d = n^3, then the limit
104+
// as n goes to infinity is zero.
105+
106+
// This implementation ensures that the key-space is greater than or equal
107+
// to the cube of the number of values. The risk of collisions can be
108+
// further reduced by increasing Exponent at the expense of
109+
// performance.
110+
111+
// For Exponent = 2, the expected number of collisions per shuffle is
112+
// maximized at n = floor((2^32-1)^(1/2)) = 65535 where the expectation is
113+
// about 1/2.
114+
115+
// For Exponent = 3, the expected number of collisions per shuffle is
116+
// maximized at n = floor((2^32-1)^(1/3)) = 1625 where the expectation is
117+
// about 1/3255.
118+
119+
// For Exponent = 4, the expected number of collisions per shuffle is
120+
// maximized at n = floor((2^32-1)^(1/4)) = 255 where the expectation is
121+
// about 1/132622.
122+
constexpr int Exponent = 3;
123+
const int rounds = static_cast<int>(
124+
std::ceil(Exponent * std::log(num_elements) / std::log(kuint32max)));
125+
126+
const xla::Shape key_shape =
127+
xla::ShapeUtil::MakeShape(xla::U32, {num_elements});
128+
xla::XlaOp zero = xla::ConstantR0(builder, 0U);
129+
130+
// Unfortunately, xla::RngUniform gives values in the half open interval
131+
// rather than the closed interval, so instead of 2^32 possible keys there
132+
// are only 2^32 - 1 (kuint32max).
133+
xla::XlaOp max_value = xla::ConstantR0(builder, kuint32max);
134+
135+
xla::XlaOp curr = input;
136+
for (int i = 0; i < rounds; ++i) {
137+
xla::XlaOp keys = xla::RngUniform(zero, max_value, key_shape);
138+
xla::XlaOp sorted = xla::Sort(keys, curr);
139+
curr = xla::GetTupleElement(sorted, 1);
140+
}
141+
142+
ctx->SetOutput(0, curr);
143+
return;
126144
}
145+
146+
// The Fisher-Yates algorithm.
147+
148+
// Generate the random swaps for the indices.
149+
auto swaps_shape = xla::ShapeUtil::MakeShape(xla::S32, {n});
150+
auto swaps =
151+
xla::RngUniform(xla::ConstantR0<int32>(builder, 0),
152+
xla::ConstantR0<int32>(builder, n), swaps_shape);
153+
154+
// Generate range(n) as the initial value for the indices to be swapped.
155+
xla::XlaOp indices = xla::Iota(builder, xla::S32, n);
156+
157+
// Swap the indices at i and swaps[i].
158+
auto swap_body_fn = [&](xla::XlaOp i, gtl::ArraySlice<xla::XlaOp> loop_vars,
159+
xla::XlaBuilder* builder)
160+
-> xla::StatusOr<std::vector<xla::XlaOp>> {
161+
auto swaps = loop_vars[0];
162+
auto indices = loop_vars[1];
163+
i = xla::Reshape(i, {1});
164+
// temp = indices[i]
165+
auto temp = xla::DynamicSlice(indices, i, {1});
166+
// swap_index = swaps[i]
167+
auto swap_index = xla::DynamicSlice(swaps, i, {1});
168+
// swap_value = indices[swaps[i]]
169+
auto swap_value = xla::DynamicSlice(indices, swap_index, {1});
170+
// indices[i] = indices[swaps[i]]
171+
indices = xla::DynamicUpdateSlice(indices, swap_value, i);
172+
// indices[swaps[i]] = temp
173+
indices = xla::DynamicUpdateSlice(indices, temp, swap_index);
174+
return std::vector<xla::XlaOp>{swaps, indices};
175+
};
176+
// for i in range(n):
177+
auto swap_loop_result =
178+
XlaForEachIndex(n, xla::S32, swap_body_fn, {swaps, indices},
179+
"indices_swap_loop", builder)
180+
.ValueOrDie();
181+
auto swapped_indices = swap_loop_result[1];
182+
183+
// Gather the data using the swapped indices as the shuffled order.
184+
auto indices_tensor_shape = TensorShape({n});
185+
DataType type = ctx->expected_output_dtype(0);
186+
xla::XlaOp gather;
187+
OP_REQUIRES_OK(ctx, XlaGather(input, input_shape, swapped_indices,
188+
indices_tensor_shape,
189+
/*axis=*/0, /*indices_are_nd=*/false, type,
190+
DT_INT32, builder, &gather));
191+
ctx->SetOutput(0, gather);
127192
}
128193

129194
private:
@@ -220,5 +285,5 @@ REGISTER_XLA_OP(Name("TruncatedNormal")
220285
.TypeConstraint("dtype", DT_FLOAT),
221286
TruncatedNormalOp);
222287

223-
} // anonymous namespace
288+
} // namespace
224289
} // namespace tensorflow

tensorflow/compiler/tf2xla/kernels/reduction_ops.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ limitations under the License.
2121
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
2222
#include "tensorflow/compiler/xla/client/lib/constants.h"
2323
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
24-
#include "tensorflow/compiler/xla/literal_util.h"
24+
#include "tensorflow/compiler/xla/literal.h"
2525
#include "tensorflow/core/framework/kernel_def_builder.h"
2626

2727
namespace tensorflow {

0 commit comments

Comments
 (0)