Skip to content

Commit 805cc7a

Browse files
authored
[DCU] LLM train surpport bf16 (#65857)
* [DCU] New features for LLM * fix WeightQuantizeInferMeta * [DCU] fix flashattn cmake * fix * fix code style * [DCU] add FLAGS_batch_norm_use_miopen * fix * fix setup.py.in * [DCU] LLM bf16 * fix code style * fix * fix
1 parent 3e95aa9 commit 805cc7a

28 files changed

+61
-37
lines changed

paddle/fluid/operators/collective/alltoall_op.cu.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,8 @@ PD_REGISTER_STRUCT_KERNEL(alltoall,
146146
ops::AllToAllOpCUDAKernel,
147147
float,
148148
double,
149-
#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
149+
#if (NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000) || \
150+
defined(PADDLE_WITH_HIP)
150151
phi::dtype::bfloat16,
151152
#endif
152153
int,

paddle/fluid/operators/collective/c_allgather_op.cu.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,8 @@ PD_REGISTER_STRUCT_KERNEL(c_allgather,
129129
ops::CAllGatherOpCUDAKernel,
130130
float,
131131
double,
132-
#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
132+
#if (NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000) || \
133+
defined(PADDLE_WITH_HIP)
133134
phi::dtype::bfloat16,
134135
#endif
135136
int,

paddle/fluid/operators/collective/c_allreduce_max_op.cu.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ PD_REGISTER_STRUCT_KERNEL(c_allreduce_max,
2727
ALL_LAYOUT,
2828
ops::CAllReduceMaxCUDAKernel,
2929
float,
30-
#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
30+
#if (NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000) || \
31+
defined(PADDLE_WITH_HIP)
3132
phi::dtype::bfloat16,
3233
#endif
3334
double,

paddle/fluid/operators/collective/c_allreduce_op.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,8 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel<T> {
411411
nccl_red_type = ncclProd;
412412
break;
413413

414-
#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
414+
#if (NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000) || \
415+
defined(PADDLE_WITH_HIP)
415416
case kRedAvg:
416417
nccl_red_type = ncclAvg;
417418
break;

paddle/fluid/operators/collective/c_allreduce_sum_op.cu.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ PD_REGISTER_STRUCT_KERNEL(c_allreduce_sum,
2727
ALL_LAYOUT,
2828
ops::CAllReduceSumCUDAKernel,
2929
float,
30-
#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
30+
#if (NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000) || \
31+
defined(PADDLE_WITH_HIP)
3132
phi::dtype::bfloat16,
3233
#endif
3334
double,

paddle/fluid/operators/collective/c_broadcast_op.cu.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,8 @@ PD_REGISTER_STRUCT_KERNEL(c_broadcast,
110110
int64_t,
111111
float,
112112
double,
113-
#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
113+
#if (NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000) || \
114+
defined(PADDLE_WITH_HIP)
114115
phi::dtype::bfloat16,
115116
#endif
116117
phi::dtype::float16) {

paddle/fluid/operators/collective/c_concat_op.cu.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,8 @@ PD_REGISTER_STRUCT_KERNEL(c_concat,
174174
double,
175175
int,
176176
int64_t,
177-
#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
177+
#if (NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000) || \
178+
defined(PADDLE_WITH_HIP)
178179
phi::dtype::bfloat16,
179180
#endif
180181
phi::dtype::float16) {

paddle/fluid/operators/collective/c_reduce_op.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,8 @@ class CReduceOpCUDAKernel : public framework::OpKernel<T> {
302302
nccl_red_type = ncclProd;
303303
break;
304304

305-
#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
305+
#if (NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000) || \
306+
defined(PADDLE_WITH_HIP)
306307
case kRedAvg:
307308
nccl_red_type = ncclAvg;
308309
break;

paddle/fluid/operators/collective/c_reducescatter_op.cu.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,8 @@ PD_REGISTER_STRUCT_KERNEL(c_reducescatter,
134134
ops::CReduceScatterOpCUDAKernel,
135135
float,
136136
double,
137-
#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
137+
#if (NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000) || \
138+
defined(PADDLE_WITH_HIP)
138139
phi::dtype::bfloat16,
139140
#endif
140141
int,

paddle/fluid/operators/collective/mp_allreduce_sum_op.cu.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ PD_REGISTER_STRUCT_KERNEL(mp_allreduce_sum,
3131
double,
3232
int,
3333
int64_t,
34-
#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
34+
#if (NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000) || \
35+
defined(PADDLE_WITH_HIP)
3536
phi::dtype::bfloat16,
3637
#endif
3738
phi::dtype::float16) {

0 commit comments

Comments
 (0)