Skip to content

Commit 95e1434

Browse files
authored
Add bfloat16 data type (#25402)
1 parent 3ba7b9b commit 95e1434

19 files changed

+832
-63
lines changed

paddle/fluid/framework/data_layout_transform.cc

+2
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ void* GetDataFromTensor(const Tensor& tensor, mkldnn::memory::data_type type) {
116116
return platform::to_void_cast(tensor.data<unsigned char>());
117117
case mkldnn::memory::data_type::s32:
118118
return platform::to_void_cast(tensor.data<int32_t>());
119+
case mkldnn::memory::data_type::bf16:
120+
return platform::to_void_cast(tensor.data<paddle::platform::bfloat16>());
119121
default:
120122
PADDLE_THROW(
121123
platform::errors::InvalidArgument("Wrong mkldnn type provided."));

paddle/fluid/framework/data_layout_transform.h

+5-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@ inline MKLDNNDataType ToMKLDNNDataType(proto::VarType::Type type) {
6161
{DataTypeTrait<float>::DataType(), MKLDNNDataType::f32},
6262
{DataTypeTrait<int8_t>::DataType(), MKLDNNDataType::s8},
6363
{DataTypeTrait<uint8_t>::DataType(), MKLDNNDataType::u8},
64-
{DataTypeTrait<int32_t>::DataType(), MKLDNNDataType::s32}};
64+
{DataTypeTrait<int32_t>::DataType(), MKLDNNDataType::s32},
65+
{DataTypeTrait<platform::bfloat16>::DataType(), MKLDNNDataType::bf16}};
6566
auto iter = dict.find(static_cast<int>(type));
6667
if (iter != dict.end()) return iter->second;
6768
return MKLDNNDataType::undef;
@@ -74,6 +75,9 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
7475
void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
7576
const OpKernelType& expected_kernel_type,
7677
const Tensor& in, Tensor* out);
78+
79+
void* GetDataFromTensor(const Tensor& tensor, MKLDNNDataType type);
80+
7781
#endif
7882

7983
std::vector<int> GetAxis(const DataLayout& from, const DataLayout& to);

paddle/fluid/framework/data_layout_transform_test.cc

+14
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,17 @@ TEST(DataTransform, DataLayoutFunction) {
4343
EXPECT_TRUE(in.layout() == paddle::framework::DataLayout::kNHWC);
4444
EXPECT_TRUE(in.dims() == paddle::framework::make_ddim({2, 3, 1, 2}));
4545
}
46+
47+
#ifdef PADDLE_WITH_MKLDNN
48+
TEST(DataTransform, GetDataFromTensorDNNL) {
49+
auto place = paddle::platform::CPUPlace();
50+
paddle::framework::Tensor in = paddle::framework::Tensor();
51+
in.mutable_data<paddle::platform::bfloat16>(
52+
paddle::framework::make_ddim({2, 3, 1, 2}), place);
53+
54+
void* in_data =
55+
paddle::framework::GetDataFromTensor(in, dnnl::memory::data_type::bf16);
56+
EXPECT_EQ(in_data, paddle::platform::to_void_cast(
57+
in.data<paddle::platform::bfloat16>()));
58+
}
59+
#endif

paddle/fluid/framework/data_type.cc

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <unordered_map>
1919

2020
using float16 = paddle::platform::float16;
21+
using bfloat16 = paddle::platform::bfloat16;
2122

