Skip to content

Commit 18581a6

Browse files
Tflite floor adds missing datatype support (#75)
* Tflite floor missing datatype support -Adds f16,bf16 for floor -Adds f16,bf16 for floor unit test Co-authored-by: nitheshsrikanth-mcw <nithesh.srikanth@multicorewareinc.com>
1 parent e108d46 commit 18581a6

File tree

4 files changed

+107
-19
lines changed

4 files changed

+107
-19
lines changed

tensorflow/lite/kernels/floor.cc

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
4242
GetOutputSafe(context, node, kOutputTensor, &output));
4343
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
4444
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
45-
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
45+
TF_LITE_ENSURE(context, input->type == kTfLiteFloat32 ||
46+
input->type == kTfLiteFloat16 ||
47+
input->type == kTfLiteBFloat16);
4648
output->type = input->type;
4749
TfLiteIntArray* output_size = TfLiteIntArrayCopy(input->dims);
4850
return context->ResizeTensor(context, output, output_size);
@@ -55,13 +57,38 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
5557
TfLiteTensor* output;
5658
TF_LITE_ENSURE_OK(context,
5759
GetOutputSafe(context, node, kOutputTensor, &output));
58-
59-
if (type == kGenericOptimized) {
60-
optimized_ops::Floor(GetTensorShape(input), GetTensorData<float>(input),
61-
GetTensorShape(output), GetTensorData<float>(output));
62-
} else {
63-
reference_ops::Floor(GetTensorShape(input), GetTensorData<float>(input),
64-
GetTensorShape(output), GetTensorData<float>(output));
60+
if (input->type == kTfLiteFloat32) {
61+
if (type == kGenericOptimized) {
62+
optimized_ops::Floor(GetTensorShape(input), GetTensorData<float>(input),
63+
GetTensorShape(output),
64+
GetTensorData<float>(output));
65+
} else {
66+
reference_ops::Floor(GetTensorShape(input), GetTensorData<float>(input),
67+
GetTensorShape(output),
68+
GetTensorData<float>(output));
69+
}
70+
}
71+
if (input->type == kTfLiteFloat16) {
72+
if (type == kGenericOptimized) {
73+
optimized_ops::Floor(
74+
GetTensorShape(input), GetTensorData<Eigen::half>(input),
75+
GetTensorShape(output), GetTensorData<Eigen::half>(output));
76+
} else {
77+
reference_ops::Floor(
78+
GetTensorShape(input), GetTensorData<Eigen::half>(input),
79+
GetTensorShape(output), GetTensorData<Eigen::half>(output));
80+
}
81+
}
82+
if (input->type == kTfLiteBFloat16) {
83+
if (type == kGenericOptimized) {
84+
optimized_ops::Floor(
85+
GetTensorShape(input), GetTensorData<Eigen::bfloat16>(input),
86+
GetTensorShape(output), GetTensorData<Eigen::bfloat16>(output));
87+
} else {
88+
reference_ops::Floor(
89+
GetTensorShape(input), GetTensorData<Eigen::bfloat16>(input),
90+
GetTensorShape(output), GetTensorData<Eigen::bfloat16>(output));
91+
}
6592
}
6693

6794
return kTfLiteOk;

tensorflow/lite/kernels/floor_test.cc

Lines changed: 65 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,19 @@ using ::testing::ElementsAreArray;
2929
class FloorOpModel : public SingleOpModel {
3030
public:
3131
FloorOpModel(std::initializer_list<int> input_shape, TensorType input_type) {
32-
input_ = AddInput(TensorType_FLOAT32);
33-
output_ = AddOutput(TensorType_FLOAT32);
32+
input_ = AddInput(input_type);
33+
output_ = AddOutput(input_type);
3434
SetBuiltinOp(BuiltinOperator_FLOOR, BuiltinOptions_NONE, 0);
3535
BuildInterpreter({
3636
input_shape,
3737
});
3838
}
3939

4040
int input() { return input_; }
41-
42-
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
41+
template <typename T>
42+
std::vector<T> GetOutput() {
43+
return ExtractVector<T>(output_);
44+
}
4345
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
4446

4547
private:
@@ -51,7 +53,7 @@ TEST(FloorOpTest, SingleDim) {
5153
FloorOpModel model({2}, TensorType_FLOAT32);
5254
model.PopulateTensor<float>(model.input(), {8.5, 0.0});
5355
ASSERT_EQ(model.Invoke(), kTfLiteOk);
54-
EXPECT_THAT(model.GetOutput(), ElementsAreArray({8, 0}));
56+
EXPECT_THAT(model.GetOutput<float>(), ElementsAreArray({8, 0}));
5557
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2}));
5658
}
5759

