Skip to content

Commit 0f1d3af

Browse files
author
chengduo
authored
Merge pull request #4461 from chengduoZH/Add_maxpool_withIdx_only
Add max pool op (with index)
2 parents 8e2cc75 + 36da825 commit 0f1d3af

File tree

8 files changed

+1369
-14
lines changed

8 files changed

+1369
-14
lines changed

paddle/operators/CMakeLists.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,20 @@ function(op_library TARGET)
5555
set(pybind_flag 1)
5656
endif()
5757

58+
# pool_op contains several operators
5859
if ("${TARGET}" STREQUAL "pool_op")
5960
set(pybind_flag 1)
6061
# It's enough to just adding one operator to pybind
6162
file(APPEND ${pybind_file} "USE_OP(pool2d);\n")
6263
endif()
6364

65+
# pool_with_index_op contains several operators
66+
if ("${TARGET}" STREQUAL "pool_with_index_op")
67+
set(pybind_flag 1)
68+
# It's enough to just adding one operator to pybind
69+
file(APPEND ${pybind_file} "USE_OP(max_pool2d_with_index);\n")
70+
endif()
71+
6472
# activation_op contains several operators
6573
if ("${TARGET}" STREQUAL "activation_op")
6674
set(pybind_flag 1)

paddle/operators/math/pooling.cc

