Skip to content

Commit 07bf09a

Browse files
committed
[PHI] Fix adaptivate pool2d kernel for big tensor
1 parent 9b4329c commit 07bf09a

File tree

3 files changed

+55
-232
lines changed

3 files changed

+55
-232
lines changed

paddle/phi/kernels/funcs/pooling.cc

Lines changed: 0 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -29,83 +29,6 @@ namespace phi::funcs {
2929
template <typename PoolProcess, typename T>
3030
class Pool2dFunctor<CPUContext, PoolProcess, T> {
3131
public:
32-
void operator()(const CPUContext& context,
33-
const DenseTensor& input,
34-
const std::vector<int>& ksize,
35-
const std::vector<int>& strides,
36-
const std::vector<int>& paddings,
37-
bool exclusive,
38-
bool adaptive,
39-
DenseTensor* output,
40-
PoolProcess pool_process) {
41-
const int batch_size = static_cast<int>(input.dims()[0]);
42-
const int input_height = static_cast<int>(input.dims()[2]);
43-
const int input_width = static_cast<int>(input.dims()[3]);
44-
const int output_channels = static_cast<int>(output->dims()[1]);
45-
const int output_height = static_cast<int>(output->dims()[2]);
46-
const int output_width = static_cast<int>(output->dims()[3]);
47-
const int ksize_height = ksize[0];
48-
const int ksize_width = ksize[1];
49-
const int stride_height = strides[0];
50-
const int stride_width = strides[1];
51-
const int padding_height = paddings[0];
52-
const int padding_width = paddings[1];
53-
54-
const int input_stride = input_height * input_width;
55-
const int output_stride = output_height * output_width;
56-
57-
const T* input_data = input.data<T>();
58-
T* output_data = context.template Alloc<T>(output);
59-
60-
int hstart = 0, hend = 1;
61-
int wstart = 0, wend = 1;
62-
for (int i = 0; i < batch_size; i++) {
63-
for (int c = 0; c < output_channels; ++c) {
64-
for (int ph = 0; ph < output_height; ++ph) {
65-
if (adaptive) {
66-
hstart = AdaptStartIndex(ph, input_height, output_height);
67-
hend = AdaptEndIndex(ph, input_height, output_height);
68-
}
69-
for (int pw = 0; pw < output_width; ++pw) {
70-
int pool_size = 1;
71-
if (adaptive) {
72-
wstart = AdaptStartIndex(pw, input_width, output_width);
73-
wend = AdaptEndIndex(pw, input_width, output_width);
74-
} else {
75-
hstart = ph * stride_height - padding_height;
76-
wstart = pw * stride_width - padding_width;
77-
hend = std::min(hstart + ksize_height,
78-
input_height + padding_height);
79-
wend =
80-
std::min(wstart + ksize_width, input_width + padding_width);
81-
pool_size = (hend - hstart) * (wend - wstart);
82-
83-
wstart = std::max(wstart, 0);
84-
hstart = std::max(hstart, 0);
85-
hend = std::min(hend, input_height);
86-
wend = std::min(wend, input_width);
87-
}
88-
89-
T ele = pool_process.initial();
90-
for (int h = hstart; h < hend; ++h) {
91-
for (int w = wstart; w < wend; ++w) {
92-
pool_process.compute(input_data[h * input_width + w], &ele);
93-
}
94-
}
95-
if (exclusive || adaptive) {
96-
pool_size = (hend - hstart) * (wend - wstart);
97-
}
98-
99-
pool_process.finalize(static_cast<T>(pool_size), &ele);
100-
output_data[ph * output_width + pw] = ele;
101-
}
102-
}
103-
input_data += input_stride;
104-
output_data += output_stride;
105-
}
106-
}
107-
}
108-
10932
void operator()(const CPUContext& context,
11033
const DenseTensor& input,
11134
const std::vector<int>& ksize,

paddle/phi/kernels/funcs/pooling.cu

Lines changed: 41 additions & 144 deletions
Original file line numberDiff line numberDiff line change
@@ -186,59 +186,52 @@ __global__ void KernelPool2D(const int nthreads,
186186
}
187187

188188
template <typename PoolProcess, typename T>
189-
__global__ void AdaptiveKernelPool2D(const int nthreads,
190-
const T* input_data,
191-
const int channels,
192-
const int input_height,
193-
const int input_width,
194-
const int output_height,
195-
const int output_width,
196-
const int ksize_height,
197-
const int ksize_width,
198-
const int stride_height,
199-
const int stride_width,
200-
const int padding_height,
201-
const int padding_width,
202-
FastDivModForPooling divmods,
189+
__global__ void AdaptiveKernelPool2D(const T* input_data,
190+
const int64_t channels,
191+
const int64_t input_height,
192+
const int64_t input_width,
193+
const int64_t output_height,
194+
const int64_t output_width,
203195
PoolProcess pool_process,
204196
bool exclusive,
205197
T* output_data,
206198
bool channel_last = false) {
207-
const int n_offset = blockIdx.y;
208-
const int c_offset = blockIdx.x * blockDim.y + threadIdx.y;
199+
const int64_t n_offset = blockIdx.y;
200+
const int64_t c_offset = blockIdx.x * blockDim.y + threadIdx.y;
209201
if (c_offset >= channels) {
210202
return;
211203
}
212-
int hstart, hend, wstart, wend;
213-
int input_offset =
204+
int64_t hstart, hend, wstart, wend;
205+
int64_t input_offset =
214206
channel_last
215207
? n_offset * input_height * input_width * channels
216208
: (n_offset * channels + c_offset) * input_height * input_width;
217-
int output_offset =
209+
int64_t output_offset =
218210
channel_last
219211
? n_offset * output_height * output_width * channels
220212
: (n_offset * channels + c_offset) * output_height * output_width;
221-
for (int hw_offset = threadIdx.x; hw_offset < output_height * output_width;
213+
for (int64_t hw_offset = threadIdx.x;
214+
hw_offset < output_height * output_width;
222215
hw_offset += blockDim.x) {
223-
int w_offset = hw_offset % output_width;
224-
int h_offset = hw_offset / output_width;
216+
int64_t w_offset = hw_offset % output_width;
217+
int64_t h_offset = hw_offset / output_width;
225218
hstart = AdaptStartIndex(h_offset, input_height, output_height);
226219
hend = AdaptEndIndex(h_offset, input_height, output_height);
227220
wstart = AdaptStartIndex(w_offset, input_width, output_width);
228221
wend = AdaptEndIndex(w_offset, input_width, output_width);
229222

230223
T ele = pool_process.initial();
231-
for (int h = hstart; h < hend; ++h) {
232-
for (int w = wstart; w < wend; ++w) {
233-
auto input_idx = channel_last
234-
? (h * input_width + w) * channels + c_offset
235-
: h * input_width + w;
224+
for (int64_t h = hstart; h < hend; ++h) {
225+
for (int64_t w = wstart; w < wend; ++w) {
226+
int64_t input_idx = channel_last
227+
? (h * input_width + w) * channels + c_offset
228+
: h * input_width + w;
236229
pool_process.compute(input_data[input_offset + input_idx], &ele);
237230
}
238231
}
239-
int pool_size = (hend - hstart) * (wend - wstart);
232+
int64_t pool_size = (hend - hstart) * (wend - wstart);
240233
pool_process.finalize(static_cast<T>(pool_size), &ele);
241-
int output_idx =
234+
int64_t output_idx =
242235
channel_last
243236
? (h_offset * output_width + w_offset) * channels + c_offset
244237
: h_offset * output_width + w_offset;
@@ -480,20 +473,12 @@ void Pool2dDirectCUDAFunctor<PoolProcess, T>::operator()(
480473
dim3 grid(
481474
std::max((output_channels + blocks - 1) / blocks, 1), batch_size, 1);
482475
AdaptiveKernelPool2D<PoolProcess, T>
483-
<<<grid, threads, 0, stream>>>(nthreads,
484-
input,
476+
<<<grid, threads, 0, stream>>>(input,
485477
input_channels,
486478
input_height,
487479
input_width,
488480
output_height,
489481
output_width,
490-
ksize_height,
491-
ksize_width,
492-
stride_height,
493-
stride_width,
494-
padding_height,
495-
padding_width,
496-
pool_divmods,
497482
pool_compute,
498483
exclusive,
499484
output);
@@ -537,90 +522,6 @@ void Pool2dDirectCUDAFunctor<PoolProcess, T>::operator()(
537522
template <typename PoolProcess, typename T>
538523
class Pool2dFunctor<phi::GPUContext, PoolProcess, T> {
539524
public:
540-
void operator()(const phi::GPUContext& context,
541-
const DenseTensor& input,
542-
const std::vector<int>& ksize,
543-
const std::vector<int>& strides,
544-
const std::vector<int>& paddings,
545-
bool exclusive,
546-
bool adaptive,
547-
DenseTensor* output,
548-
PoolProcess pool_process) {
549-
const int batch_size = input.dims()[0];
550-
const int input_channels = input.dims()[1];
551-
const int input_height = input.dims()[2];
552-
const int input_width = input.dims()[3];
553-
const int output_channels = output->dims()[1];
554-
const int output_height = output->dims()[2];
555-
const int output_width = output->dims()[3];
556-
const int ksize_height = ksize[0];
557-
const int ksize_width = ksize[1];
558-
const int stride_height = strides[0];
559-
const int stride_width = strides[1];
560-
const int padding_height = paddings[0];
561-
const int padding_width = paddings[1];
562-
563-
const T* input_data = input.data<T>();
564-
T* output_data = context.template Alloc<T>(output);
565-
566-
int nthreads = batch_size * output_channels * output_height * output_width;
567-
auto pool_divmods =
568-
FastDivModForPooling(input_channels, output_width, output_height);
569-
if (adaptive) {
570-
int max_threads = 512;
571-
int thread_num = std::min(
572-
phi::funcs::details::GetLastPow2(output_height * output_width),
573-
max_threads);
574-
int blocks = std::min(max_threads / thread_num, output_channels);
575-
dim3 threads(thread_num, blocks, 1);
576-
dim3 grid(
577-
std::max((output_channels + blocks - 1) / blocks, 1), batch_size, 1);
578-
AdaptiveKernelPool2D<PoolProcess, T>
579-
<<<grid, threads, 0, context.stream()>>>(nthreads,
580-
input_data,
581-
input_channels,
582-
input_height,
583-
input_width,
584-
output_height,
585-
output_width,
586-
ksize_height,
587-
ksize_width,
588-
stride_height,
589-
stride_width,
590-
padding_height,
591-
padding_width,
592-
pool_divmods,
593-
pool_process,
594-
exclusive,
595-
output_data);
596-
} else {
597-
int thread_num = 1024;
598-
#ifdef WITH_NV_JETSON
599-
backends::gpu::ChangeThreadNum(context, &thread_num);
600-
#endif
601-
int blocks = (nthreads + thread_num - 1) / thread_num;
602-
dim3 threads(thread_num, 1);
603-
dim3 grid(blocks, 1);
604-
KernelPool2D<PoolProcess, T>
605-
<<<grid, threads, 0, context.stream()>>>(nthreads,
606-
input_data,
607-
input_channels,
608-
input_height,
609-
input_width,
610-
output_height,
611-
output_width,
612-
ksize_height,
613-
ksize_width,
614-
stride_height,
615-
stride_width,
616-
padding_height,
617-
padding_width,
618-
pool_divmods,
619-
pool_process,
620-
exclusive,
621-
output_data);
622-
}
623-
}
624525
void operator()(const phi::GPUContext& context,
625526
const DenseTensor& input,
626527
const std::vector<int>& ksize,
@@ -632,17 +533,20 @@ class Pool2dFunctor<phi::GPUContext, PoolProcess, T> {
632533
DenseTensor* output,
633534
PoolProcess pool_process) {
634535
bool channel_last = (data_format == "NHWC");
635-
const int batch_size = input.dims()[0];
536+
const int64_t batch_size = input.dims()[0];
636537

637-
const int input_channels = channel_last ? input.dims()[3] : input.dims()[1];
638-
const int input_height = channel_last ? input.dims()[1] : input.dims()[2];
639-
const int input_width = channel_last ? input.dims()[2] : input.dims()[3];
538+
const int64_t input_channels =
539+
channel_last ? input.dims()[3] : input.dims()[1];
540+
const int64_t input_height =
541+
channel_last ? input.dims()[1] : input.dims()[2];
542+
const int64_t input_width =
543+
channel_last ? input.dims()[2] : input.dims()[3];
640544

641-
const int output_channels =
545+
const int64_t output_channels =
642546
channel_last ? output->dims()[3] : output->dims()[1];
643-
const int output_height =
547+
const int64_t output_height =
644548
channel_last ? output->dims()[1] : output->dims()[2];
645-
const int output_width =
549+
const int64_t output_width =
646550
channel_last ? output->dims()[2] : output->dims()[3];
647551

648552
const int ksize_height = ksize[0];
@@ -657,33 +561,26 @@ class Pool2dFunctor<phi::GPUContext, PoolProcess, T> {
657561
const T* input_data = input.data<T>();
658562
T* output_data = context.template Alloc<T>(output);
659563

660-
int nthreads = batch_size * output_channels * output_height * output_width;
564+
int64_t nthreads =
565+
batch_size * output_channels * output_height * output_width;
661566
auto pool_divmods =
662567
FastDivModForPooling(input_channels, output_width, output_height);
663568
if (adaptive) {
664-
int max_threads = 512;
665-
int thread_num = std::min(
666-
phi::funcs::details::GetLastPow2(output_height * output_width),
569+
int64_t max_threads = 512;
570+
int64_t thread_num = std::min(
571+
phi::funcs::details::GetInt64LastPow2(output_height * output_width),
667572
max_threads);
668-
int blocks = std::min(max_threads / thread_num, output_channels);
573+
int64_t blocks = std::min(max_threads / thread_num, output_channels);
669574
dim3 threads(thread_num, blocks, 1);
670575
dim3 grid(
671-
std::max((output_channels + blocks - 1) / blocks, 1), batch_size, 1);
576+
std::max((output_channels + blocks - 1) / blocks, 1l), batch_size, 1);
672577
AdaptiveKernelPool2D<PoolProcess, T>
673-
<<<grid, threads, 0, context.stream()>>>(nthreads,
674-
input_data,
578+
<<<grid, threads, 0, context.stream()>>>(input_data,
675579
input_channels,
676580
input_height,
677581
input_width,
678582
output_height,
679583
output_width,
680-
ksize_height,
681-
ksize_width,
682-
stride_height,
683-
stride_width,
684-
padding_height,
685-
padding_width,
686-
pool_divmods,
687584
pool_process,
688585
exclusive,
689586
output_data,

paddle/phi/kernels/funcs/pooling.h

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,20 @@ limitations under the License. */
3030
namespace phi {
3131
namespace funcs {
3232

33+
namespace details {
34+
static inline int64_t GetInt64LastPow2(int64_t x) {
35+
if (x <= 0) return 0;
36+
uint64_t ux = x;
37+
ux |= (ux >> 1);
38+
ux |= (ux >> 2);
39+
ux |= (ux >> 4);
40+
ux |= (ux >> 8);
41+
ux |= (ux >> 16);
42+
ux |= (ux >> 32);
43+
return static_cast<int64_t>(ux - (ux >> 1));
44+
}
45+
} // namespace details
46+
3347
/*
3448
* \brief Extracting simple operations from pooling.
3549
* Both MaxPool and AvgPool need "initial", "compute" and "finalize"
@@ -211,17 +225,6 @@ class Pool2dDirectCUDAFunctor {
211225
template <typename Context, typename PoolProcess, typename T>
212226
class Pool2dFunctor {
213227
public:
214-
void operator()(const Context& context,
215-
const DenseTensor& input,
216-
const std::vector<int>& ksize,
217-
const std::vector<int>& strides,
218-
const std::vector<int>& paddings,
219-
bool exclusive,
220-
bool adaptive,
221-
DenseTensor* output,
222-
PoolProcess pool_compute);
223-
224-
// overload operator() to support argument data_format
225228
void operator()(const Context& context,
226229
const DenseTensor& input,
227230
const std::vector<int>& ksize,

0 commit comments

Comments
 (0)