Skip to content

Commit ca4df33

Browse files
authored
[bf16] add bf16 kernel: elementwise_div (#39602)
* add elementwise_div * refine rocm * refine code * refine op register * solve conflict * refine unittest * refine unittest precision * add rocm
1 parent 1fcaab4 commit ca4df33

File tree

5 files changed

+64
-1
lines changed

5 files changed

+64
-1
lines changed

paddle/fluid/operators/elementwise/elementwise_div_op.cu

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ REGISTER_OP_CUDA_KERNEL(
5353
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, float>,
5454
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext,
5555
paddle::platform::float16>,
56+
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext,
57+
paddle::platform::bfloat16>,
5658
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, double>,
5759
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, int>,
5860
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, int64_t>,
@@ -65,6 +67,8 @@ REGISTER_OP_CUDA_KERNEL(
6567
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, float>,
6668
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext,
6769
paddle::platform::float16>,
70+
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext,
71+
paddle::platform::bfloat16>,
6872
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, double>,
6973
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, int>,
7074
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
@@ -78,6 +82,8 @@ REGISTER_OP_CUDA_KERNEL(
7882
float>,
7983
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
8084
paddle::platform::float16>,
85+
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
86+
paddle::platform::bfloat16>,
8187
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
8288
double>,
8389
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,

paddle/fluid/platform/device/gpu/cuda/cuda_device_function.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,18 @@ __forceinline__ __device__ float16 CudaShuffleXorSync(unsigned mask,
105105
return float16(__shfl_xor_sync(mask, val.to_half(), width));
106106
}
107107

108+
template <>
109+
__forceinline__ __device__ bfloat16 CudaShuffleXorSync(unsigned mask,
110+
bfloat16 val,
111+
int width) {
112+
#if defined(PADDLE_CUDA_BF16)
113+
return bfloat16(__shfl_xor_sync(mask, static_cast<nv_bfloat16>(val), width));
114+
#else
115+
PADDLE_ENFORCE(
116+
false, "__shfl_xor_sync with bfloat16 is not supported on cuda <= 11.");
117+
#endif
118+
}
119+
108120
template <>
109121
__forceinline__ __device__ paddle::platform::complex<float> CudaShuffleXorSync(
110122
unsigned mask, paddle::platform::complex<float> val, int width) {

paddle/fluid/platform/device/gpu/rocm/rocm_device_function.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,13 @@ __forceinline__ __device__ float16 CudaShuffleXorSync(unsigned mask,
9191
return float16(__shfl_xor(static_cast<float>(val), width));
9292
}
9393

94+
template <>
95+
__forceinline__ __device__ bfloat16 CudaShuffleXorSync(unsigned mask,
96+
bfloat16 val,
97+
int width) {
98+
return bfloat16(__shfl_xor(static_cast<float>(val), width));
99+
}
100+
94101
template <>
95102
__forceinline__ __device__ paddle::platform::complex<float> CudaShuffleXorSync(
96103
unsigned mask, paddle::platform::complex<float> val, int width) {

paddle/phi/kernels/gpu/math_kernel.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ DEFINE_CUDA_ELEMENTWISE_OP(Divide)
9292
} // namespace phi
9393

9494
using float16 = phi::dtype::float16;
95+
using bfloat16 = phi::dtype::bfloat16;
9596
using complex64 = ::phi::dtype::complex<float>;
9697
using complex128 = ::phi::dtype::complex<double>;
9798

@@ -128,6 +129,7 @@ PD_REGISTER_KERNEL(divide_raw,
128129
int,
129130
int64_t,
130131
float16,
132+
bfloat16,
131133
complex64,
132134
complex128) {}
133135
PD_REGISTER_KERNEL(multiply_raw,

python/paddle/fluid/tests/unittests/test_elementwise_div_op.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import paddle
1919
import paddle.fluid as fluid
2020
import paddle.fluid.core as core
21-
from op_test import OpTest, skip_check_grad_ci
21+
from op_test import OpTest, skip_check_grad_ci, convert_float_to_uint16
2222

2323

2424
class ElementwiseDivOp(OpTest):
@@ -55,6 +55,42 @@ def init_dtype(self):
5555
pass
5656

5757

58+
@unittest.skipIf(
59+
not core.is_compiled_with_cuda() or core.cudnn_version() < 8100,
60+
"core is not compiled with CUDA and cudnn version need larger than 8.1.0")
61+
class TestElementwiseDivOpBF16(OpTest):
62+
def setUp(self):
63+
self.op_type = "elementwise_div"
64+
self.dtype = np.uint16
65+
66+
x = np.random.uniform(0.1, 1, [12, 13]).astype(np.float32)
67+
y = np.random.uniform(0.1, 1, [12, 13]).astype(np.float32)
68+
69+
out = np.divide(x, y)
70+
71+
self.inputs = {
72+
'X': convert_float_to_uint16(x),
73+
'Y': convert_float_to_uint16(y)
74+
}
75+
self.outputs = {'Out': convert_float_to_uint16(out)}
76+
77+
def test_check_output(self):
78+
place = core.CUDAPlace(0)
79+
self.check_output_with_place(place)
80+
81+
def test_check_grad_normal(self):
82+
place = core.CUDAPlace(0)
83+
self.check_grad_with_place(place, ['X', 'Y'], 'Out')
84+
85+
def test_check_grad_ingore_x(self):
86+
place = core.CUDAPlace(0)
87+
self.check_grad_with_place(place, ['Y'], 'Out', no_grad_set=set("X"))
88+
89+
def test_check_grad_ingore_y(self):
90+
place = core.CUDAPlace(0)
91+
self.check_grad_with_place(place, ['X'], 'Out', no_grad_set=set('Y'))
92+
93+
5894
@skip_check_grad_ci(
5995
reason="[skip shape check] Use y_shape(1) to test broadcast.")
6096
class TestElementwiseDivOp_scalar(ElementwiseDivOp):

0 commit comments

Comments
 (0)