Skip to content

Commit

Permalink
【PaddlePaddle Hackathon 5 No.48】ContiguousKernel、StridedCopyKernel算子C…
Browse files Browse the repository at this point in the history
…PU、GPU性能优化 (PaddlePaddle#57835)

* speed up ContiguousKernel

* fix bugs

* fix bugs

* test origin code

* fix bugs
  • Loading branch information
WintersMontagne10335 authored Oct 16, 2023
1 parent 35ee4e9 commit 9fb27cc
Showing 1 changed file with 312 additions and 53 deletions.
365 changes: 312 additions & 53 deletions paddle/phi/kernels/gpu/contiguous_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,26 +20,120 @@ limitations under the License. */
#include "paddle/phi/kernels/transpose_kernel.h"

namespace phi {
template <typename T, size_t N>
__global__ void ContiguousCaseZeroFunc(
const T* input_data,
T* out_data,
phi::Array<int64_t, phi::DDim::kMaxRank + 1> input_stride) {
int64_t input_offset = 0;
int64_t output_offset = (blockIdx.z * gridDim.y * gridDim.x +
blockIdx.y * gridDim.x + blockIdx.x) *
blockDim.z * blockDim.y * blockDim.x +
threadIdx.z * blockDim.y * blockDim.x +
threadIdx.y * blockDim.x + threadIdx.x;
float coordinate[6] = {threadIdx.x,
threadIdx.y,
threadIdx.z,
blockIdx.x,
blockIdx.y,
blockIdx.z};

#pragma unroll
for (int dim = N - 1; dim >= 0; --dim) {
input_offset += coordinate[N - 1 - dim] * input_stride[dim];
}

out_data[output_offset] = input_data[input_offset];
}

template <typename T, size_t N>
__global__ void ContiguousFunc(
__global__ void ContiguousCaseOneFunc(
const T* input_data,
T* out_data,
phi::Array<int64_t, phi::DDim::kMaxRank + 1> input_stride,
phi::Array<int64_t, phi::DDim::kMaxRank + 1> dims,
const int64_t numel) {
int64_t gid = blockIdx.x * blockDim.x + threadIdx.x;
#pragma unroll
for (int64_t i = gid; i < numel; i += blockDim.x * gridDim.x) {
phi::Array<int64_t, 6> dims,
const int64_t x_max) {
int64_t x = blockIdx.x * blockDim.x + threadIdx.x;
if (x < x_max) {
int64_t input_offset = 0;
int64_t index_tmp = i;
int64_t output_offset = (blockIdx.z * gridDim.y + blockIdx.y) * x_max + x;

int64_t reg_dims[6] = {
dims[0], dims[1], dims[2], dims[3], dims[4], dims[5]};
int64_t coordinate[phi::DDim::kMaxRank + 1];

switch (N) {
case 1:
coordinate[0] = x % reg_dims[0];
break;
case 2:
coordinate[0] = x % reg_dims[0];
coordinate[1] = x / reg_dims[0] % reg_dims[1];
break;
case 3:
coordinate[0] = x % reg_dims[0];
coordinate[1] = x / reg_dims[0] % reg_dims[1];
coordinate[2] = x / (reg_dims[0] * reg_dims[1]);
break;
case 4:
coordinate[0] = x % reg_dims[0];
coordinate[1] = x / reg_dims[0] % reg_dims[1];
coordinate[2] = x / (reg_dims[0] * reg_dims[1]);
coordinate[3] = blockIdx.y % reg_dims[2];
break;
case 5:
coordinate[0] = x % reg_dims[0];
coordinate[1] = x / reg_dims[0] % reg_dims[1];
coordinate[2] = x / (reg_dims[0] * reg_dims[1]);
coordinate[3] = blockIdx.y % reg_dims[2];
coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3];
break;
case 6:
coordinate[0] = x % reg_dims[0];
coordinate[1] = x / reg_dims[0] % reg_dims[1];
coordinate[2] = x / (reg_dims[0] * reg_dims[1]);
coordinate[3] = blockIdx.y % reg_dims[2];
coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3];
coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]);
break;
case 7:
coordinate[0] = x % reg_dims[0];
coordinate[1] = x / reg_dims[0] % reg_dims[1];
coordinate[2] = x / (reg_dims[0] * reg_dims[1]);
coordinate[3] = blockIdx.y % reg_dims[2];
coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3];
coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]);
coordinate[6] = blockIdx.z % reg_dims[4];
break;
case 8:
coordinate[0] = x % reg_dims[0];
coordinate[1] = x / reg_dims[0] % reg_dims[1];
coordinate[2] = x / (reg_dims[0] * reg_dims[1]);
coordinate[3] = blockIdx.y % reg_dims[2];
coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3];
coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]);
coordinate[6] = blockIdx.z % reg_dims[4];
coordinate[7] = blockIdx.z / reg_dims[4] % reg_dims[5];
break;
case 9:
coordinate[0] = x % reg_dims[0];
coordinate[1] = x / reg_dims[0] % reg_dims[1];
coordinate[2] = x / (reg_dims[0] * reg_dims[1]);
coordinate[3] = blockIdx.y % reg_dims[2];
coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3];
coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]);
coordinate[6] = blockIdx.z % reg_dims[4];
coordinate[7] = blockIdx.z / reg_dims[4] % reg_dims[5];
coordinate[8] = blockIdx.z / (reg_dims[4] * reg_dims[5]);
break;
}

