Skip to content

Grid_sampler optimization #39751

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 33 commits into from
Feb 28, 2022
Merged
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
8f532b0
Merge pull request #1 from PaddlePaddle/develop
AshburnLee Sep 8, 2020
5b5804d
Merge pull request #2 from PaddlePaddle/develop
AshburnLee Sep 17, 2020
cee2470
Merge pull request #3 from PaddlePaddle/develop
AshburnLee Sep 30, 2020
5be3a45
Merge pull request #4 from PaddlePaddle/develop
AshburnLee Oct 13, 2020
a1d92b7
Merge pull request #5 from PaddlePaddle/develop
AshburnLee Oct 20, 2020
e674a5d
Merge pull request #6 from PaddlePaddle/develop
AshburnLee Nov 15, 2020
855d00b
Merge pull request #7 from PaddlePaddle/develop
AshburnLee Nov 18, 2020
7cb2c97
Merge pull request #8 from PaddlePaddle/develop
AshburnLee Mar 31, 2021
db9fc91
Merge pull request #9 from PaddlePaddle/develop
AshburnLee Apr 7, 2021
c7b68c8
Merge branch 'develop' of https://github.com/PaddlePaddle/paddle into…
AshburnLee Apr 26, 2021
0fd630e
Merge branch 'PaddlePaddle:develop' into develop
AshburnLee Aug 16, 2021
4bbb33b
Merge branch 'PaddlePaddle:develop' into develop
AshburnLee Sep 28, 2021
30a1a89
Merge branch 'PaddlePaddle:develop' into develop
AshburnLee Nov 22, 2021
ce3deec
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AshburnLee Dec 21, 2021
925eb06
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AshburnLee Jan 6, 2022
7fcf902
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AshburnLee Jan 26, 2022
956bd69
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AshburnLee Jan 28, 2022
5f5fb9e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AshburnLee Feb 10, 2022
02ad020
init grid_sampler with mode=bilinear
AshburnLee Feb 20, 2022
7d0baac
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AshburnLee Feb 20, 2022
b9f7af8
solve error
AshburnLee Feb 21, 2022
91e7467
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AshburnLee Feb 21, 2022
bf3ef1a
rm fill constant
AshburnLee Feb 21, 2022
9f0889f
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AshburnLee Feb 21, 2022
78fed2b
rm head
AshburnLee Feb 21, 2022
207564d
change block size
AshburnLee Feb 21, 2022
0fc7705
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AshburnLee Feb 21, 2022
755c48b
change block size
AshburnLee Feb 22, 2022
cbbb3cd
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AshburnLee Feb 22, 2022
5a431f0
optimize
AshburnLee Feb 23, 2022
68c43f7
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AshburnLee Feb 23, 2022
973cad0
apply existing config
AshburnLee Feb 25, 2022
625f725
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AshburnLee Feb 25, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 8 additions & 15 deletions paddle/fluid/operators/grid_sampler_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/operators/grid_sampler_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"

namespace paddle {
Expand Down Expand Up @@ -292,15 +293,12 @@ class GridSampleOpCUDAKernel : public framework::OpKernel<T> {
auto* output_data = output->mutable_data<T>(ctx.GetPlace());
VLOG(3) << "out dims: " << output->dims()[0] << "; " << output->dims()[1]
<< "; " << output->dims()[2] << "; " << output->dims()[3];
phi::funcs::SetConstant<paddle::platform::CUDADeviceContext, T>()(
dev_ctx, output, static_cast<T>(0));
int count = static_cast<int>(n * out_h * out_w);
auto cu_stream = dev_ctx.stream();
int block_size = 512;
int grid_size = (count + block_size - 1) / block_size;
VLOG(3) << "cuda launch - grid dims: " << grid_size << "; block dims"
<< block_size;
grid_sample_cuda_kernel<T><<<grid_size, block_size, 0, cu_stream>>>(
platform::GpuLaunchConfig config =
platform::GetGpuLaunchConfig1D(dev_ctx, count);
grid_sample_cuda_kernel<
T><<<config.block_per_grid, config.thread_per_block, 0, cu_stream>>>(
count, n, c, out_h, out_w, in_h, in_w, input->data<T>(),
grid->data<T>(), output_data, mode, padding_mode, align_corners);
}
Expand Down Expand Up @@ -467,19 +465,14 @@ class GridSampleGradOpCUDAKernel : public framework::OpKernel<T> {
if (ctx.HasOutput(framework::GradVarName("Grid"))) {
auto* grid_grad = ctx.Output<Tensor>(framework::GradVarName("Grid"));
grid_grad_data = grid_grad->mutable_data<T>(ctx.GetPlace());
phi::funcs::SetConstant<paddle::platform::CUDADeviceContext, T>()(
ctx.template device_context<paddle::platform::CUDADeviceContext>(),
grid_grad, static_cast<T>(0));
}

int count = static_cast<int>(n * out_h * out_w);
auto cu_stream = dev_ctx.stream();
int block_size = 512;
int grid_size = (count + block_size - 1) / block_size;
VLOG(3) << "cuda launch grad kernel - grid dims: " << grid_size
<< "; block dims" << block_size << "; count: " << count;
platform::GpuLaunchConfig config =
platform::GetGpuLaunchConfig1D(dev_ctx, count);
grid_sampler_cuda_backward_kernel<
T><<<grid_size, block_size, 0, cu_stream>>>(
T><<<config.block_per_grid, config.thread_per_block, 0, cu_stream>>>(
count, output_grad->data<T>(), input->data<T>(), grid->data<T>(), n, c,
out_h, out_w, in_h, in_w, input_grad->data<T>(), grid_grad_data, mode,
padding_mode, align_corners);
Expand Down