Skip to content

Commit b044b59

Browse files
akuegeltensorflower-gardener
authored andcommitted
Support S16 and U16 in the HLO evaluator
This allows to enable ComparatorsTest.CompareLtF16 and ComparatorsTest.CompareGtF16 tests. While at it, also enable other evaluator tests which are already passing. PiperOrigin-RevId: 228869860
1 parent b86e0ff commit b044b59

File tree

9 files changed

+89
-23
lines changed

9 files changed

+89
-23
lines changed

tensorflow/compiler/xla/client/lib/comparators_test.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,15 +103,14 @@ XLA_TEST_F(ComparatorsTest, CompareGtBF16) {
103103
ComputeAndCompareR1<bool>(builder(), expected, {});
104104
}
105105

106-
// The interpreter doesn't support S16, and F16 results in comparisons with S16.
107-
XLA_TEST_F(ComparatorsTest, DISABLED_ON_INTERPRETER(CompareLtF16)) {
106+
XLA_TEST_F(ComparatorsTest, CompareLtF16) {
108107
absl::InlinedVector<bool, 10> expected;
109108
BuildComparatorAndComparisons<F16>(this, /*compare_less_than=*/true,
110109
&expected);
111110
ComputeAndCompareR1<bool>(builder(), expected, {});
112111
}
113112

114-
XLA_TEST_F(ComparatorsTest, DISABLED_ON_INTERPRETER(CompareGtF16)) {
113+
XLA_TEST_F(ComparatorsTest, CompareGtF16) {
115114
absl::InlinedVector<bool, 10> expected;
116115
BuildComparatorAndComparisons<F16>(this, /*compare_less_than=*/false,
117116
&expected);

tensorflow/compiler/xla/literal.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1319,9 +1319,11 @@ StatusOr<Literal> ConvertIfDestTypeMatches(const LiteralBase& src_literal,
13191319
bitcast);
13201320
CONVERT_IF_TYPES_MATCH(PRED)
13211321
CONVERT_IF_TYPES_MATCH(S8)
1322+
CONVERT_IF_TYPES_MATCH(S16)
13221323
CONVERT_IF_TYPES_MATCH(S32)
13231324
CONVERT_IF_TYPES_MATCH(S64)
13241325
CONVERT_IF_TYPES_MATCH(U8)
1326+
CONVERT_IF_TYPES_MATCH(U16)
13251327
CONVERT_IF_TYPES_MATCH(U32)
13261328
CONVERT_IF_TYPES_MATCH(U64)
13271329
CONVERT_IF_TYPES_MATCH(F16)
@@ -1357,9 +1359,11 @@ StatusOr<Literal> ConvertSwitch(const LiteralBase& literal,
13571359
bitcast);
13581360
CONVERT_IF_DEST_TYPE_MATCHES(PRED)
13591361
CONVERT_IF_DEST_TYPE_MATCHES(S8)
1362+
CONVERT_IF_DEST_TYPE_MATCHES(S16)
13601363
CONVERT_IF_DEST_TYPE_MATCHES(S32)
13611364
CONVERT_IF_DEST_TYPE_MATCHES(S64)
13621365
CONVERT_IF_DEST_TYPE_MATCHES(U8)
1366+
CONVERT_IF_DEST_TYPE_MATCHES(U16)
13631367
CONVERT_IF_DEST_TYPE_MATCHES(U32)
13641368
CONVERT_IF_DEST_TYPE_MATCHES(U64)
13651369
CONVERT_IF_DEST_TYPE_MATCHES(F16)

tensorflow/compiler/xla/literal_test.cc

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1237,11 +1237,21 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
12371237
{{0, 19, 0, 21}, {22, 0, 24, 0}},
12381238
{{26, 0, 28, 0}, {0, 31, 0, 33}},
12391239
}}, layout_r4_dim0major_);
1240+
auto s16 = LiteralUtil::CreateR4WithLayout<int16>({{
1241+
{{10, 0, 12, 0}, {0, 15, 0, 17}},
1242+
{{0, 19, 0, 21}, {22, 0, 24, 0}},
1243+
{{26, 0, 28, 0}, {0, 31, 0, 33}},
1244+
}}, layout_r4_dim0major_);
12401245
auto s32 = LiteralUtil::CreateR4WithLayout<int32>({{
12411246
{{10, 0, 12, 0}, {0, 15, 0, 17}},
12421247
{{0, 19, 0, 21}, {22, 0, 24, 0}},
12431248
{{26, 0, 28, 0}, {0, 31, 0, 33}},
12441249
}}, layout_r4_dim0major_);
1250+
auto u16 = LiteralUtil::CreateR4WithLayout<uint16>({{
1251+
{{10, 0, 12, 0}, {0, 15, 0, 17}},
1252+
{{0, 19, 0, 21}, {22, 0, 24, 0}},
1253+
{{26, 0, 28, 0}, {0, 31, 0, 33}},
1254+
}}, layout_r4_dim0major_);
12451255
auto u32 = LiteralUtil::CreateR4WithLayout<uint32>({{
12461256
{{10, 0, 12, 0}, {0, 15, 0, 17}},
12471257
{{0, 19, 0, 21}, {22, 0, 24, 0}},
@@ -1301,6 +1311,12 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
13011311
// clang-format on
13021312
Literal conv;
13031313

1314+
conv = s8.Convert(U16).ConsumeValueOrDie();
1315+
EXPECT_EQ(conv, u16);
1316+
1317+
conv = s8.Convert(S16).ConsumeValueOrDie();
1318+
EXPECT_EQ(conv, s16);
1319+
13041320
conv = s8.Convert(U32).ConsumeValueOrDie();
13051321
EXPECT_EQ(conv, u32);
13061322

@@ -1352,10 +1368,14 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
13521368
conv = f16.Convert(C64).ConsumeValueOrDie();
13531369
EXPECT_EQ(conv, c64);
13541370

1371+
conv = s32.Convert(S16).ConsumeValueOrDie();
1372+
EXPECT_EQ(conv, s16);
1373+
1374+
conv = s32.Convert(U16).ConsumeValueOrDie();
1375+
EXPECT_EQ(conv, u16);
1376+
13551377
EXPECT_EQ(s32.Convert(TUPLE).status().code(),
13561378
tensorflow::error::UNIMPLEMENTED);
1357-
EXPECT_EQ(s32.Convert(S16).status().code(), tensorflow::error::UNIMPLEMENTED);
1358-
EXPECT_EQ(s32.Convert(U16).status().code(), tensorflow::error::UNIMPLEMENTED);
13591379
EXPECT_EQ(c64.Convert(F32).status().code(), tensorflow::error::UNIMPLEMENTED);
13601380
EXPECT_EQ(c64.Convert(S32).status().code(), tensorflow::error::UNIMPLEMENTED);
13611381
}

tensorflow/compiler/xla/service/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,9 +227,11 @@ cc_library(
227227
"hlo_evaluator_typed_visitor_double.cc",
228228
"hlo_evaluator_typed_visitor_float.cc",
229229
"hlo_evaluator_typed_visitor_half.cc",
230+
"hlo_evaluator_typed_visitor_int16.cc",
230231
"hlo_evaluator_typed_visitor_int32.cc",
231232
"hlo_evaluator_typed_visitor_int64.cc",
232233
"hlo_evaluator_typed_visitor_int8.cc",
234+
"hlo_evaluator_typed_visitor_uint16.cc",
233235
"hlo_evaluator_typed_visitor_uint32.cc",
234236
"hlo_evaluator_typed_visitor_uint64.cc",
235237
"hlo_evaluator_typed_visitor_uint8.cc",

tensorflow/compiler/xla/service/hlo_evaluator.cc

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -150,22 +150,14 @@ HloEvaluator::HloEvaluator(int64 max_loop_iterations)
150150
typed_visitors_[U8] =
151151
absl::make_unique<HloEvaluatorTypedVisitor<uint8>>(this);
152152
typed_visitors_[U16] =
153-
absl::make_unique<FunctionVisitor>([](HloInstruction*) {
154-
return Unimplemented(
155-
"HloEvaluator::HloEvaluatorTypedVisitor: unhandled primitive type: "
156-
"U16.");
157-
});
153+
absl::make_unique<HloEvaluatorTypedVisitor<uint16>>(this);
158154
typed_visitors_[U32] =
159155
absl::make_unique<HloEvaluatorTypedVisitor<uint32>>(this);
160156
typed_visitors_[U64] =
161157
absl::make_unique<HloEvaluatorTypedVisitor<uint64>>(this);
162158
typed_visitors_[S8] = absl::make_unique<HloEvaluatorTypedVisitor<int8>>(this);
163159
typed_visitors_[S16] =
164-
absl::make_unique<FunctionVisitor>([](HloInstruction*) {
165-
return Unimplemented(
166-
"HloEvaluator::HloEvaluatorTypedVisitor: unhandled primitive type: "
167-
"S16.");
168-
});
160+
absl::make_unique<HloEvaluatorTypedVisitor<int16>>(this);
169161
typed_visitors_[S32] =
170162
absl::make_unique<HloEvaluatorTypedVisitor<int32>>(this);
171163
typed_visitors_[S64] =
@@ -595,8 +587,11 @@ Status HloEvaluator::HandleCompare(HloInstruction* compare) {
595587
evaluated_[compare],
596588
Compare<uint8>(compare->shape(), opcode, lhs_literal, rhs_literal));
597589
} break;
598-
case U16:
599-
return Unimplemented("unhandled primitive type: U16.");
590+
case U16: {
591+
TF_ASSIGN_OR_RETURN(
592+
evaluated_[compare],
593+
Compare<uint16>(compare->shape(), opcode, lhs_literal, rhs_literal));
594+
} break;
600595
case U32: {
601596
TF_ASSIGN_OR_RETURN(
602597
evaluated_[compare],
@@ -612,8 +607,11 @@ Status HloEvaluator::HandleCompare(HloInstruction* compare) {
612607
evaluated_[compare],
613608
Compare<int8>(compare->shape(), opcode, lhs_literal, rhs_literal));
614609
} break;
615-
case S16:
616-
return Unimplemented("unhandled primitive type: S16.");
610+
case S16: {
611+
TF_ASSIGN_OR_RETURN(
612+
evaluated_[compare],
613+
Compare<int16>(compare->shape(), opcode, lhs_literal, rhs_literal));
614+
} break;
617615
case S32: {
618616
TF_ASSIGN_OR_RETURN(
619617
evaluated_[compare],
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h"
17+
18+
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
19+
20+
namespace xla {
21+
template class HloEvaluatorTypedVisitor<int16>;
22+
} // namespace xla
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h"
17+
18+
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
19+
20+
namespace xla {
21+
template class HloEvaluatorTypedVisitor<uint16>;
22+
} // namespace xla

tensorflow/compiler/xla/tests/tuple_test.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -513,8 +513,7 @@ XLA_TEST_F(TupleTest, ComplexTuples) {
513513

514514
class TupleHloTest : public HloTestBase {};
515515

516-
// Disabled on the interpreter because bitcast doesn't exist on the interpreter.
517-
XLA_TEST_F(TupleHloTest, DISABLED_ON_INTERPRETER(BitcastAfterGTE)) {
516+
XLA_TEST_F(TupleHloTest, BitcastAfterGTE) {
518517
const char* testcase = R"(
519518
HloModule m, is_scheduled=true
520519

tensorflow/compiler/xla/tests/while_test.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -901,7 +901,7 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) {
901901
// Per backend the values generated can be different as the different backends
902902
// use different random number generators.
903903
// TODO(b/32240857): Extend test to verify outputs.
904-
XLA_TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithPrngScalarResult)) {
904+
XLA_TEST_F(WhileTest, WhileWithPrngScalarResult) {
905905
auto v6s32 = ShapeUtil::MakeShape(S32, {6});
906906

907907
// Create a computation for the condition: repeat for count iterations.
@@ -1146,7 +1146,7 @@ XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) {
11461146
// while (f(result).get<0>()) {
11471147
// result = result + 1;
11481148
// }
1149-
XLA_TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithCallInsideCondition)) {
1149+
XLA_TEST_F(WhileTest, WhileWithCallInsideCondition) {
11501150
auto result_shape = ShapeUtil::MakeShape(S32, {});
11511151

11521152
// Create a computation for the condition: repeat for 5 iterations.

0 commit comments

Comments
 (0)