#pragma unroll
for (int dim = N - 1; dim >= 0; --dim) {
input_offset += index_tmp % dims[dim] * input_stride[dim];
index_tmp = index_tmp / dims[dim];
input_offset += coordinate[N - 1 - dim] * input_stride[dim];
}

out_data[i] = input_data[input_offset];
out_data[output_offset] = input_data[input_offset];
}
}

Expand Down Expand Up @@ -135,49 +229,214 @@ void ContiguousKernel(const Context& dev_ctx,
input_stride[0] = 1;
}

int64_t block = 512;
int64_t grid = (numel + block - 1) / block;

switch (rank) {
case 1:
ContiguousFunc<T, 1><<<grid, block, 0, dev_ctx.stream()>>>(
input_data, output_data, input_stride, input_dims, numel);
break;
case 2:
ContiguousFunc<T, 2><<<grid, block, 0, dev_ctx.stream()>>>(
input_data, output_data, input_stride, input_dims, numel);
break;
case 3:
ContiguousFunc<T, 3><<<grid, block, 0, dev_ctx.stream()>>>(
input_data, output_data, input_stride, input_dims, numel);
break;
case 4:
ContiguousFunc<T, 4><<<grid, block, 0, dev_ctx.stream()>>>(
input_data, output_data, input_stride, input_dims, numel);
break;
case 5:
ContiguousFunc<T, 5><<<grid, block, 0, dev_ctx.stream()>>>(
input_data, output_data, input_stride, input_dims, numel);
break;
case 6:
ContiguousFunc<T, 6><<<grid, block, 0, dev_ctx.stream()>>>(
input_data, output_data, input_stride, input_dims, numel);
break;
case 7:
ContiguousFunc<T, 7><<<grid, block, 0, dev_ctx.stream()>>>(
input_data, output_data, input_stride, input_dims, numel);
break;
case 8:
ContiguousFunc<T, 8><<<grid, block, 0, dev_ctx.stream()>>>(
input_data, output_data, input_stride, input_dims, numel);
break;
case 9:
ContiguousFunc<T, 9><<<grid, block, 0, dev_ctx.stream()>>>(
input_data, output_data, input_stride, input_dims, numel);
break;
default:
PADDLE_THROW(phi::errors::InvalidArgument(
"The rank of input should be less than 9, but received %d.", rank));
dim3 grid(1, 1, 1), block(1, 1, 1);

int tmp = 1;

for (int i = 0; i < 3 && i < rank; i++) {
tmp *= input_dims[rank - 1 - i];
}

if (rank <= 6 && tmp <= 1024 &&
(input_dims.size() < 3 || input_dims[rank - 3] <= 64)) {
if (rank >= 1) {
block.x = input_dims[rank - 1];
}

if (rank >= 2) {
block.y = input_dims[rank - 2];
}

if (rank >= 3) {
block.z = input_dims[rank - 3];
}

switch (rank) {
case 1:
ContiguousCaseZeroFunc<T, 1><<<grid, block, 0, dev_ctx.stream()>>>(
input_data, output_data, input_stride);
break;
case 2:
ContiguousCaseZeroFunc<T, 2><<<grid, block, 0, dev_ctx.stream()>>>(
input_data, output_data, input_stride);
break;
case 3:
ContiguousCaseZeroFunc<T, 3><<<grid, block, 0, dev_ctx.stream()>>>(
input_data, output_data, input_stride);
break;
case 4:
grid.x = input_dims[rank - 4];
ContiguousCaseZeroFunc<T, 4><<<grid, block, 0, dev_ctx.stream()>>>(
input_data, output_data, input_stride);
break;
case 5:
grid.x = input_dims[rank - 4];
grid.y = input_dims[rank - 5];
ContiguousCaseZeroFunc<T, 5><<<grid, block, 0, dev_ctx.stream()>>>(
input_data, output_data, input_stride);
break;
case 6:
grid.x = input_dims[rank - 4];
grid.y = input_dims[rank - 5];
grid.z = input_dims[rank - 6];
ContiguousCaseZeroFunc<T, 6><<<grid, block, 0, dev_ctx.stream()>>>(
input_data, output_data, input_stride);
break;
}
} else {
phi::Array<int64_t, 6> cur_input_dims;
block.x = 512;
switch (rank) {
case 1:
grid.x = (numel + block.x - 1) / block.x;
cur_input_dims[0] = input_dims[rank - 1];
ContiguousCaseOneFunc<T, 1>
<<<grid, block, 0, dev_ctx.stream()>>>(input_data,
output_data,
input_stride,
cur_input_dims,
input_dims[rank - 1]);
break;
case 2:
grid.x = (numel + block.x - 1) / block.x;
cur_input_dims[0] = input_dims[rank - 1];
cur_input_dims[1] = input_dims[rank - 2];
ContiguousCaseOneFunc<T, 2><<<grid, block, 0, dev_ctx.stream()>>>(
input_data,
output_data,
input_stride,
cur_input_dims,
input_dims[rank - 1] * input_dims[rank - 2]);
break;
case 3:
grid.x = (numel + block.x - 1) / block.x;
cur_input_dims[0] = input_dims[rank - 1];
cur_input_dims[1] = input_dims[rank - 2];
ContiguousCaseOneFunc<T, 3><<<grid, block, 0, dev_ctx.stream()>>>(
input_data,
output_data,
input_stride,
cur_input_dims,
input_dims[rank - 1] * input_dims[rank - 2] * input_dims[rank - 3]);
break;
case 4:
grid.x = (input_dims[rank - 1] * input_dims[rank - 2] *
input_dims[rank - 3] +
block.x - 1) /
block.x;
grid.y = input_dims[rank - 4];
cur_input_dims[0] = input_dims[rank - 1];
cur_input_dims[1] = input_dims[rank - 2];
cur_input_dims[2] = input_dims[rank - 4];
ContiguousCaseOneFunc<T, 4><<<grid, block, 0, dev_ctx.stream()>>>(
input_data,
output_data,
input_stride,
cur_input_dims,
input_dims[rank - 1] * input_dims[rank - 2] * input_dims[rank - 3]);
break;
case 5:
grid.x = (input_dims[rank - 1] * input_dims[rank - 2] *
input_dims[rank - 3] +
block.x - 1) /
block.x;
grid.y = input_dims[rank - 4] * input_dims[rank - 5];
cur_input_dims[0] = input_dims[rank - 1];
cur_input_dims[1] = input_dims[rank - 2];
cur_input_dims[2] = input_dims[rank - 4];
cur_input_dims[3] = input_dims[rank - 5];
ContiguousCaseOneFunc<T, 5><<<grid, block, 0, dev_ctx.stream()>>>(
input_data,
output_data,
input_stride,
cur_input_dims,
input_dims[rank - 1] * input_dims[rank - 2] * input_dims[rank - 3]);
break;
case 6:
grid.x = (input_dims[rank - 1] * input_dims[rank - 2] *
input_dims[rank - 3] +
block.x - 1) /
block.x;
grid.y =
input_dims[rank - 4] * input_dims[rank - 5] * input_dims[rank - 6];
cur_input_dims[0] = input_dims[rank - 1];
cur_input_dims[1] = input_dims[rank - 2];
cur_input_dims[2] = input_dims[rank - 4];
cur_input_dims[3] = input_dims[rank - 5];
ContiguousCaseOneFunc<T, 6><<<grid, block, 0, dev_ctx.stream()>>>(
input_data,
output_data,
input_stride,
cur_input_dims,
input_dims[rank - 1] * input_dims[rank - 2] * input_dims[rank - 3]);
break;
case 7:
grid.x = (input_dims[rank - 1] * input_dims[rank - 2] *
input_dims[rank - 3] +
block.x - 1) /
block.x;
grid.y =
input_dims[rank - 4] * input_dims[rank - 5] * input_dims[rank - 6];
grid.z = input_dims[rank - 7];
cur_input_dims[0] = input_dims[rank - 1];
cur_input_dims[1] = input_dims[rank - 2];
cur_input_dims[2] = input_dims[rank - 4];
cur_input_dims[3] = input_dims[rank - 5];
cur_input_dims[4] = input_dims[rank - 7];
ContiguousCaseOneFunc<T, 7><<<grid, block, 0, dev_ctx.stream()>>>(
input_data,
output_data,
input_stride,
cur_input_dims,
input_dims[rank - 1] * input_dims[rank - 2] * input_dims[rank - 3]);
break;
case 8:
grid.x = (input_dims[rank - 1] * input_dims[rank - 2] *
input_dims[rank - 3] +
block.x - 1) /
block.x;
grid.y =
input_dims[rank - 4] * input_dims[rank - 5] * input_dims[rank - 6];
grid.z = input_dims[rank - 7] * input_dims[rank - 8];
cur_input_dims[0] = input_dims[rank - 1];
cur_input_dims[1] = input_dims[rank - 2];
cur_input_dims[2] = input_dims[rank - 4];
cur_input_dims[3] = input_dims[rank - 5];
cur_input_dims[4] = input_dims[rank - 7];
cur_input_dims[5] = input_dims[rank - 8];
ContiguousCaseOneFunc<T, 8><<<grid, block, 0, dev_ctx.stream()>>>(
input_data,
output_data,
input_stride,
cur_input_dims,
input_dims[rank - 1] * input_dims[rank - 2] * input_dims[rank - 3]);
break;
case 9:
grid.x = (input_dims[rank - 1] * input_dims[rank - 2] *
input_dims[rank - 3] +
block.x - 1) /
block.x;
grid.y =
input_dims[rank - 4] * input_dims[rank - 5] * input_dims[rank - 6];
grid.z =
input_dims[rank - 7] * input_dims[rank - 8] * input_dims[rank - 9];
cur_input_dims[0] = input_dims[rank - 1];
cur_input_dims[1] = input_dims[rank - 2];
cur_input_dims[2] = input_dims[rank - 4];
cur_input_dims[3] = input_dims[rank - 5];
cur_input_dims[4] = input_dims[rank - 7];
cur_input_dims[5] = input_dims[rank - 8];
ContiguousCaseOneFunc<T, 9><<<grid, block, 0, dev_ctx.stream()>>>(
input_data,
output_data,
input_stride,
cur_input_dims,
input_dims[rank - 1] * input_dims[rank - 2] * input_dims[rank - 3]);
break;
default:
PADDLE_THROW(phi::errors::InvalidArgument(
"The rank of input should be less than 9, but received %d.", rank));
}
}
}

Expand Down

0 comments on commit 9fb27cc

Please sign in to comment.