2223
namespace paddle {
2324
namespace framework {

paddle/fluid/framework/data_type.h

+12-9
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ limitations under the License. */
1717
#include <typeindex>
1818
#include "paddle/fluid/framework/framework.pb.h"
1919
#include "paddle/fluid/platform/enforce.h"
20+
21+
#include "paddle/fluid/platform/bfloat16.h"
2022
#include "paddle/fluid/platform/float16.h"
2123

2224
namespace paddle {
@@ -36,15 +38,16 @@ struct DataTypeTrait<void> {
3638
#define _ForEachDataTypeHelper_(callback, cpp_type, proto_type) \
3739
callback(cpp_type, ::paddle::framework::proto::VarType::proto_type);
3840

39-
#define _ForEachDataType_(callback) \
40-
_ForEachDataTypeHelper_(callback, float, FP32); \
41-
_ForEachDataTypeHelper_(callback, ::paddle::platform::float16, FP16); \
42-
_ForEachDataTypeHelper_(callback, double, FP64); \
43-
_ForEachDataTypeHelper_(callback, int, INT32); \
44-
_ForEachDataTypeHelper_(callback, int64_t, INT64); \
45-
_ForEachDataTypeHelper_(callback, bool, BOOL); \
46-
_ForEachDataTypeHelper_(callback, uint8_t, UINT8); \
47-
_ForEachDataTypeHelper_(callback, int16_t, INT16); \
41+
#define _ForEachDataType_(callback) \
42+
_ForEachDataTypeHelper_(callback, float, FP32); \
43+
_ForEachDataTypeHelper_(callback, ::paddle::platform::float16, FP16); \
44+
_ForEachDataTypeHelper_(callback, ::paddle::platform::bfloat16, BF16); \
45+
_ForEachDataTypeHelper_(callback, double, FP64); \
46+
_ForEachDataTypeHelper_(callback, int, INT32); \
47+
_ForEachDataTypeHelper_(callback, int64_t, INT64); \
48+
_ForEachDataTypeHelper_(callback, bool, BOOL); \
49+
_ForEachDataTypeHelper_(callback, uint8_t, UINT8); \
50+
_ForEachDataTypeHelper_(callback, int16_t, INT16); \
4851
_ForEachDataTypeHelper_(callback, int8_t, INT8)
4952

5053
#define _ForEachDataTypeSmall_(callback) \

paddle/fluid/framework/data_type_test.cc

+22
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,25 @@ TEST(DataType, float16) {
3838
std::string type = "::paddle::platform::float16";
3939
EXPECT_STREQ(f::DataTypeToString(dtype).c_str(), type.c_str());
4040
}
41+
42+
TEST(DataType, bfloat16) {
43+
using paddle::framework::Tensor;
44+
using paddle::platform::CPUPlace;
45+
using paddle::platform::bfloat16;
46+
namespace f = paddle::framework;
47+
f::proto::VarType::Type dtype = f::proto::VarType::BF16;
48+
49+
Tensor tensor;
50+
CPUPlace cpu;
51+
tensor.mutable_data(cpu, dtype);
52+
53+
// test bf16 tensor
54+
EXPECT_EQ(tensor.type(), f::ToDataType(typeid(bfloat16)));
55+
56+
// test bf16 size
57+
EXPECT_EQ(f::SizeOfType(dtype), 2u);
58+
59+
// test debug info
60+
std::string type = "::paddle::platform::bfloat16";
61+
EXPECT_STREQ(f::DataTypeToString(dtype).c_str(), type.c_str());
62+
}

paddle/fluid/framework/data_type_transform.cc

+4
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@ void TransDataType(const OpKernelType& kernel_type_for_var,
7777
framework::VisitDataType(dst_type,
7878
CastDataType<platform::float16>(in, out, ctx));
7979
break;
80+
case proto::VarType::BF16:
81+
framework::VisitDataType(dst_type,
82+
CastDataType<platform::bfloat16>(in, out, ctx));
83+
break;
8084
case proto::VarType::FP32:
8185
framework::VisitDataType(dst_type, CastDataType<float>(in, out, ctx));
8286
break;

paddle/fluid/framework/data_type_transform_test.cc

+121
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@ TEST(DataTypeTransform, CPUTransform) {
2424
paddle::framework::DataLayout::kAnyLayout,
2525
paddle::framework::LibraryType::kPlain);
2626

27+
auto kernel_bf16 = paddle::framework::OpKernelType(
28+
paddle::framework::proto::VarType::BF16, place,
29+
paddle::framework::DataLayout::kAnyLayout,
30+
paddle::framework::LibraryType::kPlain);
31+
2732
auto kernel_fp32 = paddle::framework::OpKernelType(
2833
paddle::framework::proto::VarType::FP32, place,
2934
paddle::framework::DataLayout::kAnyLayout,
@@ -189,4 +194,120 @@ TEST(DataTypeTransform, CPUTransform) {
189194
static_cast<paddle::platform::float16>(in_data_bool[i]).x);
190195
}
191196
}
197+
198+
// data type transform from/to bfloat16
199+
{
200+
paddle::framework::Tensor in;
201+
paddle::framework::Tensor out;
202+
203+
paddle::platform::bfloat16* ptr =
204+
in.mutable_data<paddle::platform::bfloat16>(
205+
paddle::framework::make_ddim({2, 3}), place);
206+
int data_number = 2 * 3;
207+
208+
for (int i = 0; i < data_number; ++i) {
209+
ptr[i] = i;
210+
}
211+
212+
// transform from bfloat16 to other data types
213+
paddle::framework::TransDataType(kernel_bf16, kernel_fp32, in, &out);
214+
float* out_data_float = out.data<float>();
215+
for (int i = 0; i < data_number; ++i) {
216+
EXPECT_EQ(out_data_float[i], static_cast<float>(ptr[i]));
217+
}
218+
219+
paddle::framework::TransDataType(kernel_bf16, kernel_fp64, in, &out);
220+
double* out_data_double = out.data<double>();
221+
for (int i = 0; i < data_number; ++i) {
222+
EXPECT_EQ(out_data_double[i], static_cast<double>(ptr[i]));
223+
}
224+
225+
paddle::framework::TransDataType(kernel_bf16, kernel_int32, in, &out);
226+
int* out_data_int = out.data<int>();
227+
for (int i = 0; i < data_number; ++i) {
228+
EXPECT_EQ(out_data_int[i], static_cast<int>(ptr[i]));
229+
}
230+
231+
paddle::framework::TransDataType(kernel_bf16, kernel_int64, in, &out);
232+
int64_t* out_data_int64 = out.data<int64_t>();
233+
for (int i = 0; i < data_number; ++i) {
234+
EXPECT_EQ(out_data_int64[i], static_cast<int64_t>(ptr[i]));
235+
}
236+
237+
paddle::framework::TransDataType(kernel_bf16, kernel_bool, in, &out);
238+
bool* out_data_bool = out.data<bool>();
239+
for (int i = 0; i < data_number; ++i) {
240+
EXPECT_EQ(out_data_bool[i], static_cast<bool>(ptr[i]));
241+
}
242+
243+
// transform float to bfloat16
244+
float* in_data_float =
245+
in.mutable_data<float>(paddle::framework::make_ddim({2, 3}), place);
246+
for (int i = 0; i < data_number; ++i) {
247+
in_data_float[i] = i;
248+
}
249+
250+
paddle::framework::TransDataType(kernel_fp32, kernel_bf16, in, &out);
251+
ptr = out.data<paddle::platform::bfloat16>();
252+
for (int i = 0; i < data_number; ++i) {
253+
EXPECT_EQ(ptr[i].x,
254+
static_cast<paddle::platform::bfloat16>(in_data_float[i]).x);
255+
}
256+
257+
// transform double to bfloat16
258+
double* in_data_double =
259+
in.mutable_data<double>(paddle::framework::make_ddim({2, 3}), place);
260+
for (int i = 0; i < data_number; ++i) {
261+
in_data_double[i] = i;
262+
}
263+
264+
paddle::framework::TransDataType(kernel_fp64, kernel_bf16, in, &out);
265+
ptr = out.data<paddle::platform::bfloat16>();
266+
for (int i = 0; i < data_number; ++i) {
267+
EXPECT_EQ(ptr[i].x,
268+
static_cast<paddle::platform::bfloat16>(in_data_double[i]).x);
269+
}
270+
271+
// transform int to bfloat16
272+
int* in_data_int =
273+
in.mutable_data<int>(paddle::framework::make_ddim({2, 3}), place);
274+
for (int i = 0; i < data_number; ++i) {
275+
in_data_int[i] = i;
276+
}
277+
278+
paddle::framework::TransDataType(kernel_int32, kernel_bf16, in, &out);
279+
ptr = out.data<paddle::platform::bfloat16>();
280+
for (int i = 0; i < data_number; ++i) {
281+
EXPECT_EQ(ptr[i].x,
282+
static_cast<paddle::platform::bfloat16>(in_data_int[i]).x);
283+
}
284+
285+
// transform int64 to bfloat16
286+
int64_t* in_data_int64 =
287+
in.mutable_data<int64_t>(paddle::framework::make_ddim({2, 3}), place);
288+
for (int i = 0; i < data_number; ++i) {
289+
in_data_int64[i] = i;
290+
}
291+
292+
paddle::framework::TransDataType(kernel_int64, kernel_bf16, in, &out);
293+
ptr = out.data<paddle::platform::bfloat16>();
294+
for (int i = 0; i < data_number; ++i) {
295+
EXPECT_EQ(ptr[i].x,
296+
static_cast<paddle::platform::bfloat16>(in_data_int64[i]).x);
297+
}
298+
299+
// transform bool to bfloat16
300+
bool* in_data_bool =
301+
in.mutable_data<bool>(paddle::framework::make_ddim({2, 3}), place);
302+
for (int i = 0; i < data_number; ++i) {
303+
in_data_bool[i] = i;
304+
}
305+
306+
paddle::framework::TransDataType(kernel_bool, kernel_bf16, in, &out);
307+
ptr = out.data<paddle::platform::bfloat16>();
308+
for (int i = 0; i < data_number; ++i) {
309+
EXPECT_EQ(ptr[i].x,
310+
static_cast<paddle::platform::bfloat16>(in_data_bool[i]).x);
311+
}
312+
}
192313
}

paddle/fluid/framework/details/nan_inf_utils_detail.cc

+2
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,8 @@ static void PrintNanInf(const T* value, const size_t numel, int print_num,
167167
// more detail see: 180 page of
168168
// https://www.openmp.org/wp-content/uploads/OpenMP4.0.0.pdf
169169
#pragma omp declare reduction(+ : paddle::platform::float16 : omp_out += omp_in)
170+
#pragma omp declare reduction(+ : paddle::platform::bfloat16 : omp_out += \
171+
omp_in)
170172
#endif
171173

172174
template <typename T>

paddle/fluid/framework/dlpack_tensor.cc

+1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ template <typename T>
2323
static ::DLDataType GetDLDataTypeCode() {
2424
::DLDataType dtype;
2525
if (std::is_same<T, platform::float16>::value ||
26+
std::is_same<T, platform::bfloat16>::value ||
2627
std::is_floating_point<T>::value) {
2728
dtype.code = kDLFloat;
2829
} else if (std::is_unsigned<T>::value) {

paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc

+1-27
Original file line numberDiff line numberDiff line change
@@ -90,32 +90,6 @@ void MemoryOptimizePass::CollectLifeCycle(
9090
}
9191
}
9292

93-
// TODO(Superjomn) Make this a general help method.
94-
int DataTypeToSpace(framework::proto::VarType_Type type) {
95-
switch (type) {
96-
case framework::proto::VarType_Type_BOOL:
97-
return sizeof(bool);
98-
case framework::proto::VarType_Type_FP32:
99-
return sizeof(float);
100-
case framework::proto::VarType_Type_INT32:
101-
return sizeof(int32_t);
102-
case framework::proto::VarType_Type_INT64:
103-
return sizeof(int64_t);
104-
case framework::proto::VarType_Type_INT16:
105-
return sizeof(int16_t);
106-
case framework::proto::VarType_Type_FP16:
107-
return sizeof(int16_t);
108-
case framework::proto::VarType_Type_FP64:
109-
return sizeof(double);
110-
case framework::proto::VarType_Type_UINT8:
111-
return sizeof(unsigned char);
112-
case framework::proto::VarType_Type_INT8:
113-
return sizeof(int8_t);
114-
default:
115-
PADDLE_THROW("Unknown data type");
116-
}
117-
}
118-
11993
void MemoryOptimizePass::CollectVarMemorySize(
12094
space_table_t* space_table) const {
12195
const int fake_batch_size = 1;
@@ -163,7 +137,7 @@ void MemoryOptimizePass::CollectVarMemorySize(
163137
int size = std::accumulate(shape.begin(), shape.end(), 1,
164138
std::multiplies<int>());
165139
(*space_table)[node->Var()->Name()] =
166-
size * DataTypeToSpace(node->Var()->GetDataType());
140+
size * paddle::framework::SizeOfType(node->Var()->GetDataType());
167141
}
168142
}
169143
}

paddle/fluid/inference/lite/test_engine.cc

+3-2
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,16 @@
1414

1515
#include <gtest/gtest.h>
1616

17-
#include "paddle/fluid/inference/lite/engine.h"
1817
#include "paddle/fluid/inference/utils/singleton.h"
19-
#include "paddle/fluid/operators/lite/ut_helper.h"
2018

2119
#include "paddle/fluid/framework/block_desc.h"
2220
#include "paddle/fluid/framework/op_desc.h"
2321
#include "paddle/fluid/framework/program_desc.h"
2422
#include "paddle/fluid/framework/scope.h"
2523

24+
#include "paddle/fluid/inference/lite/engine.h"
25+
#include "paddle/fluid/operators/lite/ut_helper.h"
26+
2627
namespace paddle {
2728
namespace inference {
2829
namespace lite {

paddle/fluid/operators/math/concat_and_split.h

+11-10
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,14 @@ class SplitFunctor {
6565
} // namespace operators
6666
} // namespace paddle
6767

68-
#define FOR_ALL_TYPES(macro) \
69-
macro(int); \
70-
macro(float); \
71-
macro(double); \
72-
macro(bool); \
73-
macro(int64_t); \
74-
macro(int16_t); \
75-
macro(uint8_t); \
76-
macro(int8_t); \
77-
macro(::paddle::platform::float16)
68+
#define FOR_ALL_TYPES(macro) \
69+
macro(int); \
70+
macro(float); \
71+
macro(double); \
72+
macro(bool); \
73+
macro(int64_t); \
74+
macro(int16_t); \
75+
macro(uint8_t); \
76+
macro(int8_t); \
77+
macro(::paddle::platform::float16); \
78+
macro(::paddle::platform::bfloat16)

paddle/fluid/operators/math/math_function.cc

+13-10
Original file line numberDiff line numberDiff line change
@@ -34,23 +34,26 @@ namespace math {
3434
using float16 = paddle::platform::float16;
3535

3636
template struct SetConstant<platform::CPUDeviceContext, platform::float16>;
37+
template struct SetConstant<platform::CPUDeviceContext, platform::bfloat16>;
3738
template struct SetConstant<platform::CPUDeviceContext, float>;
3839
template struct SetConstant<platform::CPUDeviceContext, double>;
3940
template struct SetConstant<platform::CPUDeviceContext, int>;
4041
template struct SetConstant<platform::CPUDeviceContext, int64_t>;
4142
template struct SetConstant<platform::CPUDeviceContext, bool>;
4243
template struct SetConstant<platform::CPUDeviceContext, uint8_t>;
4344

44-
#define DEFINE_CPU_TRANS(RANK) \
45-
template struct Transpose<platform::CPUDeviceContext, platform::float16, \
46-
RANK>; \
47-
template struct Transpose<platform::CPUDeviceContext, float, RANK>; \
48-
template struct Transpose<platform::CPUDeviceContext, double, RANK>; \
49-
template struct Transpose<platform::CPUDeviceContext, int, RANK>; \
50-
template struct Transpose<platform::CPUDeviceContext, int64_t, RANK>; \
51-
template struct Transpose<platform::CPUDeviceContext, bool, RANK>; \
52-
template struct Transpose<platform::CPUDeviceContext, int16_t, RANK>; \
53-
template struct Transpose<platform::CPUDeviceContext, uint8_t, RANK>; \
45+
#define DEFINE_CPU_TRANS(RANK) \
46+
template struct Transpose<platform::CPUDeviceContext, platform::float16, \
47+
RANK>; \
48+
template struct Transpose<platform::CPUDeviceContext, platform::bfloat16, \
49+
RANK>; \
50+
template struct Transpose<platform::CPUDeviceContext, float, RANK>; \
51+
template struct Transpose<platform::CPUDeviceContext, double, RANK>; \
52+
template struct Transpose<platform::CPUDeviceContext, int, RANK>; \
53+
template struct Transpose<platform::CPUDeviceContext, int64_t, RANK>; \
54+
template struct Transpose<platform::CPUDeviceContext, bool, RANK>; \
55+
template struct Transpose<platform::CPUDeviceContext, int16_t, RANK>; \
56+
template struct Transpose<platform::CPUDeviceContext, uint8_t, RANK>; \
5457
template struct Transpose<platform::CPUDeviceContext, int8_t, RANK>;
5558

5659
DEFINE_CPU_TRANS(1);

0 commit comments

Comments
 (0)