Lines changed: 279 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@ namespace paddle {
1818
namespace operators {
1919
namespace math {
2020

21+
/*
22+
* All tensors are in NCHW format.
23+
* Ksize, strides, paddings are two elements. These two elements represent
24+
* height and width, respectively.
25+
*/
2126
template <typename PoolProcess, typename T>
2227
class Pool2dFunctor<platform::CPUPlace, PoolProcess, T> {
2328
public:
@@ -73,6 +78,11 @@ class Pool2dFunctor<platform::CPUPlace, PoolProcess, T> {
7378
}
7479
};
7580

81+
/*
82+
* All tensors are in NCHW format.
83+
* Ksize, strides, paddings are two elements. These two elements represent height
84+
* and width, respectively.
85+
*/
7686
template <typename PoolProcess, class T>
7787
class Pool2dGradFunctor<platform::CPUPlace, PoolProcess, T> {
7888
public:
@@ -135,6 +145,11 @@ class Pool2dGradFunctor<platform::CPUPlace, PoolProcess, T> {
135145
}
136146
};
137147

148+
/*
149+
* All tensors are in NCHW format.
150+
* Ksize, strides, paddings are two elements. These two elements represent
151+
* height and width, respectively.
152+
*/
138153
template <class T>
139154
class MaxPool2dGradFunctor<platform::CPUPlace, T> {
140155
public:
@@ -197,7 +212,7 @@ class MaxPool2dGradFunctor<platform::CPUPlace, T> {
197212
};
198213

199214
template class MaxPool2dGradFunctor<platform::CPUPlace, float>;
200-
// template class MaxPool2dGradFunctor<platform::CPUPlace, double>;
215+
template class MaxPool2dGradFunctor<platform::CPUPlace, double>;
201216

202217
template class Pool2dFunctor<platform::CPUPlace,
203218
paddle::operators::math::MaxPool<float>, float>;
@@ -216,6 +231,11 @@ template class Pool2dGradFunctor<
216231
template class Pool2dGradFunctor<
217232
platform::CPUPlace, paddle::operators::math::AvgPoolGrad<double>, double>;
218233

234+
/*
235+
* All tensors are in NCDHW format.
236+
* Ksize, strides, paddings are three elements. These three elements represent
237+
* depth, height and width, respectively.
238+
*/
219239
template <typename PoolProcess, class T>
220240
class Pool3dFunctor<platform::CPUPlace, PoolProcess, T> {
221241
public:
@@ -286,6 +306,11 @@ class Pool3dFunctor<platform::CPUPlace, PoolProcess, T> {
286306
}
287307
};
288308

309+
/*
310+
* All tensors are in NCDHW format.
311+
* Ksize, strides, paddings are three elements. These three elements represent
312+
* depth, height and width, respectively.
313+
*/
289314
template <typename PoolProcess, class T>
290315
class Pool3dGradFunctor<platform::CPUPlace, PoolProcess, T> {
291316
public:
@@ -364,6 +389,11 @@ class Pool3dGradFunctor<platform::CPUPlace, PoolProcess, T> {
364389
}
365390
};
366391

392+
/*
393+
* All tensors are in NCDHW format.
394+
* Ksize, strides, paddings are three elements. These three elements represent
395+
* depth, height and width, respectively.
396+
*/
367397
template <class T>
368398
class MaxPool3dGradFunctor<platform::CPUPlace, T> {
369399
public:
@@ -440,7 +470,7 @@ class MaxPool3dGradFunctor<platform::CPUPlace, T> {
440470
};
441471

442472
template class MaxPool3dGradFunctor<platform::CPUPlace, float>;
443-
// template class MaxPool3dGradFunctor<platform::CPUPlace, double>;
473+
template class MaxPool3dGradFunctor<platform::CPUPlace, double>;
444474

445475
template class Pool3dFunctor<platform::CPUPlace,
446476
paddle::operators::math::MaxPool<float>, float>;
@@ -458,6 +488,253 @@ template class Pool3dGradFunctor<
458488
platform::CPUPlace, paddle::operators::math::MaxPoolGrad<double>, double>;
459489
template class Pool3dGradFunctor<
460490
platform::CPUPlace, paddle::operators::math::AvgPoolGrad<double>, double>;
491+
492+
/*
493+
* All tensors are in NCHW format.
494+
* Ksize, strides, paddings are two elements. These two elements represent
495+
* height and width, respectively.
496+
*/
497+
template <typename T>
498+
class MaxPool2dWithIndexFunctor<platform::CPUPlace, T> {
499+
public:
500+
void operator()(const platform::DeviceContext& context,
501+
const framework::Tensor& input, framework::Tensor& output,
502+
framework::Tensor& mask, std::vector<int>& ksize,
503+
std::vector<int>& strides, std::vector<int>& paddings) {
504+
const int batch_size = input.dims()[0];
505+
const int input_height = input.dims()[2];
506+
const int input_width = input.dims()[3];
507+
const int output_channels = output.dims()[1];
508+
const int output_height = output.dims()[2];
509+
const int output_width = output.dims()[3];
510+
const int ksize_height = ksize[0];
511+
const int ksize_width = ksize[1];
512+
const int stride_height = strides[0];
513+
const int stride_width = strides[1];
514+
const int padding_height = paddings[0];
515+
const int padding_width = paddings[1];
516+
const int input_stride = input_height * input_width;
517+
const int output_stride = output_height * output_width;
518+
519+
const T* input_data = input.data<T>();
520+
T* output_data = output.mutable_data<T>(context.GetPlace());
521+
T* mask_data = mask.mutable_data<T>(context.GetPlace());
522+
523+
for (int i = 0; i < batch_size; i++) {
524+
for (int c = 0; c < output_channels; ++c) {
525+
for (int ph = 0; ph < output_height; ++ph) {
526+
int hstart = ph * stride_height - padding_height;
527+
int hend = std::min(hstart + ksize_height, input_height);
528+
hstart = std::max(hstart, 0);
529+
for (int pw = 0; pw < output_width; ++pw) {
530+
int wstart = pw * stride_width - padding_width;
531+
int wend = std::min(wstart + ksize_width, input_width);
532+
wstart = std::max(wstart, 0);
533+
534+
T ele = static_cast<T>(-FLT_MAX);
535+
int index = -1;
536+
for (int h = hstart; h < hend; ++h) {
537+
for (int w = wstart; w < wend; ++w) {
538+
if (ele < input_data[h * input_width + w]) {
539+
ele = input_data[h * input_width + w];
540+
index = h * input_width + w;
541+
}
542+
}
543+
}
544+
output_data[ph * output_width + pw] = ele;
545+
mask_data[ph * output_width + pw] = index;
546+
}
547+
}
548+
// offset
549+
input_data += input_stride;
550+
output_data += output_stride;
551+
mask_data += output_stride;
552+
}
553+
}
554+
}
555+
};
556+
557+
/*
558+
* All tensors are in NCHW format.
559+
* Ksize, strides, paddings are two elements. These two elements represent
560+
* height and width, respectively.
561+
*/
562+
template <typename T>
563+
class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, T> {
564+
public:
565+
void operator()(const platform::DeviceContext& context,
566+
framework::Tensor& input_grad,
567+
const framework::Tensor& output_grad,
568+
const framework::Tensor& mask, std::vector<int>& ksize,
569+
std::vector<int>& strides, std::vector<int>& paddings) {
570+
const int batch_size = input_grad.dims()[0];
571+
const int input_height = input_grad.dims()[2];
572+
const int input_width = input_grad.dims()[3];
573+
const int output_channels = output_grad.dims()[1];
574+
const int output_height = output_grad.dims()[2];
575+
const int output_width = output_grad.dims()[3];
576+
const int input_stride = input_height * input_width;
577+
const int output_stride = output_height * output_width;
578+
579+
const T* mask_data = mask.data<T>();
580+
const T* output_grad_data = output_grad.data<T>();
581+
T* input_grad_data = input_grad.mutable_data<T>(context.GetPlace());
582+
583+
for (int n = 0; n < batch_size; ++n) {
584+
for (int c = 0; c < output_channels; ++c) {
585+
for (int ph = 0; ph < output_height; ++ph) {
586+
for (int pw = 0; pw < output_width; ++pw) {
587+
const int output_idx = ph * output_width + pw;
588+
const int input_idx = static_cast<int>(mask_data[output_idx]);
589+
input_grad_data[input_idx] += output_grad_data[output_idx];
590+
}
591+
}
592+
// offset
593+
input_grad_data += input_stride;
594+
output_grad_data += output_stride;
595+
mask_data += output_stride;
596+
}
597+
}
598+
}
599+
};
600+
601+
template class MaxPool2dWithIndexFunctor<platform::CPUPlace, float>;
602+
template class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, float>;
603+
template class MaxPool2dWithIndexFunctor<platform::CPUPlace, double>;
604+
template class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, double>;
605+
606+
/*
607+
* All tensors are in NCDHW format.
608+
* Ksize, strides, paddings are three elements. These three elements represent
609+
* depth, height and width, respectively.
610+
*/
611+
template <typename T>
612+
class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> {
613+
public:
614+
void operator()(const platform::DeviceContext& context,
615+
const framework::Tensor& input, framework::Tensor& output,
616+
framework::Tensor& mask, std::vector<int>& ksize,
617+
std::vector<int>& strides, std::vector<int>& paddings) {
618+
const int batch_size = input.dims()[0];
619+
const int input_depth = input.dims()[2];
620+
const int input_height = input.dims()[3];
621+
const int input_width = input.dims()[4];
622+
const int output_channels = output.dims()[1];
623+
const int output_depth = output.dims()[2];
624+
const int output_height = output.dims()[3];
625+
const int output_width = output.dims()[4];
626+
const int ksize_depth = ksize[0];
627+
const int ksize_height = ksize[1];
628+
const int ksize_width = ksize[2];
629+
const int stride_depth = strides[0];
630+
const int stride_height = strides[1];
631+
const int stride_width = strides[2];
632+
const int padding_depth = paddings[0];
633+
const int padding_height = paddings[1];
634+
const int padding_width = paddings[2];
635+
const int input_stride = input_depth * input_height * input_width;
636+
const int output_stride = output_depth * output_height * output_width;
637+
638+
const T* input_data = input.data<T>();
639+
T* output_data = output.mutable_data<T>(context.GetPlace());
640+
T* mask_data = mask.mutable_data<T>(context.GetPlace());
641+
642+
for (int i = 0; i < batch_size; i++) {
643+
for (int c = 0; c < output_channels; ++c) {
644+
for (int pd = 0; pd < output_depth; ++pd) {
645+
int dstart = pd * stride_depth - padding_depth;
646+
int dend = std::min(dstart + ksize_depth, input_depth);
647+
dstart = std::max(dstart, 0);
648+
for (int ph = 0; ph < output_height; ++ph) {
649+
int hstart = ph * stride_height - padding_height;
650+
int hend = std::min(hstart + ksize_height, input_height);
651+
hstart = std::max(hstart, 0);
652+
for (int pw = 0; pw < output_width; ++pw) {
653+
int wstart = pw * stride_width - padding_width;
654+
int wend = std::min(wstart + ksize_width, input_width);
655+
wstart = std::max(wstart, 0);
656+
657+
int output_idx = (pd * output_height + ph) * output_width + pw;
658+
T ele = static_cast<T>(-FLT_MAX);
659+
int index = -1;
660+
for (int d = dstart; d < dend; ++d) {
661+
for (int h = hstart; h < hend; ++h) {
662+
for (int w = wstart; w < wend; ++w) {
663+
int input_idx = (d * input_height + h) * input_width + w;
664+
if (ele < input_data[input_idx]) {
665+
index = input_idx;
666+
ele = input_data[input_idx];
667+
}
668+
}
669+
}
670+
}
671+
output_data[output_idx] = ele;
672+
mask_data[output_idx] = index;
673+
}
674+
}
675+
}
676+
// offset
677+
input_data += input_stride;
678+
output_data += output_stride;
679+
mask_data += output_stride;
680+
}
681+
}
682+
}
683+
};
684+
685+
/*
686+
* All tensors are in NCDHW format.
687+
* Ksize, strides, paddings are three elements. These three elements represent
688+
* depth, height and width, respectively.
689+
*/
690+
template <typename T>
691+
class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, T> {
692+
public:
693+
void operator()(const platform::DeviceContext& context,
694+
framework::Tensor& input_grad,
695+
const framework::Tensor& output_grad,
696+
const framework::Tensor& mask, std::vector<int>& ksize,
697+
std::vector<int>& strides, std::vector<int>& paddings) {
698+
const int batch_size = input_grad.dims()[0];
699+
const int input_depth = input_grad.dims()[2];
700+
const int input_height = input_grad.dims()[3];
701+
const int input_width = input_grad.dims()[4];
702+
const int output_channels = output_grad.dims()[1];
703+
const int output_depth = output_grad.dims()[2];
704+
const int output_height = output_grad.dims()[3];
705+
const int output_width = output_grad.dims()[4];
706+
const int input_stride = input_depth * input_height * input_width;
707+
const int output_stride = output_depth * output_height * output_width;
708+
709+
const T* mask_data = mask.data<T>();
710+
const T* output_grad_data = output_grad.data<T>();
711+
T* input_grad_data = input_grad.mutable_data<T>(context.GetPlace());
712+
713+
for (int n = 0; n < batch_size; ++n) {
714+
for (int c = 0; c < output_channels; ++c) {
715+
for (int pd = 0; pd < output_depth; ++pd) {
716+
for (int ph = 0; ph < output_height; ++ph) {
717+
for (int pw = 0; pw < output_width; ++pw) {
718+
const int output_idx =
719+
(pd * output_height + ph) * output_width + pw;
720+
const int input_idx = static_cast<int>(mask_data[output_idx]);
721+
input_grad_data[input_idx] += output_grad_data[output_idx];
722+
}
723+
}
724+
}
725+
// offset
726+
input_grad_data += input_stride;
727+
output_grad_data += output_stride;
728+
mask_data += output_stride;
729+
}
730+
}
731+
}
732+
};
733+
734+
template class MaxPool3dWithIndexFunctor<platform::CPUPlace, float>;
735+
template class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, float>;
736+
template class MaxPool3dWithIndexFunctor<platform::CPUPlace, double>;
737+
template class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, double>;
461738
} // namespace math
462739
} // namespace operators
463740
} // namespace paddle

0 commit comments

Comments
 (0)