Skip to content

Commit 26bb9a4

Browse files
committed
optimize fused_elewise_act backward performance
1 parent dad8c72 commit 26bb9a4

File tree

1 file changed

+97
-70
lines changed

1 file changed

+97
-70
lines changed

paddle/fluid/operators/elementwise/elementwise_op_function.h

Lines changed: 97 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024;
5757
*mod = dividend_copy % divisor; \
5858
} while (0)
5959

60+
#define DIVUP(x, y) (((x) + (y)-1) / (y))
61+
62+
#define ROUNDUP(x, y) (DIVUP((x), (y)) * (y))
63+
6064
namespace paddle {
6165
namespace operators {
6266

@@ -2581,106 +2585,129 @@ static __global__ void FusedElemwiseAndActGradBroadcast1CUDAKernel(
25812585
const T *x, const T *y, const T *intermediate_out, const T *out,
25822586
const T *dout, int h, int w, DX_OP dx_op, DY_OP dy_op,
25832587
DIntermediate_OP dintermediate_op, T *dx, T *dy, T *d_intermediate) {
2584-
int j = blockIdx.x;
2585-
int i = threadIdx.x;
2586-
int tid = threadIdx.x;
2587-
T val(0), inter_val(0);
2588-
int64_t tmp_out_idx, x_idx, y_idx;
2588+
__shared__ T sdata[BLOCK_Y][BLOCK_X];
2589+
size_t idx = threadIdx.x + BLOCK_X * blockIdx.x;
2590+
size_t width_stride = gridDim.x * BLOCK_X;
2591+
2592+
size_t full_w = ROUNDUP(w, BLOCK_X);
2593+
25892594
T zero = static_cast<T>(0);
25902595

2591-
do {
2592-
int offset = i * w + j;
2596+
for (size_t j = idx; j < full_w; j += width_stride) {
2597+
T val(0), inter_val(0);
2598+
if (j < w) {
2599+
for (size_t i = threadIdx.y; i < h; i += BLOCK_Y) {
2600+
size_t offset = i * w + j;
25932601

2594-
tmp_out_idx = BcastY ? j : offset;
2595-
y_idx = BcastY ? j : offset;
2596-
x_idx = BcastY ? offset : j;
2597-
T x_val = (x == nullptr) ? zero : x[x_idx];
2598-
T y_val = (y == nullptr) ? zero : y[y_idx];
2602+
size_t tmp_out_idx = BcastY ? j : offset;
2603+
size_t y_idx = BcastY ? j : offset;
2604+
size_t x_idx = BcastY ? offset : j;
2605+
T x_val = (x == nullptr) ? zero : x[x_idx];
2606+
T y_val = (y == nullptr) ? zero : y[y_idx];
25992607

2600-
if (SameShapeOfIntermediateOutAndOut) {
2601-
tmp_out_idx = offset;
2602-
}
2608+
if (SameShapeOfIntermediateOutAndOut) {
2609+
tmp_out_idx = offset;
2610+
}
26032611

2604-
if (dx != nullptr) {
2605-
T tmp = UseIntermediateOut
2612+
if (dx != nullptr) {
2613+
T tmp =
2614+
UseIntermediateOut
26062615
? dx_op.UseIntermediateOut(x_val, y_val,
26072616
intermediate_out[tmp_out_idx],
26082617
out[offset], dout[offset])
26092618
: dx_op.Recompute(x_val, y_val, out[offset], dout[offset]);
26102619

2611-
if (BcastY) {
2612-
dx[x_idx] = tmp;
2613-
} else {
2614-
val += tmp;
2615-
}
2616-
}
2617-
if (dy != nullptr) {
2618-
T tmp = UseIntermediateOut
2620+
if (BcastY) {
2621+
dx[x_idx] = tmp;
2622+
} else {
2623+
val += tmp;
2624+
}
2625+
}
2626+
if (dy != nullptr) {
2627+
T tmp =
2628+
UseIntermediateOut
26192629
? dy_op.UseIntermediateOut(x_val, y_val,
26202630
intermediate_out[tmp_out_idx],
26212631
out[offset], dout[offset])
26222632
: dy_op.Recompute(x_val, y_val, out[offset], dout[offset]);
2623-
if (BcastY) {
2624-
val += tmp;
2625-
} else {
2626-
dy[y_idx] = tmp;
2627-
}
2628-
}
2629-
if (d_intermediate != nullptr) {
2630-
T tmp = UseIntermediateOut
2631-
? dintermediate_op.UseIntermediateOut(
2632-
y[y_idx], intermediate_out[tmp_out_idx], out[offset],
2633-
dout[offset])
2634-
: dintermediate_op.Recompute(x_val, y_val, out[offset],
2635-
dout[offset]);
2636-
if (SameShapeOfIntermediateOutAndOut) {
2637-
d_intermediate[tmp_out_idx] = tmp;
2638-
} else {
2639-
inter_val += tmp;
2633+
if (BcastY) {
2634+
val += tmp;
2635+
} else {
2636+
dy[y_idx] = tmp;
2637+
}
2638+
}
2639+
if (d_intermediate != nullptr) {
2640+
T tmp = UseIntermediateOut
2641+
? dintermediate_op.UseIntermediateOut(
2642+
y[y_idx], intermediate_out[tmp_out_idx],
2643+
out[offset], dout[offset])
2644+
: dintermediate_op.Recompute(x_val, y_val, out[offset],
2645+
dout[offset]);
2646+
if (SameShapeOfIntermediateOutAndOut) {
2647+
d_intermediate[tmp_out_idx] = tmp;
2648+
} else {
2649+
inter_val += tmp;
2650+
}
2651+
}
26402652
}
26412653
}
26422654

2643-
i += ELEMWISE_MAX_BLOCK_DIM;
2644-
} while (i < h);
2655+
// transpose, for ReduceSum with wrap
2656+
sdata[threadIdx.y][threadIdx.x] = val;
2657+
__syncthreads();
2658+
val = sdata[threadIdx.x][threadIdx.y];
2659+
#pragma unroll
2660+
for (int i = BLOCK_X >> 1; i > 0; i >>= 1) {
2661+
// reduce sum with wrap
2662+
val += platform::CudaShuffleXorSync(0xFFFFFFFF, val, i);
2663+
}
26452664

2646-
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
2647-
if (BcastY) {
2648-
if (dy) {
2649-
val = paddle::platform::reduceSum(val, tid, h);
2650-
if (threadIdx.x == 0) {
2651-
dy[j] = val;
2665+
size_t idx_j = j + threadIdx.y;
2666+
if (BcastY) {
2667+
if (dy) {
2668+
if (threadIdx.x == 0 && (idx_j < w)) dy[idx_j] = val;
26522669
}
2653-
}
2654-
} else {
2655-
if (dx) {
2656-
val = paddle::platform::reduceSum(val, tid, h);
2657-
if (threadIdx.x == 0) {
2658-
dx[j] = val;
2670+
} else {
2671+
if (dx) {
2672+
if (threadIdx.x == 0 && (idx_j < w)) dx[idx_j] = val;
26592673
}
26602674
}
2661-
}
2662-
if (!SameShapeOfIntermediateOutAndOut) {
2663-
if (d_intermediate) {
2664-
inter_val = paddle::platform::reduceSum(inter_val, tid, h);
2665-
if (threadIdx.x == 0) {
2666-
d_intermediate[j] = inter_val;
2675+
2676+
if (!SameShapeOfIntermediateOutAndOut) {
2677+
if (d_intermediate) {
2678+
sdata[threadIdx.y][threadIdx.x] = inter_val;
2679+
__syncthreads();
2680+
inter_val = sdata[threadIdx.x][threadIdx.y];
2681+
#pragma unroll
2682+
for (int i = BLOCK_X >> 1; i > 0; i >>= 1) {
2683+
// reduce sum with wrap
2684+
inter_val += platform::CudaShuffleXorSync(0xFFFFFFFF, inter_val, i);
2685+
}
2686+
if (threadIdx.x == 0 && (idx_j < w)) d_intermediate[idx_j] = inter_val;
26672687
}
26682688
}
2669-
}
2689+
} // end for
26702690
}
26712691

26722692
template <typename T, typename DX_OP, typename DY_OP, typename DIntermediate_OP,
26732693
bool UseIntermediateOut, bool BcastY,
26742694
bool SameShapeOfIntermediateOutAndOut>
26752695
static void FusedElemwiseAndActGradBroadcast1CUDA(
2676-
gpuStream_t stream, const T *x, const T *y, const T *intermediate_out,
2677-
const T *out, const T *dout, int h, int w, DX_OP dx_op, DY_OP dy_op,
2678-
DIntermediate_OP dintermediate_op, T *dx, T *dy, T *d_intermediate) {
2679-
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h);
2680-
int gird_size = w;
2696+
const framework::ExecutionContext &ctx, const T *x, const T *y,
2697+
const T *intermediate_out, const T *out, const T *dout, int h, int w,
2698+
DX_OP dx_op, DY_OP dy_op, DIntermediate_OP dintermediate_op, T *dx, T *dy,
2699+
T *d_intermediate) {
2700+
gpuStream_t stream = ctx.cuda_device_context().stream();
2701+
2702+
dim3 blocks(BLOCK_X, BLOCK_Y);
2703+
int max_gpu_threads = ctx.cuda_device_context().GetMaxPhysicalThreadCount();
2704+
int max_blocks = std::max(max_gpu_threads / (BLOCK_X * BLOCK_Y), 1);
2705+
int theory_block = (w + BLOCK_X - 1) / BLOCK_X;
2706+
dim3 grids(std::min(theory_block, max_blocks));
2707+
26812708
FusedElemwiseAndActGradBroadcast1CUDAKernel<
26822709
T, DX_OP, DY_OP, DIntermediate_OP, UseIntermediateOut, BcastY,
2683-
SameShapeOfIntermediateOutAndOut><<<gird_size, block_size, 0, stream>>>(
2710+
SameShapeOfIntermediateOutAndOut><<<grids, blocks, 0, stream>>>(
26842711
x, y, intermediate_out, out, dout, h, w, dx_op, dy_op, dintermediate_op,
26852712
dx, dy, d_intermediate);
26862713
}
@@ -2832,7 +2859,7 @@ void FusedElemwiseAndActGradComputeWithBroadcast(
28322859
FusedElemwiseAndActGradBroadcast1CUDA<T, DX_OP, DY_OP, DIntermediate_OP,
28332860
UseIntermediateOut, BcastY,
28342861
SameShapeOfIntermediateOutAndOut>(
2835-
ctx.template device_context<DeviceContext>().stream(), x_data, y_data,
2862+
ctx, x_data, y_data,
28362863
intermediate_out == nullptr ? nullptr : intermediate_out->data<T>(),
28372864
out->data<T>(), dout->data<T>(), h, w, dx_op, dy_op, dintermediate_op,
28382865
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),

0 commit comments

Comments
 (0)