@@ -17,6 +17,7 @@ limitations under the License. */
1717
1818#include " paddle/fluid/operators/math/pooling.h"
1919#include " paddle/fluid/platform/cuda_primitives.h"
20+ #include " paddle/fluid/platform/gpu_launch_config.h"
2021
2122namespace paddle {
2223namespace operators {
@@ -254,8 +255,13 @@ void Pool2dDirectCUDAFunctor<PoolProcess, T>::operator()(
254255 const int padding_width = paddings[1 ];
255256
256257 int nthreads = batch_size * output_channels * output_height * output_width;
257- int blocks = (nthreads + 1024 - 1 ) / 1024 ;
258- dim3 threads (1024 , 1 );
258+ int thread_num = 1024 ;
259+ #ifdef WITH_NV_JETSON
260+ // platform::ChangeThreadNum(context, &thread_num);
261+ thread_num = 512 ;
262+ #endif
263+ int blocks = (nthreads + thread_num - 1 ) / thread_num;
264+ dim3 threads (thread_num, 1 );
259265 dim3 grid (blocks, 1 );
260266
261267 KernelPool2D<PoolProcess, T><<<grid, threads, 0 , stream>>> (
@@ -298,10 +304,13 @@ class Pool2dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
298304 T* output_data = output->mutable_data <T>(context.GetPlace ());
299305
300306 int nthreads = batch_size * output_channels * output_height * output_width;
301- int blocks = (nthreads + 1024 - 1 ) / 1024 ;
302- dim3 threads (1024 , 1 );
307+ int thread_num = 1024 ;
308+ #ifdef WITH_NV_JETSON
309+ platform::ChangeThreadNum (context, &thread_num);
310+ #endif
311+ int blocks = (nthreads + thread_num - 1 ) / thread_num;
312+ dim3 threads (thread_num, 1 );
303313 dim3 grid (blocks, 1 );
304-
305314 KernelPool2D<PoolProcess, T><<<grid, threads, 0 , context.stream()>>> (
306315 nthreads, input_data, input_channels, input_height, input_width,
307316 output_height, output_width, ksize_height, ksize_width, stride_height,
@@ -341,10 +350,13 @@ class Pool2dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
341350 T* output_data = output->mutable_data <T>(context.GetPlace ());
342351
343352 int nthreads = batch_size * output_channels * output_height * output_width;
344- int blocks = (nthreads + 1024 - 1 ) / 1024 ;
345- dim3 threads (1024 , 1 );
353+ int thread_num = 1024 ;
354+ #ifdef WITH_NV_JETSON
355+ platform::ChangeThreadNum (context, &thread_num);
356+ #endif
357+ int blocks = (nthreads + thread_num - 1 ) / thread_num;
358+ dim3 threads (thread_num, 1 );
346359 dim3 grid (blocks, 1 );
347-
348360 KernelPool2D<PoolProcess, T><<<grid, threads, 0 , context.stream()>>> (
349361 nthreads, input_data, input_channels, input_height, input_width,
350362 output_height, output_width, ksize_height, ksize_width, stride_height,
@@ -911,8 +923,12 @@ class Pool3dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
911923
912924 int nthreads = batch_size * output_channels * output_depth * output_height *
913925 output_width;
914- int blocks = (nthreads + 1024 - 1 ) / 1024 ;
915- dim3 threads (1024 , 1 );
926+ int thread_num = 1024 ;
927+ #ifdef WITH_NV_JETSON
928+ platform::ChangeThreadNum (context, &thread_num);
929+ #endif
930+ int blocks = (nthreads + thread_num - 1 ) / thread_num;
931+ dim3 threads (thread_num, 1 );
916932 dim3 grid (blocks, 1 );
917933
918934 KernelPool3D<PoolProcess, T><<<grid, threads, 0 , context.stream()>>> (
@@ -962,8 +978,12 @@ class Pool3dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
962978
963979 int nthreads = batch_size * output_channels * output_depth * output_height *
964980 output_width;
965- int blocks = (nthreads + 1024 - 1 ) / 1024 ;
966- dim3 threads (1024 , 1 );
981+ int thread_num = 1024 ;
982+ #ifdef WITH_NV_JETSON
983+ platform::ChangeThreadNum (context, &thread_num);
984+ #endif
985+ int blocks = (nthreads + thread_num - 1 ) / thread_num;
986+ dim3 threads (thread_num, 1 );
967987 dim3 grid (blocks, 1 );
968988
969989 KernelPool3D<PoolProcess, T><<<grid, threads, 0 , context.stream()>>> (
@@ -1377,10 +1397,14 @@ class MaxPool2dWithIndexFunctor<platform::CUDADeviceContext, T1, T2> {
13771397 T2* mask_data = mask->mutable_data <T2>(context.GetPlace ());
13781398
13791399 int nthreads = batch_size * output_channels * output_height * output_width;
1380- int blocks = (nthreads + 1024 - 1 ) / 1024 ;
1381- dim3 threads (1024 , 1 );
1382- dim3 grid (blocks, 1 );
1400+ int thread_num = 1024 ;
1401+ #ifdef WITH_NV_JETSON
1402+ platform::ChangeThreadNum (context, &thread_num);
1403+ #endif
13831404
1405+ int blocks = (nthreads + thread_num - 1 ) / thread_num;
1406+ dim3 threads (thread_num, 1 );
1407+ dim3 grid (blocks, 1 );
13841408 KernelMaxPool2dWithIdx<T1, T2><<<grid, threads, 0 , context.stream()>>> (
13851409 nthreads, input_data, input_channels, input_height, input_width,
13861410 output_height, output_width, ksize_height, ksize_width, stride_height,
@@ -1613,8 +1637,13 @@ class MaxPool3dWithIndexFunctor<platform::CUDADeviceContext, T1, T2> {
16131637
16141638 int nthreads = batch_size * output_channels * output_depth * output_height *
16151639 output_width;
1616- int blocks = (nthreads + 1024 - 1 ) / 1024 ;
1617- dim3 threads (1024 , 1 );
1640+ int thread_num = 1024 ;
1641+ #ifdef WITH_NV_JETSON
1642+ platform::ChangeThreadNum (context, &thread_num);
1643+ #endif
1644+
1645+ int blocks = (nthreads + thread_num - 1 ) / thread_num;
1646+ dim3 threads (thread_num, 1 );
16181647 dim3 grid (blocks, 1 );
16191648
16201649 KernelMaxPool3DWithIdx<T1, T2><<<grid, threads, 0 , context.stream()>>> (
0 commit comments