Skip to content

Commit e370b16

Browse files
feng_shuaiAnnaTrainingG
authored andcommitted
merge CMakeList.txt manual (PaddlePaddle#35378)
* merge CMakeList.txt manual * add platform for changethreadnum * repair some bugs according to make error * do nothing just flush CI * forget change thread num * add inplace_atol param for check_output_with_place * Windows * std:min and std::max should be change because of windows
1 parent 58aa3eb commit e370b16

File tree

9 files changed

+148
-45
lines changed

9 files changed

+148
-45
lines changed

paddle/fluid/operators/math/im2col.cu

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License. */
1616
#include <vector>
1717
#include "paddle/fluid/operators/math/im2col.h"
1818
#include "paddle/fluid/platform/cuda_primitives.h"
19+
#include "paddle/fluid/platform/gpu_launch_config.h"
1920

2021
namespace paddle {
2122
namespace operators {
@@ -104,10 +105,14 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
104105
int col_width = col->dims()[4];
105106

106107
int num_outputs = im_channels * col_height * col_width;
107-
int blocks = (num_outputs + 1024 - 1) / 1024;
108+
int num_thread = 1024;
109+
#ifdef WITH_NV_JETSON
110+
platform::ChangeThreadNum(context, &num_thread);
111+
#endif
112+
int blocks = (num_outputs + num_thread - 1) / num_thread;
108113
int block_x = 512;
109114
int block_y = (blocks + 512 - 1) / 512;
110-
dim3 threads(1024, 1);
115+
dim3 threads(num_thread, 1);
111116
dim3 grid(block_x, block_y);
112117
im2col<T><<<grid, threads, 0, context.stream()>>>(
113118
im.data<T>(), num_outputs, im_height, im_width, dilation[0],
@@ -228,10 +233,14 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
228233

229234
size_t num_kernels = im_channels * im_height * im_width;
230235

231-
size_t blocks = (num_kernels + 1024 - 1) / 1024;
236+
int num_thread = 1024;
237+
#ifdef WITH_NV_JETSON
238+
platform::ChangeThreadNum(context, &num_thread);
239+
#endif
240+
size_t blocks = (num_kernels + num_thread - 1) / num_thread;
232241
size_t block_x = 512;
233242
size_t block_y = (blocks + 512 - 1) / 512;
234-
dim3 threads(1024, 1);
243+
dim3 threads(num_thread, 1);
235244
dim3 grid(block_x, block_y);
236245

237246
// To avoid involving atomic operations, we will launch one kernel per

paddle/fluid/operators/math/pooling.cu

Lines changed: 46 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License. */
1717

1818
#include "paddle/fluid/operators/math/pooling.h"
1919
#include "paddle/fluid/platform/cuda_primitives.h"
20+
#include "paddle/fluid/platform/gpu_launch_config.h"
2021

2122
namespace paddle {
2223
namespace operators {
@@ -254,8 +255,13 @@ void Pool2dDirectCUDAFunctor<PoolProcess, T>::operator()(
254255
const int padding_width = paddings[1];
255256

256257
int nthreads = batch_size * output_channels * output_height * output_width;
257-
int blocks = (nthreads + 1024 - 1) / 1024;
258-
dim3 threads(1024, 1);
258+
int thread_num = 1024;
259+
#ifdef WITH_NV_JETSON
260+
// platform::ChangeThreadNum(context, &thread_num);
261+
thread_num = 512;
262+
#endif
263+
int blocks = (nthreads + thread_num - 1) / thread_num;
264+
dim3 threads(thread_num, 1);
259265
dim3 grid(blocks, 1);
260266

261267
KernelPool2D<PoolProcess, T><<<grid, threads, 0, stream>>>(
@@ -298,10 +304,13 @@ class Pool2dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
298304
T* output_data = output->mutable_data<T>(context.GetPlace());
299305

300306
int nthreads = batch_size * output_channels * output_height * output_width;
301-
int blocks = (nthreads + 1024 - 1) / 1024;
302-
dim3 threads(1024, 1);
307+
int thread_num = 1024;
308+
#ifdef WITH_NV_JETSON
309+
platform::ChangeThreadNum(context, &thread_num);
310+
#endif
311+
int blocks = (nthreads + thread_num - 1) / thread_num;
312+
dim3 threads(thread_num, 1);
303313
dim3 grid(blocks, 1);
304-
305314
KernelPool2D<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
306315
nthreads, input_data, input_channels, input_height, input_width,
307316
output_height, output_width, ksize_height, ksize_width, stride_height,
@@ -341,10 +350,13 @@ class Pool2dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
341350
T* output_data = output->mutable_data<T>(context.GetPlace());
342351

343352
int nthreads = batch_size * output_channels * output_height * output_width;
344-
int blocks = (nthreads + 1024 - 1) / 1024;
345-
dim3 threads(1024, 1);
353+
int thread_num = 1024;
354+
#ifdef WITH_NV_JETSON
355+
platform::ChangeThreadNum(context, &thread_num);
356+
#endif
357+
int blocks = (nthreads + thread_num - 1) / thread_num;
358+
dim3 threads(thread_num, 1);
346359
dim3 grid(blocks, 1);
347-
348360
KernelPool2D<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
349361
nthreads, input_data, input_channels, input_height, input_width,
350362
output_height, output_width, ksize_height, ksize_width, stride_height,
@@ -911,8 +923,12 @@ class Pool3dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
911923

912924
int nthreads = batch_size * output_channels * output_depth * output_height *
913925
output_width;
914-
int blocks = (nthreads + 1024 - 1) / 1024;
915-
dim3 threads(1024, 1);
926+
int thread_num = 1024;
927+
#ifdef WITH_NV_JETSON
928+
platform::ChangeThreadNum(context, &thread_num);
929+
#endif
930+
int blocks = (nthreads + thread_num - 1) / thread_num;
931+
dim3 threads(thread_num, 1);
916932
dim3 grid(blocks, 1);
917933

918934
KernelPool3D<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
@@ -962,8 +978,12 @@ class Pool3dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
962978

963979
int nthreads = batch_size * output_channels * output_depth * output_height *
964980
output_width;
965-
int blocks = (nthreads + 1024 - 1) / 1024;
966-
dim3 threads(1024, 1);
981+
int thread_num = 1024;
982+
#ifdef WITH_NV_JETSON
983+
platform::ChangeThreadNum(context, &thread_num);
984+
#endif
985+
int blocks = (nthreads + thread_num - 1) / thread_num;
986+
dim3 threads(thread_num, 1);
967987
dim3 grid(blocks, 1);
968988

969989
KernelPool3D<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
@@ -1377,10 +1397,14 @@ class MaxPool2dWithIndexFunctor<platform::CUDADeviceContext, T1, T2> {
13771397
T2* mask_data = mask->mutable_data<T2>(context.GetPlace());
13781398

13791399
int nthreads = batch_size * output_channels * output_height * output_width;
1380-
int blocks = (nthreads + 1024 - 1) / 1024;
1381-
dim3 threads(1024, 1);
1382-
dim3 grid(blocks, 1);
1400+
int thread_num = 1024;
1401+
#ifdef WITH_NV_JETSON
1402+
platform::ChangeThreadNum(context, &thread_num);
1403+
#endif
13831404

1405+
int blocks = (nthreads + thread_num - 1) / thread_num;
1406+
dim3 threads(thread_num, 1);
1407+
dim3 grid(blocks, 1);
13841408
KernelMaxPool2dWithIdx<T1, T2><<<grid, threads, 0, context.stream()>>>(
13851409
nthreads, input_data, input_channels, input_height, input_width,
13861410
output_height, output_width, ksize_height, ksize_width, stride_height,
@@ -1613,8 +1637,13 @@ class MaxPool3dWithIndexFunctor<platform::CUDADeviceContext, T1, T2> {
16131637

16141638
int nthreads = batch_size * output_channels * output_depth * output_height *
16151639
output_width;
1616-
int blocks = (nthreads + 1024 - 1) / 1024;
1617-
dim3 threads(1024, 1);
1640+
int thread_num = 1024;
1641+
#ifdef WITH_NV_JETSON
1642+
platform::ChangeThreadNum(context, &thread_num);
1643+
#endif
1644+
1645+
int blocks = (nthreads + thread_num - 1) / thread_num;
1646+
dim3 threads(thread_num, 1);
16181647
dim3 grid(blocks, 1);
16191648

16201649
KernelMaxPool3DWithIdx<T1, T2><<<grid, threads, 0, context.stream()>>>(

paddle/fluid/operators/math/vol2col.cu

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License. */
1616
#include <vector>
1717
#include "paddle/fluid/operators/math/vol2col.h"
1818
#include "paddle/fluid/platform/cuda_primitives.h"
19+
#include "paddle/fluid/platform/gpu_launch_config.h"
1920

2021
namespace paddle {
2122
namespace operators {
@@ -152,8 +153,14 @@ class Vol2ColFunctor<platform::CUDADeviceContext, T> {
152153
int num_outputs =
153154
input_channels * output_depth * output_height * output_width;
154155

155-
const int threads = 1024;
156-
const int blocks = (num_outputs + 1024 - 1) / 1024;
156+
int max_threads = 1024;
157+
#ifdef WITH_NV_JETSON
158+
platform::ChangeThreadNum(context, &max_threads);
159+
#endif
160+
161+
const int threads = max_threads;
162+
const int blocks = (num_outputs + max_threads - 1) / max_threads;
163+
157164
vol2col<T><<<blocks, threads, 0, context.stream()>>>(
158165
num_outputs, vol.data<T>(), input_depth, input_height, input_width,
159166
dilations[0], dilations[1], dilations[2], filter_depth, filter_height,
@@ -313,8 +320,13 @@ class Col2VolFunctor<platform::CUDADeviceContext, T> {
313320

314321
int num_kernels = input_channels * input_depth * input_height * input_width;
315322

316-
const int threads = 1024;
317-
const int blocks = (num_kernels + 1024 - 1) / 1024;
323+
int max_threads = 1024;
324+
#ifdef WITH_NV_JETSON
325+
platform::ChangeThreadNum(context, &max_threads);
326+
#endif
327+
328+
const int threads = max_threads;
329+
const int blocks = (num_kernels + max_threads - 1) / max_threads;
318330

319331
col2vol<T><<<blocks, threads, 0, context.stream()>>>(
320332
num_kernels, col.data<T>(), input_depth, input_height, input_width,

paddle/fluid/operators/roi_align_op.cu

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License. */
1616
#include "paddle/fluid/memory/memory.h"
1717
#include "paddle/fluid/operators/roi_align_op.h"
1818
#include "paddle/fluid/platform/cuda_primitives.h"
19+
#include "paddle/fluid/platform/gpu_launch_config.h"
1920

2021
namespace paddle {
2122
namespace operators {
@@ -261,7 +262,9 @@ class GPUROIAlignOpKernel : public framework::OpKernel<T> {
261262
int output_size = out->numel();
262263
int blocks = NumBlocks(output_size);
263264
int threads = kNumCUDAThreads;
264-
265+
#ifdef WITH_NV_JETSON
266+
platform::ChangeThreadNum(ctx.cuda_device_context(), &threads, 256);
267+
#endif
265268
Tensor roi_batch_id_list;
266269
roi_batch_id_list.Resize({rois_num});
267270
auto cplace = platform::CPUPlace();

paddle/fluid/platform/for_range.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ limitations under the License. */
1414

1515
#pragma once
1616
#include "paddle/fluid/platform/device_context.h"
17+
#include "paddle/fluid/platform/gpu_launch_config.h"
1718

1819
namespace paddle {
1920
namespace platform {
@@ -65,6 +66,11 @@ struct ForRange<CUDADeviceContext> {
6566
#ifdef __HIPCC__
6667
// HIP will throw core dump when threads > 256
6768
constexpr int num_threads = 256;
69+
#elif WITH_NV_JETSON
70+
// JETSON_NANO will throw core dump when threads > 128
71+
int num_thread = 256;
72+
platform::ChangeThreadNum(dev_ctx_, &num_thread, 128);
73+
const int num_threads = num_thread;
6874
#else
6975
constexpr int num_threads = 1024;
7076
#endif

paddle/fluid/platform/gpu_launch_config.h

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#else
2424
#include <hip/hip_runtime.h>
2525
#endif
26+
2627
#include <stddef.h>
2728
#include <algorithm>
2829
#include <string>
@@ -33,6 +34,18 @@ namespace platform {
3334

3435
inline int DivUp(int a, int b) { return (a + b - 1) / b; }
3536

37+
#ifdef WITH_NV_JETSON
38+
// The number of threads cannot be assigned 1024 in some cases when the device
39+
// is nano or tx2 .
40+
inline void ChangeThreadNum(const platform::CUDADeviceContext& context,
41+
int* num_thread, int alternative_num_thread = 512) {
42+
if (context.GetComputeCapability() == 53 ||
43+
context.GetComputeCapability() == 62) {
44+
*num_thread = alternative_num_thread;
45+
}
46+
}
47+
#endif
48+
3649
struct GpuLaunchConfig {
3750
dim3 theory_thread_count = dim3(1, 1, 1);
3851
dim3 thread_per_block = dim3(1, 1, 1);
@@ -61,15 +74,22 @@ inline GpuLaunchConfig GetGpuLaunchConfig1D(
6174

6275
// Compute physical threads we need, should small than max sm threads
6376
const int physical_thread_count =
64-
std::min(max_physical_threads, theory_thread_count);
77+
(std::min)(max_physical_threads, theory_thread_count);
78+
79+
// Get compute_capability
80+
const int capability = context.GetComputeCapability();
81+
82+
#ifdef WITH_NV_JETSON
83+
if (capability == 53 || capability == 62) {
84+
max_threads = 512;
85+
}
86+
#endif
6587

6688
// Need get from device
6789
const int thread_per_block =
68-
std::min(max_threads, context.GetMaxThreadsPerBlock());
90+
(std::min)(max_threads, context.GetMaxThreadsPerBlock());
6991
const int block_count =
70-
std::min(DivUp(physical_thread_count, thread_per_block), sm);
71-
// Get compute_capability
72-
const int capability = context.GetComputeCapability();
92+
(std::min)(DivUp(physical_thread_count, thread_per_block), sm);
7393

7494
GpuLaunchConfig config;
7595
config.theory_thread_count.x = theory_thread_count;
@@ -91,19 +111,20 @@ inline GpuLaunchConfig GetGpuLaunchConfig2D(
91111
y_dim));
92112

93113
const int kThreadsPerBlock = 256;
94-
int block_cols = std::min(x_dim, kThreadsPerBlock);
95-
int block_rows = std::max(kThreadsPerBlock / block_cols, 1);
114+
int block_cols = (std::min)(x_dim, kThreadsPerBlock);
115+
int block_rows = (std::max)(kThreadsPerBlock / block_cols, 1);
96116

97117
int max_physical_threads = context.GetMaxPhysicalThreadCount();
98-
const int max_blocks = std::max(max_physical_threads / kThreadsPerBlock, 1);
118+
const int max_blocks = (std::max)(max_physical_threads / kThreadsPerBlock, 1);
99119

100120
GpuLaunchConfig config;
101121
// Noticed, block size is not align to 32, if needed do it yourself.
102122
config.theory_thread_count = dim3(x_dim, y_dim, 1);
103123
config.thread_per_block = dim3(block_cols, block_rows, 1);
104124

105-
int grid_x = std::min(DivUp(x_dim, block_cols), max_blocks);
106-
int grid_y = std::min(max_blocks / grid_x, std::max(y_dim / block_rows, 1));
125+
int grid_x = (std::min)(DivUp(x_dim, block_cols), max_blocks);
126+
int grid_y =
127+
(std::min)(max_blocks / grid_x, (std::max)(y_dim / block_rows, 1));
107128

108129
config.block_per_grid = dim3(grid_x, grid_y, 1);
109130
return config;

0 commit comments

Comments
 (0)