@@ -24,6 +24,11 @@ TEST(DataTypeTransform, CPUTransform) {
24
24
paddle::framework::DataLayout::kAnyLayout ,
25
25
paddle::framework::LibraryType::kPlain );
26
26
27
+ auto kernel_bf16 = paddle::framework::OpKernelType (
28
+ paddle::framework::proto::VarType::BF16, place,
29
+ paddle::framework::DataLayout::kAnyLayout ,
30
+ paddle::framework::LibraryType::kPlain );
31
+
27
32
auto kernel_fp32 = paddle::framework::OpKernelType (
28
33
paddle::framework::proto::VarType::FP32, place,
29
34
paddle::framework::DataLayout::kAnyLayout ,
@@ -189,4 +194,120 @@ TEST(DataTypeTransform, CPUTransform) {
189
194
static_cast <paddle::platform::float16>(in_data_bool[i]).x );
190
195
}
191
196
}
197
+
198
+ // data type transform from/to bfloat16
199
+ {
200
+ paddle::framework::Tensor in;
201
+ paddle::framework::Tensor out;
202
+
203
+ paddle::platform::bfloat16* ptr =
204
+ in.mutable_data <paddle::platform::bfloat16>(
205
+ paddle::framework::make_ddim ({2 , 3 }), place);
206
+ int data_number = 2 * 3 ;
207
+
208
+ for (int i = 0 ; i < data_number; ++i) {
209
+ ptr[i] = i;
210
+ }
211
+
212
+ // transform from bfloat16 to other data types
213
+ paddle::framework::TransDataType (kernel_bf16, kernel_fp32, in, &out);
214
+ float * out_data_float = out.data <float >();
215
+ for (int i = 0 ; i < data_number; ++i) {
216
+ EXPECT_EQ (out_data_float[i], static_cast <float >(ptr[i]));
217
+ }
218
+
219
+ paddle::framework::TransDataType (kernel_bf16, kernel_fp64, in, &out);
220
+ double * out_data_double = out.data <double >();
221
+ for (int i = 0 ; i < data_number; ++i) {
222
+ EXPECT_EQ (out_data_double[i], static_cast <double >(ptr[i]));
223
+ }
224
+
225
+ paddle::framework::TransDataType (kernel_bf16, kernel_int32, in, &out);
226
+ int * out_data_int = out.data <int >();
227
+ for (int i = 0 ; i < data_number; ++i) {
228
+ EXPECT_EQ (out_data_int[i], static_cast <int >(ptr[i]));
229
+ }
230
+
231
+ paddle::framework::TransDataType (kernel_bf16, kernel_int64, in, &out);
232
+ int64_t * out_data_int64 = out.data <int64_t >();
233
+ for (int i = 0 ; i < data_number; ++i) {
234
+ EXPECT_EQ (out_data_int64[i], static_cast <int64_t >(ptr[i]));
235
+ }
236
+
237
+ paddle::framework::TransDataType (kernel_bf16, kernel_bool, in, &out);
238
+ bool * out_data_bool = out.data <bool >();
239
+ for (int i = 0 ; i < data_number; ++i) {
240
+ EXPECT_EQ (out_data_bool[i], static_cast <bool >(ptr[i]));
241
+ }
242
+
243
+ // transform float to bfloat16
244
+ float * in_data_float =
245
+ in.mutable_data <float >(paddle::framework::make_ddim ({2 , 3 }), place);
246
+ for (int i = 0 ; i < data_number; ++i) {
247
+ in_data_float[i] = i;
248
+ }
249
+
250
+ paddle::framework::TransDataType (kernel_fp32, kernel_bf16, in, &out);
251
+ ptr = out.data <paddle::platform::bfloat16>();
252
+ for (int i = 0 ; i < data_number; ++i) {
253
+ EXPECT_EQ (ptr[i].x ,
254
+ static_cast <paddle::platform::bfloat16>(in_data_float[i]).x );
255
+ }
256
+
257
+ // transform double to bfloat16
258
+ double * in_data_double =
259
+ in.mutable_data <double >(paddle::framework::make_ddim ({2 , 3 }), place);
260
+ for (int i = 0 ; i < data_number; ++i) {
261
+ in_data_double[i] = i;
262
+ }
263
+
264
+ paddle::framework::TransDataType (kernel_fp64, kernel_bf16, in, &out);
265
+ ptr = out.data <paddle::platform::bfloat16>();
266
+ for (int i = 0 ; i < data_number; ++i) {
267
+ EXPECT_EQ (ptr[i].x ,
268
+ static_cast <paddle::platform::bfloat16>(in_data_double[i]).x );
269
+ }
270
+
271
+ // transform int to bfloat16
272
+ int * in_data_int =
273
+ in.mutable_data <int >(paddle::framework::make_ddim ({2 , 3 }), place);
274
+ for (int i = 0 ; i < data_number; ++i) {
275
+ in_data_int[i] = i;
276
+ }
277
+
278
+ paddle::framework::TransDataType (kernel_int32, kernel_bf16, in, &out);
279
+ ptr = out.data <paddle::platform::bfloat16>();
280
+ for (int i = 0 ; i < data_number; ++i) {
281
+ EXPECT_EQ (ptr[i].x ,
282
+ static_cast <paddle::platform::bfloat16>(in_data_int[i]).x );
283
+ }
284
+
285
+ // transform int64 to bfloat16
286
+ int64_t * in_data_int64 =
287
+ in.mutable_data <int64_t >(paddle::framework::make_ddim ({2 , 3 }), place);
288
+ for (int i = 0 ; i < data_number; ++i) {
289
+ in_data_int64[i] = i;
290
+ }
291
+
292
+ paddle::framework::TransDataType (kernel_int64, kernel_bf16, in, &out);
293
+ ptr = out.data <paddle::platform::bfloat16>();
294
+ for (int i = 0 ; i < data_number; ++i) {
295
+ EXPECT_EQ (ptr[i].x ,
296
+ static_cast <paddle::platform::bfloat16>(in_data_int64[i]).x );
297
+ }
298
+
299
+ // transform bool to bfloat16
300
+ bool * in_data_bool =
301
+ in.mutable_data <bool >(paddle::framework::make_ddim ({2 , 3 }), place);
302
+ for (int i = 0 ; i < data_number; ++i) {
303
+ in_data_bool[i] = i;
304
+ }
305
+
306
+ paddle::framework::TransDataType (kernel_bool, kernel_bf16, in, &out);
307
+ ptr = out.data <paddle::platform::bfloat16>();
308
+ for (int i = 0 ; i < data_number; ++i) {
309
+ EXPECT_EQ (ptr[i].x ,
310
+ static_cast <paddle::platform::bfloat16>(in_data_bool[i]).x );
311
+ }
312
+ }
192
313
}
0 commit comments