Skip to content

Commit 28aa0c6

Browse files
authored
[DCU] Fix NAN problem when training BERT on DUC platform (#44643)
1 parent e7c7280 commit 28aa0c6

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,11 @@ static void MultiTensorL2Norm(const platform::CUDAPlace &place,
166166

167167
constexpr int kNumTensor = MaxTensorNumPerLaunch;
168168
constexpr int kNumChunk = MaxChunkNumPerLaunch;
169+
#ifdef PADDLE_WITH_HIP
170+
constexpr int kBlockDim = 256;
171+
#else
169172
constexpr int kBlockDim = 512;
173+
#endif
170174

171175
int max_chunk_num = -1;
172176
int vec_size = 8;
@@ -805,7 +809,11 @@ static void MultiTensorUpdateLambParamAndBetaPows(
805809
platform::errors::InvalidArgument("Beta2Pow should be nullptr."));
806810
}
807811

812+
#ifdef PADDLE_WITH_HIP
813+
const int block_dim = 256;
814+
#else
808815
const int block_dim = 512;
816+
#endif
809817

810818
int vec_size = 8;
811819
for (int i = 0; i < n; ++i) {

paddle/fluid/platform/device/gpu/rocm/rocm_device_function.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,11 @@ __device__ T reduceSum(T val, int tid, int len) {
134134
// I use Warp-Level Parallelism and assume the Warp size
135135
// is 32 which may be different for different GPU,
136136
// but most card's warp size is 32.
137+
#ifdef PADDLE_WITH_HIP
138+
const int warpSize = 64;
139+
#else
137140
const int warpSize = 32;
141+
#endif
138142
__shared__ T shm[warpSize];
139143
unsigned mask = 0u;
140144
CREATE_SHFL_MASK(mask, tid < len);

0 commit comments

Comments
 (0)