@@ -29,17 +29,19 @@ using ::testing::ElementsAreArray;
29
29
class FloorOpModel : public SingleOpModel {
30
30
public:
31
31
FloorOpModel (std::initializer_list<int > input_shape, TensorType input_type) {
32
- input_ = AddInput (TensorType_FLOAT32 );
33
- output_ = AddOutput (TensorType_FLOAT32 );
32
+ input_ = AddInput (input_type );
33
+ output_ = AddOutput (input_type );
34
34
SetBuiltinOp (BuiltinOperator_FLOOR, BuiltinOptions_NONE, 0 );
35
35
BuildInterpreter ({
36
36
input_shape,
37
37
});
38
38
}
39
39
40
40
int input () { return input_; }
41
-
42
- std::vector<float > GetOutput () { return ExtractVector<float >(output_); }
41
+ template <typename T>
42
+ std::vector<T> GetOutput () {
43
+ return ExtractVector<T>(output_);
44
+ }
43
45
std::vector<int > GetOutputShape () { return GetTensorShape (output_); }
44
46
45
47
private:
@@ -51,7 +53,7 @@ TEST(FloorOpTest, SingleDim) {
51
53
FloorOpModel model ({2 }, TensorType_FLOAT32);
52
54
model.PopulateTensor <float >(model.input (), {8.5 , 0.0 });
53
55
ASSERT_EQ (model.Invoke (), kTfLiteOk );
54
- EXPECT_THAT (model.GetOutput (), ElementsAreArray ({8 , 0 }));
56
+ EXPECT_THAT (model.GetOutput < float > (), ElementsAreArray ({8 , 0 }));
55
57
EXPECT_THAT (model.GetOutputShape (), ElementsAreArray ({2 }));
56
58
}
57
59
@@ -70,10 +72,67 @@ TEST(FloorOpTest, MultiDims) {
70
72
-0.5 ,
71
73
});
72
74
ASSERT_EQ (model.Invoke (), kTfLiteOk );
73
- EXPECT_THAT (model.GetOutput (),
75
+ EXPECT_THAT (model.GetOutput <float >(),
76
+ ElementsAreArray ({0 , 8 , 0 , 9 , 0 , -1 , -9 , -1 , -10 , -1 }));
77
+ EXPECT_THAT (model.GetOutputShape (), ElementsAreArray ({2 , 1 , 1 , 5 }));
78
+ }
79
+
80
+ TEST (FloorOpTest, SingleDimFloat16) {
81
+ FloorOpModel model ({2 }, TensorType_FLOAT16);
82
+ model.PopulateTensor <>(model.input (), {Eigen::half (8.5 ), Eigen::half (0.0 )});
83
+ ASSERT_EQ (model.Invoke (), kTfLiteOk );
84
+ EXPECT_THAT (model.GetOutput <Eigen::half>(), ElementsAreArray ({8 , 0 }));
85
+ EXPECT_THAT (model.GetOutputShape (), ElementsAreArray ({2 }));
86
+ }
87
+
88
+ TEST (FloorOpTest, MultiDimsFloat16) {
89
+ FloorOpModel model ({2 , 1 , 1 , 5 }, TensorType_FLOAT16);
90
+ model.PopulateTensor <Eigen::half>(model.input (), {
91
+ Eigen::half (0.75 ),
92
+ Eigen::half (8.25 ),
93
+ Eigen::half (0.49 ),
94
+ Eigen::half (9.99 ),
95
+ Eigen::half (0.5 ),
96
+ Eigen::half (-0.25 ),
97
+ Eigen::half (-8.75 ),
98
+ Eigen::half (-0.99 ),
99
+ Eigen::half (-9.49 ),
100
+ Eigen::half (-0.5 ),
101
+ });
102
+ ASSERT_EQ (model.Invoke (), kTfLiteOk );
103
+ EXPECT_THAT (model.GetOutput <Eigen::half>(),
74
104
ElementsAreArray ({0 , 8 , 0 , 9 , 0 , -1 , -9 , -1 , -10 , -1 }));
75
105
EXPECT_THAT (model.GetOutputShape (), ElementsAreArray ({2 , 1 , 1 , 5 }));
76
106
}
77
107
108
+
109
+ TEST (FloorOpTest, SingleDimBFloat16) {
110
+ FloorOpModel model ({2 }, TensorType_BFLOAT16);
111
+ model.PopulateTensor <>(model.input (), {Eigen::bfloat16 (8.5 ),Eigen::bfloat16 (0.0 )});
112
+ ASSERT_EQ (model.Invoke (), kTfLiteOk );
113
+ EXPECT_THAT (model.GetOutput <Eigen::bfloat16>(), ElementsAreArray ({8 , 0 }));
114
+ EXPECT_THAT (model.GetOutputShape (), ElementsAreArray ({2 }));
115
+ }
116
+
117
+ TEST (FloorOpTest, MultiDimsBFloat16) {
118
+ FloorOpModel model ({2 , 1 , 1 , 5 }, TensorType_BFLOAT16);
119
+ model.PopulateTensor <Eigen::bfloat16>(model.input (), {
120
+ Eigen::bfloat16 (1.75 ),
121
+ Eigen::bfloat16 (8.5 ),
122
+ Eigen::bfloat16 (1.49 ),
123
+ Eigen::bfloat16 (9.01 ),
124
+ Eigen::bfloat16 (1.5 ),
125
+ Eigen::bfloat16 (-1.25 ),
126
+ Eigen::bfloat16 (-8.99 ),
127
+ Eigen::bfloat16 (-1.99 ),
128
+ Eigen::bfloat16 (-9.5 ),
129
+ Eigen::bfloat16 (-1.5 ),
130
+ });
131
+ ASSERT_EQ (model.Invoke (), kTfLiteOk );
132
+ EXPECT_THAT (model.GetOutput <Eigen::bfloat16>(),
133
+ ElementsAreArray ({1 , 8 , 1 , 9 , 1 , -2 , -9 , -2 , -10 , -2 }));
134
+ EXPECT_THAT (model.GetOutputShape (), ElementsAreArray ({2 , 1 , 1 , 5 }));
135
+ }
136
+
78
137
} // namespace
79
138
} // namespace tflite
0 commit comments