@@ -90,10 +90,13 @@ constexpr int kMaxThread = 256;
9090#else
9191constexpr int kMaxThread = 128 ;
9292#endif
93+ constexpr int warp_size = 32 ;
9394
9495// get blockDim for reduceLastDim and reduceAny
9596static inline int GetBlockDim (int block_dim) {
96- return block_dim >= kMaxThread ? kMaxThread : GetLastPow2 (block_dim);
97+ return block_dim >= kMaxThread
98+ ? kMaxThread
99+ : (block_dim <= warp_size ? warp_size : GetLastPow2 (block_dim));
97100}
98101
99102// check reduce rand is valid
@@ -393,26 +396,62 @@ struct ReduceConfig {
393396 dim3 grid;
394397};
395398
399+ // version 1
396400template <typename T, typename ReduceOp>
397- __device__ __forceinline__ T BlockReduce (T* shared, T val, ReduceOp reducer) {
401+ __device__ __forceinline__ T WarpReduce (T val, ReduceOp reducer) {
402+ unsigned mask = 0u ;
403+ CREATE_SHFL_MASK (mask, true );
404+ for (int stride = warpSize / 2 ; stride > 0 ; stride >>= 1 ) {
405+ T temp = paddle::platform::CudaShuffleDownSync (mask, val, stride);
406+ val = reducer (val, temp);
407+ }
408+ return val;
409+ }
410+
411+ template <typename T, typename ReduceOp>
412+ __device__ __forceinline__ T BlockReduce (T val, T init, ReduceOp reducer) {
413+ __shared__ T shared[32 ];
414+ int lane = threadIdx.x % warpSize;
415+ int wid = threadIdx.x / warpSize;
416+
417+ val = WarpReduce (val, reducer);
418+
419+ if (lane == 0 ) {
420+ shared[wid] = val;
421+ }
422+
423+ __syncthreads ();
424+
425+ val = (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : init;
426+
427+ if (wid == 0 ) {
428+ val = WarpReduce (val, reducer);
429+ }
430+ return val;
431+ }
432+
433+ // version 2
434+ template <typename T, typename ReduceOp>
435+ __device__ __forceinline__ T BlockReduce (T val, ReduceOp reducer) {
436+ __shared__ T shared[detail::kMaxThread ];
398437 constexpr int warp_size = 32 ;
399438 if (blockDim.x > warp_size) {
400439 shared[threadIdx.x ] = val;
401- }
402- for ( int offset = blockDim. x / 2 ; offset >= warp_size; offset >>= 1 ) {
403- __syncthreads ();
404- if (threadIdx. x < offset && threadIdx.x + offset < blockDim. x ) {
405- T temp = shared[threadIdx. x + offset] ;
406- val = reducer ( val, temp) ;
407- shared[threadIdx. x ] = val;
440+ for ( int stride = blockDim. x / 2 ; stride >= warp_size; stride >>= 1 ) {
441+ __syncthreads ();
442+ if (threadIdx. x < stride) {
443+ T temp = shared[ threadIdx.x + stride];
444+ val = reducer (val, temp) ;
445+ shared[threadIdx. x ] = val;
446+ }
408447 }
409448 }
410449 __syncthreads ();
411450
412451 unsigned mask = 0u ;
413452 CREATE_SHFL_MASK (mask, true );
414- for (int offset = warp_size / 2 ; offset > 0 ; offset >>= 1 ) {
415- T temp = paddle::platform::CudaShuffleDownSync (mask, val, offset );
453+ for (int stride = warp_size / 2 ; stride > 0 ; stride >>= 1 ) {
454+ T temp = paddle::platform::CudaShuffleDownSync (mask, val, stride );
416455 val = reducer (val, temp);
417456 }
418457 return val;
@@ -426,7 +465,6 @@ __device__ __forceinline__ void ReduceLastDim(const Tx* x, Ty* y,
426465 ReduceOp reducer,
427466 TransformOp transformer, Ty init,
428467 int reduce_num) {
429- __shared__ Ty shared_memory[detail::kMaxThread ];
430468 int idx_x = blockIdx.x * reduce_num;
431469 int idx_y = threadIdx.x ;
432470 Ty reduce_var = init;
@@ -436,7 +474,7 @@ __device__ __forceinline__ void ReduceLastDim(const Tx* x, Ty* y,
436474 }
437475 __syncthreads ();
438476
439- reduce_var = BlockReduce (shared_memory, reduce_var, reducer);
477+ reduce_var = BlockReduce (reduce_var, reducer);
440478
441479 if (threadIdx.x == 0 ) {
442480 y[blockIdx.x ] = reduce_var;
@@ -485,7 +523,6 @@ __device__ __forceinline__ void ReduceAny(
485523 paddle::framework::Array<int , ReduceRank> reduce_strides,
486524 paddle::framework::Array<int , Rank - ReduceRank> left_dim,
487525 paddle::framework::Array<int , Rank - ReduceRank> left_strides) {
488- __shared__ Ty shared_memory[detail::kMaxThread ];
489526 int sub_index[Rank];
490527 int left_idx = blockIdx.x ;
491528 for (int i = 0 ; i < Rank - ReduceRank; ++i) {
@@ -523,7 +560,7 @@ __device__ __forceinline__ void ReduceAny(
523560 }
524561 __syncthreads ();
525562
526- reduce_var = BlockReduce (shared_memory, reduce_var, reducer);
563+ reduce_var = BlockReduce (reduce_var, reducer);
527564 if (threadIdx.x == 0 ) {
528565 y[blockIdx.x ] = reduce_var;
529566 }
0 commit comments