@@ -70,10 +72,67 @@ TEST(FloorOpTest, MultiDims) {
7072
-0.5,
7173
});
7274
ASSERT_EQ(model.Invoke(), kTfLiteOk);
73-
EXPECT_THAT(model.GetOutput(),
75+
EXPECT_THAT(model.GetOutput<float>(),
76+
ElementsAreArray({0, 8, 0, 9, 0, -1, -9, -1, -10, -1}));
77+
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 1, 1, 5}));
78+
}
79+
80+
TEST(FloorOpTest, SingleDimFloat16) {
81+
FloorOpModel model({2}, TensorType_FLOAT16);
82+
model.PopulateTensor<>(model.input(), {Eigen::half(8.5), Eigen::half(0.0)});
83+
ASSERT_EQ(model.Invoke(), kTfLiteOk);
84+
EXPECT_THAT(model.GetOutput<Eigen::half>(), ElementsAreArray({8, 0}));
85+
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2}));
86+
}
87+
88+
TEST(FloorOpTest, MultiDimsFloat16) {
89+
FloorOpModel model({2, 1, 1, 5}, TensorType_FLOAT16);
90+
model.PopulateTensor<Eigen::half>(model.input(), {
91+
Eigen::half(0.75),
92+
Eigen::half(8.25),
93+
Eigen::half(0.49),
94+
Eigen::half(9.99),
95+
Eigen::half(0.5),
96+
Eigen::half(-0.25),
97+
Eigen::half(-8.75),
98+
Eigen::half(-0.99),
99+
Eigen::half(-9.49),
100+
Eigen::half(-0.5),
101+
});
102+
ASSERT_EQ(model.Invoke(), kTfLiteOk);
103+
EXPECT_THAT(model.GetOutput<Eigen::half>(),
74104
ElementsAreArray({0, 8, 0, 9, 0, -1, -9, -1, -10, -1}));
75105
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 1, 1, 5}));
76106
}
77107

108+
109+
TEST(FloorOpTest, SingleDimBFloat16) {
110+
FloorOpModel model({2}, TensorType_BFLOAT16);
111+
model.PopulateTensor<>(model.input(), {Eigen::bfloat16(8.5),Eigen::bfloat16(0.0)});
112+
ASSERT_EQ(model.Invoke(), kTfLiteOk);
113+
EXPECT_THAT(model.GetOutput<Eigen::bfloat16>(), ElementsAreArray({8, 0}));
114+
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2}));
115+
}
116+
117+
TEST(FloorOpTest, MultiDimsBFloat16) {
118+
FloorOpModel model({2, 1, 1, 5}, TensorType_BFLOAT16);
119+
model.PopulateTensor<Eigen::bfloat16>(model.input(), {
120+
Eigen::bfloat16(1.75),
121+
Eigen::bfloat16(8.5),
122+
Eigen::bfloat16(1.49),
123+
Eigen::bfloat16(9.01),
124+
Eigen::bfloat16(1.5),
125+
Eigen::bfloat16(-1.25),
126+
Eigen::bfloat16(-8.99),
127+
Eigen::bfloat16(-1.99),
128+
Eigen::bfloat16(-9.5),
129+
Eigen::bfloat16(-1.5),
130+
});
131+
ASSERT_EQ(model.Invoke(), kTfLiteOk);
132+
EXPECT_THAT(model.GetOutput<Eigen::bfloat16>(),
133+
ElementsAreArray({1, 8, 1, 9, 1, -2, -9, -2, -10, -2}));
134+
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 1, 1, 5}));
135+
}
136+
78137
} // namespace
79138
} // namespace tflite

tensorflow/lite/kernels/internal/optimized/optimized_ops.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4224,8 +4224,9 @@ inline void Cast(const RuntimeShape& input_shape, const SrcT* input_data,
42244224
output_map.array() = input_map.array().template cast<DstT>();
42254225
}
42264226

4227-
inline void Floor(const RuntimeShape& input_shape, const float* input_data,
4228-
const RuntimeShape& output_shape, float* output_data) {
4227+
template <typename T>
4228+
inline void Floor(const RuntimeShape& input_shape, const T* input_data,
4229+
const RuntimeShape& output_shape, T* output_data) {
42294230
ruy::profiler::ScopeLabel label("Floor");
42304231
auto input_map = MapAsVector(input_data, input_shape);
42314232
auto output_map = MapAsVector(output_data, output_shape);

tensorflow/lite/kernels/internal/reference/floor.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,14 @@ namespace tflite {
2323

2424
namespace reference_ops {
2525

26-
inline void Floor(const RuntimeShape& input_shape, const float* input_data,
27-
const RuntimeShape& output_shape, float* output_data) {
26+
template <typename T>
27+
inline void Floor(const RuntimeShape& input_shape, const T* input_data,
28+
const RuntimeShape& output_shape, T* output_data) {
2829
const int flat_size = MatchingFlatSize(input_shape, output_shape);
2930

3031
for (int i = 0; i < flat_size; i++) {
3132
int offset = i;
32-
output_data[offset] = std::floor(input_data[offset]);
33+
output_data[offset] = static_cast<T>(std::floor(static_cast<float>(input_data[offset])));
3334
}
3435
}
3536

0 commit comments

Comments
 (0)