@@ -57,6 +57,10 @@ constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024;
5757 *mod = dividend_copy % divisor; \
5858 } while (0 )
5959
60+ #define DIVUP (x, y ) (((x) + (y)-1 ) / (y))
61+
62+ #define ROUNDUP (x, y ) (DIVUP((x), (y)) * (y))
63+
6064namespace paddle {
6165namespace operators {
6266
@@ -2581,106 +2585,129 @@ static __global__ void FusedElemwiseAndActGradBroadcast1CUDAKernel(
25812585 const T *x, const T *y, const T *intermediate_out, const T *out,
25822586 const T *dout, int h, int w, DX_OP dx_op, DY_OP dy_op,
25832587 DIntermediate_OP dintermediate_op, T *dx, T *dy, T *d_intermediate) {
2584- int j = blockIdx.x ;
2585- int i = threadIdx.x ;
2586- int tid = threadIdx.x ;
2587- T val (0 ), inter_val (0 );
2588- int64_t tmp_out_idx, x_idx, y_idx;
2588+ __shared__ T sdata[BLOCK_Y][BLOCK_X];
2589+ size_t idx = threadIdx.x + BLOCK_X * blockIdx.x ;
2590+ size_t width_stride = gridDim.x * BLOCK_X;
2591+
2592+ size_t full_w = ROUNDUP (w, BLOCK_X);
2593+
25892594 T zero = static_cast <T>(0 );
25902595
2591- do {
2592- int offset = i * w + j;
2596+ for (size_t j = idx; j < full_w; j += width_stride) {
2597+ T val (0 ), inter_val (0 );
2598+ if (j < w) {
2599+ for (size_t i = threadIdx.y ; i < h; i += BLOCK_Y) {
2600+ size_t offset = i * w + j;
25932601
2594- tmp_out_idx = BcastY ? j : offset;
2595- y_idx = BcastY ? j : offset;
2596- x_idx = BcastY ? offset : j;
2597- T x_val = (x == nullptr ) ? zero : x[x_idx];
2598- T y_val = (y == nullptr ) ? zero : y[y_idx];
2602+ size_t tmp_out_idx = BcastY ? j : offset;
2603+ size_t y_idx = BcastY ? j : offset;
2604+ size_t x_idx = BcastY ? offset : j;
2605+ T x_val = (x == nullptr ) ? zero : x[x_idx];
2606+ T y_val = (y == nullptr ) ? zero : y[y_idx];
25992607
2600- if (SameShapeOfIntermediateOutAndOut) {
2601- tmp_out_idx = offset;
2602- }
2608+ if (SameShapeOfIntermediateOutAndOut) {
2609+ tmp_out_idx = offset;
2610+ }
26032611
2604- if (dx != nullptr ) {
2605- T tmp = UseIntermediateOut
2612+ if (dx != nullptr ) {
2613+ T tmp =
2614+ UseIntermediateOut
26062615 ? dx_op.UseIntermediateOut (x_val, y_val,
26072616 intermediate_out[tmp_out_idx],
26082617 out[offset], dout[offset])
26092618 : dx_op.Recompute (x_val, y_val, out[offset], dout[offset]);
26102619
2611- if (BcastY) {
2612- dx[x_idx] = tmp;
2613- } else {
2614- val += tmp;
2615- }
2616- }
2617- if (dy != nullptr ) {
2618- T tmp = UseIntermediateOut
2620+ if (BcastY) {
2621+ dx[x_idx] = tmp;
2622+ } else {
2623+ val += tmp;
2624+ }
2625+ }
2626+ if (dy != nullptr ) {
2627+ T tmp =
2628+ UseIntermediateOut
26192629 ? dy_op.UseIntermediateOut (x_val, y_val,
26202630 intermediate_out[tmp_out_idx],
26212631 out[offset], dout[offset])
26222632 : dy_op.Recompute (x_val, y_val, out[offset], dout[offset]);
2623- if (BcastY) {
2624- val += tmp;
2625- } else {
2626- dy[y_idx] = tmp;
2627- }
2628- }
2629- if (d_intermediate != nullptr ) {
2630- T tmp = UseIntermediateOut
2631- ? dintermediate_op.UseIntermediateOut (
2632- y[y_idx], intermediate_out[tmp_out_idx], out[offset],
2633- dout[offset])
2634- : dintermediate_op.Recompute (x_val, y_val, out[offset],
2635- dout[offset]);
2636- if (SameShapeOfIntermediateOutAndOut) {
2637- d_intermediate[tmp_out_idx] = tmp;
2638- } else {
2639- inter_val += tmp;
2633+ if (BcastY) {
2634+ val += tmp;
2635+ } else {
2636+ dy[y_idx] = tmp;
2637+ }
2638+ }
2639+ if (d_intermediate != nullptr ) {
2640+ T tmp = UseIntermediateOut
2641+ ? dintermediate_op.UseIntermediateOut (
2642+ y[y_idx], intermediate_out[tmp_out_idx],
2643+ out[offset], dout[offset])
2644+ : dintermediate_op.Recompute (x_val, y_val, out[offset],
2645+ dout[offset]);
2646+ if (SameShapeOfIntermediateOutAndOut) {
2647+ d_intermediate[tmp_out_idx] = tmp;
2648+ } else {
2649+ inter_val += tmp;
2650+ }
2651+ }
26402652 }
26412653 }
26422654
2643- i += ELEMWISE_MAX_BLOCK_DIM;
2644- } while (i < h);
2655+ // transpose, for ReduceSum with wrap
2656+ sdata[threadIdx.y ][threadIdx.x ] = val;
2657+ __syncthreads ();
2658+ val = sdata[threadIdx.x ][threadIdx.y ];
2659+ #pragma unroll
2660+ for (int i = BLOCK_X >> 1 ; i > 0 ; i >>= 1 ) {
2661+ // reduce sum with wrap
2662+ val += platform::CudaShuffleXorSync (0xFFFFFFFF , val, i);
2663+ }
26452664
2646- h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
2647- if (BcastY) {
2648- if (dy) {
2649- val = paddle::platform::reduceSum (val, tid, h);
2650- if (threadIdx.x == 0 ) {
2651- dy[j] = val;
2665+ size_t idx_j = j + threadIdx.y ;
2666+ if (BcastY) {
2667+ if (dy) {
2668+ if (threadIdx.x == 0 && (idx_j < w)) dy[idx_j] = val;
26522669 }
2653- }
2654- } else {
2655- if (dx) {
2656- val = paddle::platform::reduceSum (val, tid, h);
2657- if (threadIdx.x == 0 ) {
2658- dx[j] = val;
2670+ } else {
2671+ if (dx) {
2672+ if (threadIdx.x == 0 && (idx_j < w)) dx[idx_j] = val;
26592673 }
26602674 }
2661- }
2662- if (!SameShapeOfIntermediateOutAndOut) {
2663- if (d_intermediate) {
2664- inter_val = paddle::platform::reduceSum (inter_val, tid, h);
2665- if (threadIdx.x == 0 ) {
2666- d_intermediate[j] = inter_val;
2675+
2676+ if (!SameShapeOfIntermediateOutAndOut) {
2677+ if (d_intermediate) {
2678+ sdata[threadIdx.y ][threadIdx.x ] = inter_val;
2679+ __syncthreads ();
2680+ inter_val = sdata[threadIdx.x ][threadIdx.y ];
2681+ #pragma unroll
2682+ for (int i = BLOCK_X >> 1 ; i > 0 ; i >>= 1 ) {
2683+ // reduce sum with wrap
2684+ inter_val += platform::CudaShuffleXorSync (0xFFFFFFFF , inter_val, i);
2685+ }
2686+ if (threadIdx.x == 0 && (idx_j < w)) d_intermediate[idx_j] = inter_val;
26672687 }
26682688 }
2669- }
2689+ } // end for
26702690}
26712691
26722692template <typename T, typename DX_OP, typename DY_OP, typename DIntermediate_OP,
26732693 bool UseIntermediateOut, bool BcastY,
26742694 bool SameShapeOfIntermediateOutAndOut>
26752695static void FusedElemwiseAndActGradBroadcast1CUDA (
2676- gpuStream_t stream, const T *x, const T *y, const T *intermediate_out,
2677- const T *out, const T *dout, int h, int w, DX_OP dx_op, DY_OP dy_op,
2678- DIntermediate_OP dintermediate_op, T *dx, T *dy, T *d_intermediate) {
2679- int block_size = std::min (ELEMWISE_MAX_BLOCK_DIM, h);
2680- int gird_size = w;
2696+ const framework::ExecutionContext &ctx, const T *x, const T *y,
2697+ const T *intermediate_out, const T *out, const T *dout, int h, int w,
2698+ DX_OP dx_op, DY_OP dy_op, DIntermediate_OP dintermediate_op, T *dx, T *dy,
2699+ T *d_intermediate) {
2700+ gpuStream_t stream = ctx.cuda_device_context ().stream ();
2701+
2702+ dim3 blocks (BLOCK_X, BLOCK_Y);
2703+ int max_gpu_threads = ctx.cuda_device_context ().GetMaxPhysicalThreadCount ();
2704+ int max_blocks = std::max (max_gpu_threads / (BLOCK_X * BLOCK_Y), 1 );
2705+ int theory_block = (w + BLOCK_X - 1 ) / BLOCK_X;
2706+ dim3 grids (std::min (theory_block, max_blocks));
2707+
26812708 FusedElemwiseAndActGradBroadcast1CUDAKernel<
26822709 T, DX_OP, DY_OP, DIntermediate_OP, UseIntermediateOut, BcastY,
2683- SameShapeOfIntermediateOutAndOut><<<gird_size, block_size , 0 , stream>>>(
2710+ SameShapeOfIntermediateOutAndOut><<<grids, blocks , 0 , stream>>>(
26842711 x, y, intermediate_out, out, dout, h, w, dx_op, dy_op, dintermediate_op,
26852712 dx, dy, d_intermediate);
26862713}
@@ -2832,7 +2859,7 @@ void FusedElemwiseAndActGradComputeWithBroadcast(
28322859 FusedElemwiseAndActGradBroadcast1CUDA<T, DX_OP, DY_OP, DIntermediate_OP,
28332860 UseIntermediateOut, BcastY,
28342861 SameShapeOfIntermediateOutAndOut>(
2835- ctx. template device_context <DeviceContext>(). stream () , x_data, y_data,
2862+ ctx, x_data, y_data,
28362863 intermediate_out == nullptr ? nullptr : intermediate_out->data <T>(),
28372864 out->data <T>(), dout->data <T>(), h, w, dx_op, dy_op, dintermediate_op,
28382865 dx == nullptr ? nullptr : dx->mutable_data <T>(ctx.GetPlace ()),
0 commit comments