From 9fb27cc5a3e546ab75ae3e8415758291bcb2d11f Mon Sep 17 00:00:00 2001 From: Winters Montagne <118546135+WintersMontagne10335@users.noreply.github.com> Date: Mon, 16 Oct 2023 15:04:44 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90PaddlePaddle=20Hackathon=205=20No.48?= =?UTF-8?q?=E3=80=91ContiguousKernel=E3=80=81StridedCopyKernel=E7=AE=97?= =?UTF-8?q?=E5=AD=90CPU=E3=80=81GPU=E6=80=A7=E8=83=BD=E4=BC=98=E5=8C=96=20?= =?UTF-8?q?(#57835)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * speed up ContiguousKernel * fix bugs * fix bugs * test origin code * fix bugs --- paddle/phi/kernels/gpu/contiguous_kernel.cu | 365 +++++++++++++++++--- 1 file changed, 312 insertions(+), 53 deletions(-) diff --git a/paddle/phi/kernels/gpu/contiguous_kernel.cu b/paddle/phi/kernels/gpu/contiguous_kernel.cu index b8dee10e31cdeb..357e104afb01c8 100644 --- a/paddle/phi/kernels/gpu/contiguous_kernel.cu +++ b/paddle/phi/kernels/gpu/contiguous_kernel.cu @@ -20,26 +20,120 @@ limitations under the License. */ #include "paddle/phi/kernels/transpose_kernel.h" namespace phi { +template +__global__ void ContiguousCaseZeroFunc( + const T* input_data, + T* out_data, + phi::Array input_stride) { + int64_t input_offset = 0; + int64_t output_offset = (blockIdx.z * gridDim.y * gridDim.x + + blockIdx.y * gridDim.x + blockIdx.x) * + blockDim.z * blockDim.y * blockDim.x + + threadIdx.z * blockDim.y * blockDim.x + + threadIdx.y * blockDim.x + threadIdx.x; + float coordinate[6] = {threadIdx.x, + threadIdx.y, + threadIdx.z, + blockIdx.x, + blockIdx.y, + blockIdx.z}; + +#pragma unroll + for (int dim = N - 1; dim >= 0; --dim) { + input_offset += coordinate[N - 1 - dim] * input_stride[dim]; + } + + out_data[output_offset] = input_data[input_offset]; +} template -__global__ void ContiguousFunc( +__global__ void ContiguousCaseOneFunc( const T* input_data, T* out_data, phi::Array input_stride, - phi::Array dims, - const int64_t numel) { - int64_t gid = blockIdx.x * blockDim.x + threadIdx.x; -#pragma unroll - for (int64_t i = gid; i < numel; i += blockDim.x * gridDim.x) { + phi::Array dims, + const int64_t x_max) { + int64_t x = blockIdx.x * blockDim.x + threadIdx.x; + if (x < x_max) { int64_t input_offset = 0; - int64_t index_tmp = i; + int64_t output_offset = (blockIdx.z * gridDim.y + blockIdx.y) * x_max + x; + + int64_t reg_dims[6] = { + dims[0], dims[1], dims[2], dims[3], dims[4], dims[5]}; + int64_t coordinate[phi::DDim::kMaxRank + 1]; + + switch (N) { + case 1: + coordinate[0] = x % reg_dims[0]; + break; + case 2: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + break; + case 3: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + break; + case 4: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + coordinate[3] = blockIdx.y % reg_dims[2]; + break; + case 5: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + coordinate[3] = blockIdx.y % reg_dims[2]; + coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; + break; + case 6: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + coordinate[3] = blockIdx.y % reg_dims[2]; + coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; + coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); + break; + case 7: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + coordinate[3] = blockIdx.y % reg_dims[2]; + coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; + coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); + coordinate[6] = blockIdx.z % reg_dims[4]; + break; + case 8: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + coordinate[3] = blockIdx.y % reg_dims[2]; + coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; + coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); + coordinate[6] = blockIdx.z % reg_dims[4]; + coordinate[7] = blockIdx.z / reg_dims[4] % reg_dims[5]; + break; + case 9: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + coordinate[3] = blockIdx.y % reg_dims[2]; + coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; + coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); + coordinate[6] = blockIdx.z % reg_dims[4]; + coordinate[7] = blockIdx.z / reg_dims[4] % reg_dims[5]; + coordinate[8] = blockIdx.z / (reg_dims[4] * reg_dims[5]); + break; + } + #pragma unroll for (int dim = N - 1; dim >= 0; --dim) { - input_offset += index_tmp % dims[dim] * input_stride[dim]; - index_tmp = index_tmp / dims[dim]; + input_offset += coordinate[N - 1 - dim] * input_stride[dim]; } - out_data[i] = input_data[input_offset]; + out_data[output_offset] = input_data[input_offset]; } } @@ -135,49 +229,214 @@ void ContiguousKernel(const Context& dev_ctx, input_stride[0] = 1; } - int64_t block = 512; - int64_t grid = (numel + block - 1) / block; - - switch (rank) { - case 1: - ContiguousFunc<<>>( - input_data, output_data, input_stride, input_dims, numel); - break; - case 2: - ContiguousFunc<<>>( - input_data, output_data, input_stride, input_dims, numel); - break; - case 3: - ContiguousFunc<<>>( - input_data, output_data, input_stride, input_dims, numel); - break; - case 4: - ContiguousFunc<<>>( - input_data, output_data, input_stride, input_dims, numel); - break; - case 5: - ContiguousFunc<<>>( - input_data, output_data, input_stride, input_dims, numel); - break; - case 6: - ContiguousFunc<<>>( - input_data, output_data, input_stride, input_dims, numel); - break; - case 7: - ContiguousFunc<<>>( - input_data, output_data, input_stride, input_dims, numel); - break; - case 8: - ContiguousFunc<<>>( - input_data, output_data, input_stride, input_dims, numel); - break; - case 9: - ContiguousFunc<<>>( - input_data, output_data, input_stride, input_dims, numel); - break; - default: - PADDLE_THROW(phi::errors::InvalidArgument( - "The rank of input should be less than 9, but received %d.", rank)); + dim3 grid(1, 1, 1), block(1, 1, 1); + + int tmp = 1; + + for (int i = 0; i < 3 && i < rank; i++) { + tmp *= input_dims[rank - 1 - i]; + } + + if (rank <= 6 && tmp <= 1024 && + (input_dims.size() < 3 || input_dims[rank - 3] <= 64)) { + if (rank >= 1) { + block.x = input_dims[rank - 1]; + } + + if (rank >= 2) { + block.y = input_dims[rank - 2]; + } + + if (rank >= 3) { + block.z = input_dims[rank - 3]; + } + + switch (rank) { + case 1: + ContiguousCaseZeroFunc<<>>( + input_data, output_data, input_stride); + break; + case 2: + ContiguousCaseZeroFunc<<>>( + input_data, output_data, input_stride); + break; + case 3: + ContiguousCaseZeroFunc<<>>( + input_data, output_data, input_stride); + break; + case 4: + grid.x = input_dims[rank - 4]; + ContiguousCaseZeroFunc<<>>( + input_data, output_data, input_stride); + break; + case 5: + grid.x = input_dims[rank - 4]; + grid.y = input_dims[rank - 5]; + ContiguousCaseZeroFunc<<>>( + input_data, output_data, input_stride); + break; + case 6: + grid.x = input_dims[rank - 4]; + grid.y = input_dims[rank - 5]; + grid.z = input_dims[rank - 6]; + ContiguousCaseZeroFunc<<>>( + input_data, output_data, input_stride); + break; + } + } else { + phi::Array cur_input_dims; + block.x = 512; + switch (rank) { + case 1: + grid.x = (numel + block.x - 1) / block.x; + cur_input_dims[0] = input_dims[rank - 1]; + ContiguousCaseOneFunc + <<>>(input_data, + output_data, + input_stride, + cur_input_dims, + input_dims[rank - 1]); + break; + case 2: + grid.x = (numel + block.x - 1) / block.x; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + ContiguousCaseOneFunc<<>>( + input_data, + output_data, + input_stride, + cur_input_dims, + input_dims[rank - 1] * input_dims[rank - 2]); + break; + case 3: + grid.x = (numel + block.x - 1) / block.x; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + ContiguousCaseOneFunc<<>>( + input_data, + output_data, + input_stride, + cur_input_dims, + input_dims[rank - 1] * input_dims[rank - 2] * input_dims[rank - 3]); + break; + case 4: + grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3] + + block.x - 1) / + block.x; + grid.y = input_dims[rank - 4]; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + cur_input_dims[2] = input_dims[rank - 4]; + ContiguousCaseOneFunc<<>>( + input_data, + output_data, + input_stride, + cur_input_dims, + input_dims[rank - 1] * input_dims[rank - 2] * input_dims[rank - 3]); + break; + case 5: + grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3] + + block.x - 1) / + block.x; + grid.y = input_dims[rank - 4] * input_dims[rank - 5]; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + cur_input_dims[2] = input_dims[rank - 4]; + cur_input_dims[3] = input_dims[rank - 5]; + ContiguousCaseOneFunc<<>>( + input_data, + output_data, + input_stride, + cur_input_dims, + input_dims[rank - 1] * input_dims[rank - 2] * input_dims[rank - 3]); + break; + case 6: + grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3] + + block.x - 1) / + block.x; + grid.y = + input_dims[rank - 4] * input_dims[rank - 5] * input_dims[rank - 6]; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + cur_input_dims[2] = input_dims[rank - 4]; + cur_input_dims[3] = input_dims[rank - 5]; + ContiguousCaseOneFunc<<>>( + input_data, + output_data, + input_stride, + cur_input_dims, + input_dims[rank - 1] * input_dims[rank - 2] * input_dims[rank - 3]); + break; + case 7: + grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3] + + block.x - 1) / + block.x; + grid.y = + input_dims[rank - 4] * input_dims[rank - 5] * input_dims[rank - 6]; + grid.z = input_dims[rank - 7]; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + cur_input_dims[2] = input_dims[rank - 4]; + cur_input_dims[3] = input_dims[rank - 5]; + cur_input_dims[4] = input_dims[rank - 7]; + ContiguousCaseOneFunc<<>>( + input_data, + output_data, + input_stride, + cur_input_dims, + input_dims[rank - 1] * input_dims[rank - 2] * input_dims[rank - 3]); + break; + case 8: + grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3] + + block.x - 1) / + block.x; + grid.y = + input_dims[rank - 4] * input_dims[rank - 5] * input_dims[rank - 6]; + grid.z = input_dims[rank - 7] * input_dims[rank - 8]; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + cur_input_dims[2] = input_dims[rank - 4]; + cur_input_dims[3] = input_dims[rank - 5]; + cur_input_dims[4] = input_dims[rank - 7]; + cur_input_dims[5] = input_dims[rank - 8]; + ContiguousCaseOneFunc<<>>( + input_data, + output_data, + input_stride, + cur_input_dims, + input_dims[rank - 1] * input_dims[rank - 2] * input_dims[rank - 3]); + break; + case 9: + grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3] + + block.x - 1) / + block.x; + grid.y = + input_dims[rank - 4] * input_dims[rank - 5] * input_dims[rank - 6]; + grid.z = + input_dims[rank - 7] * input_dims[rank - 8] * input_dims[rank - 9]; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + cur_input_dims[2] = input_dims[rank - 4]; + cur_input_dims[3] = input_dims[rank - 5]; + cur_input_dims[4] = input_dims[rank - 7]; + cur_input_dims[5] = input_dims[rank - 8]; + ContiguousCaseOneFunc<<>>( + input_data, + output_data, + input_stride, + cur_input_dims, + input_dims[rank - 1] * input_dims[rank - 2] * input_dims[rank - 3]); + break; + default: + PADDLE_THROW(phi::errors::InvalidArgument( + "The rank of input should be less than 9, but received %d.", rank)); + } } }