@@ -18,6 +18,11 @@ namespace paddle {
1818namespace operators {
1919namespace math {
2020
21+ /*
22+ * All tensors are in NCHW format.
23+ * Ksize, strides, paddings are two elements. These two elements represent
24+ * height and width, respectively.
25+ */
2126template <typename PoolProcess, typename T>
2227class Pool2dFunctor <platform::CPUPlace, PoolProcess, T> {
2328 public:
@@ -73,6 +78,11 @@ class Pool2dFunctor<platform::CPUPlace, PoolProcess, T> {
7378 }
7479};
7580
81+ /*
82+ * All tensors are in NCHW format.
83+ * Ksize, strides, paddings are two elements. These two elements represent height
84+ * and width, respectively.
85+ */
7686template <typename PoolProcess, class T >
7787class Pool2dGradFunctor <platform::CPUPlace, PoolProcess, T> {
7888 public:
@@ -135,6 +145,11 @@ class Pool2dGradFunctor<platform::CPUPlace, PoolProcess, T> {
135145 }
136146};
137147
148+ /*
149+ * All tensors are in NCHW format.
150+ * Ksize, strides, paddings are two elements. These two elements represent
151+ * height and width, respectively.
152+ */
138153template <class T >
139154class MaxPool2dGradFunctor <platform::CPUPlace, T> {
140155 public:
@@ -197,7 +212,7 @@ class MaxPool2dGradFunctor<platform::CPUPlace, T> {
197212};
198213
199214template class MaxPool2dGradFunctor <platform::CPUPlace, float >;
200- // template class MaxPool2dGradFunctor<platform::CPUPlace, double>;
215+ template class MaxPool2dGradFunctor <platform::CPUPlace, double >;
201216
202217template class Pool2dFunctor <platform::CPUPlace,
203218 paddle::operators::math::MaxPool<float >, float >;
@@ -216,6 +231,11 @@ template class Pool2dGradFunctor<
216231template class Pool2dGradFunctor <
217232 platform::CPUPlace, paddle::operators::math::AvgPoolGrad<double >, double >;
218233
234+ /*
235+ * All tensors are in NCDHW format.
236+ * Ksize, strides, paddings are three elements. These three elements represent
237+ * depth, height and width, respectively.
238+ */
219239template <typename PoolProcess, class T >
220240class Pool3dFunctor <platform::CPUPlace, PoolProcess, T> {
221241 public:
@@ -286,6 +306,11 @@ class Pool3dFunctor<platform::CPUPlace, PoolProcess, T> {
286306 }
287307};
288308
309+ /*
310+ * All tensors are in NCDHW format.
311+ * Ksize, strides, paddings are three elements. These three elements represent
312+ * depth, height and width, respectively.
313+ */
289314template <typename PoolProcess, class T >
290315class Pool3dGradFunctor <platform::CPUPlace, PoolProcess, T> {
291316 public:
@@ -364,6 +389,11 @@ class Pool3dGradFunctor<platform::CPUPlace, PoolProcess, T> {
364389 }
365390};
366391
392+ /*
393+ * All tensors are in NCDHW format.
394+ * Ksize, strides, paddings are three elements. These three elements represent
395+ * depth, height and width, respectively.
396+ */
367397template <class T >
368398class MaxPool3dGradFunctor <platform::CPUPlace, T> {
369399 public:
@@ -440,7 +470,7 @@ class MaxPool3dGradFunctor<platform::CPUPlace, T> {
440470};
441471
442472template class MaxPool3dGradFunctor <platform::CPUPlace, float >;
443- // template class MaxPool3dGradFunctor<platform::CPUPlace, double>;
473+ template class MaxPool3dGradFunctor <platform::CPUPlace, double >;
444474
445475template class Pool3dFunctor <platform::CPUPlace,
446476 paddle::operators::math::MaxPool<float >, float >;
@@ -458,6 +488,253 @@ template class Pool3dGradFunctor<
458488 platform::CPUPlace, paddle::operators::math::MaxPoolGrad<double >, double >;
459489template class Pool3dGradFunctor <
460490 platform::CPUPlace, paddle::operators::math::AvgPoolGrad<double >, double >;
491+
492+ /*
493+ * All tensors are in NCHW format.
494+ * Ksize, strides, paddings are two elements. These two elements represent
495+ * height and width, respectively.
496+ */
497+ template <typename T>
498+ class MaxPool2dWithIndexFunctor <platform::CPUPlace, T> {
499+ public:
500+ void operator ()(const platform::DeviceContext& context,
501+ const framework::Tensor& input, framework::Tensor& output,
502+ framework::Tensor& mask, std::vector<int >& ksize,
503+ std::vector<int >& strides, std::vector<int >& paddings) {
504+ const int batch_size = input.dims ()[0 ];
505+ const int input_height = input.dims ()[2 ];
506+ const int input_width = input.dims ()[3 ];
507+ const int output_channels = output.dims ()[1 ];
508+ const int output_height = output.dims ()[2 ];
509+ const int output_width = output.dims ()[3 ];
510+ const int ksize_height = ksize[0 ];
511+ const int ksize_width = ksize[1 ];
512+ const int stride_height = strides[0 ];
513+ const int stride_width = strides[1 ];
514+ const int padding_height = paddings[0 ];
515+ const int padding_width = paddings[1 ];
516+ const int input_stride = input_height * input_width;
517+ const int output_stride = output_height * output_width;
518+
519+ const T* input_data = input.data <T>();
520+ T* output_data = output.mutable_data <T>(context.GetPlace ());
521+ T* mask_data = mask.mutable_data <T>(context.GetPlace ());
522+
523+ for (int i = 0 ; i < batch_size; i++) {
524+ for (int c = 0 ; c < output_channels; ++c) {
525+ for (int ph = 0 ; ph < output_height; ++ph) {
526+ int hstart = ph * stride_height - padding_height;
527+ int hend = std::min (hstart + ksize_height, input_height);
528+ hstart = std::max (hstart, 0 );
529+ for (int pw = 0 ; pw < output_width; ++pw) {
530+ int wstart = pw * stride_width - padding_width;
531+ int wend = std::min (wstart + ksize_width, input_width);
532+ wstart = std::max (wstart, 0 );
533+
534+ T ele = static_cast <T>(-FLT_MAX);
535+ int index = -1 ;
536+ for (int h = hstart; h < hend; ++h) {
537+ for (int w = wstart; w < wend; ++w) {
538+ if (ele < input_data[h * input_width + w]) {
539+ ele = input_data[h * input_width + w];
540+ index = h * input_width + w;
541+ }
542+ }
543+ }
544+ output_data[ph * output_width + pw] = ele;
545+ mask_data[ph * output_width + pw] = index;
546+ }
547+ }
548+ // offset
549+ input_data += input_stride;
550+ output_data += output_stride;
551+ mask_data += output_stride;
552+ }
553+ }
554+ }
555+ };
556+
557+ /*
558+ * All tensors are in NCHW format.
559+ * Ksize, strides, paddings are two elements. These two elements represent
560+ * height and width, respectively.
561+ */
562+ template <typename T>
563+ class MaxPool2dWithIndexGradFunctor <platform::CPUPlace, T> {
564+ public:
565+ void operator ()(const platform::DeviceContext& context,
566+ framework::Tensor& input_grad,
567+ const framework::Tensor& output_grad,
568+ const framework::Tensor& mask, std::vector<int >& ksize,
569+ std::vector<int >& strides, std::vector<int >& paddings) {
570+ const int batch_size = input_grad.dims ()[0 ];
571+ const int input_height = input_grad.dims ()[2 ];
572+ const int input_width = input_grad.dims ()[3 ];
573+ const int output_channels = output_grad.dims ()[1 ];
574+ const int output_height = output_grad.dims ()[2 ];
575+ const int output_width = output_grad.dims ()[3 ];
576+ const int input_stride = input_height * input_width;
577+ const int output_stride = output_height * output_width;
578+
579+ const T* mask_data = mask.data <T>();
580+ const T* output_grad_data = output_grad.data <T>();
581+ T* input_grad_data = input_grad.mutable_data <T>(context.GetPlace ());
582+
583+ for (int n = 0 ; n < batch_size; ++n) {
584+ for (int c = 0 ; c < output_channels; ++c) {
585+ for (int ph = 0 ; ph < output_height; ++ph) {
586+ for (int pw = 0 ; pw < output_width; ++pw) {
587+ const int output_idx = ph * output_width + pw;
588+ const int input_idx = static_cast <int >(mask_data[output_idx]);
589+ input_grad_data[input_idx] += output_grad_data[output_idx];
590+ }
591+ }
592+ // offset
593+ input_grad_data += input_stride;
594+ output_grad_data += output_stride;
595+ mask_data += output_stride;
596+ }
597+ }
598+ }
599+ };
600+
601+ template class MaxPool2dWithIndexFunctor <platform::CPUPlace, float >;
602+ template class MaxPool2dWithIndexGradFunctor <platform::CPUPlace, float >;
603+ template class MaxPool2dWithIndexFunctor <platform::CPUPlace, double >;
604+ template class MaxPool2dWithIndexGradFunctor <platform::CPUPlace, double >;
605+
606+ /*
607+ * All tensors are in NCDHW format.
608+ * Ksize, strides, paddings are three elements. These three elements represent
609+ * depth, height and width, respectively.
610+ */
611+ template <typename T>
612+ class MaxPool3dWithIndexFunctor <platform::CPUPlace, T> {
613+ public:
614+ void operator ()(const platform::DeviceContext& context,
615+ const framework::Tensor& input, framework::Tensor& output,
616+ framework::Tensor& mask, std::vector<int >& ksize,
617+ std::vector<int >& strides, std::vector<int >& paddings) {
618+ const int batch_size = input.dims ()[0 ];
619+ const int input_depth = input.dims ()[2 ];
620+ const int input_height = input.dims ()[3 ];
621+ const int input_width = input.dims ()[4 ];
622+ const int output_channels = output.dims ()[1 ];
623+ const int output_depth = output.dims ()[2 ];
624+ const int output_height = output.dims ()[3 ];
625+ const int output_width = output.dims ()[4 ];
626+ const int ksize_depth = ksize[0 ];
627+ const int ksize_height = ksize[1 ];
628+ const int ksize_width = ksize[2 ];
629+ const int stride_depth = strides[0 ];
630+ const int stride_height = strides[1 ];
631+ const int stride_width = strides[2 ];
632+ const int padding_depth = paddings[0 ];
633+ const int padding_height = paddings[1 ];
634+ const int padding_width = paddings[2 ];
635+ const int input_stride = input_depth * input_height * input_width;
636+ const int output_stride = output_depth * output_height * output_width;
637+
638+ const T* input_data = input.data <T>();
639+ T* output_data = output.mutable_data <T>(context.GetPlace ());
640+ T* mask_data = mask.mutable_data <T>(context.GetPlace ());
641+
642+ for (int i = 0 ; i < batch_size; i++) {
643+ for (int c = 0 ; c < output_channels; ++c) {
644+ for (int pd = 0 ; pd < output_depth; ++pd) {
645+ int dstart = pd * stride_depth - padding_depth;
646+ int dend = std::min (dstart + ksize_depth, input_depth);
647+ dstart = std::max (dstart, 0 );
648+ for (int ph = 0 ; ph < output_height; ++ph) {
649+ int hstart = ph * stride_height - padding_height;
650+ int hend = std::min (hstart + ksize_height, input_height);
651+ hstart = std::max (hstart, 0 );
652+ for (int pw = 0 ; pw < output_width; ++pw) {
653+ int wstart = pw * stride_width - padding_width;
654+ int wend = std::min (wstart + ksize_width, input_width);
655+ wstart = std::max (wstart, 0 );
656+
657+ int output_idx = (pd * output_height + ph) * output_width + pw;
658+ T ele = static_cast <T>(-FLT_MAX);
659+ int index = -1 ;
660+ for (int d = dstart; d < dend; ++d) {
661+ for (int h = hstart; h < hend; ++h) {
662+ for (int w = wstart; w < wend; ++w) {
663+ int input_idx = (d * input_height + h) * input_width + w;
664+ if (ele < input_data[input_idx]) {
665+ index = input_idx;
666+ ele = input_data[input_idx];
667+ }
668+ }
669+ }
670+ }
671+ output_data[output_idx] = ele;
672+ mask_data[output_idx] = index;
673+ }
674+ }
675+ }
676+ // offset
677+ input_data += input_stride;
678+ output_data += output_stride;
679+ mask_data += output_stride;
680+ }
681+ }
682+ }
683+ };
684+
685+ /*
686+ * All tensors are in NCDHW format.
687+ * Ksize, strides, paddings are three elements. These three elements represent
688+ * depth, height and width, respectively.
689+ */
690+ template <typename T>
691+ class MaxPool3dWithIndexGradFunctor <platform::CPUPlace, T> {
692+ public:
693+ void operator ()(const platform::DeviceContext& context,
694+ framework::Tensor& input_grad,
695+ const framework::Tensor& output_grad,
696+ const framework::Tensor& mask, std::vector<int >& ksize,
697+ std::vector<int >& strides, std::vector<int >& paddings) {
698+ const int batch_size = input_grad.dims ()[0 ];
699+ const int input_depth = input_grad.dims ()[2 ];
700+ const int input_height = input_grad.dims ()[3 ];
701+ const int input_width = input_grad.dims ()[4 ];
702+ const int output_channels = output_grad.dims ()[1 ];
703+ const int output_depth = output_grad.dims ()[2 ];
704+ const int output_height = output_grad.dims ()[3 ];
705+ const int output_width = output_grad.dims ()[4 ];
706+ const int input_stride = input_depth * input_height * input_width;
707+ const int output_stride = output_depth * output_height * output_width;
708+
709+ const T* mask_data = mask.data <T>();
710+ const T* output_grad_data = output_grad.data <T>();
711+ T* input_grad_data = input_grad.mutable_data <T>(context.GetPlace ());
712+
713+ for (int n = 0 ; n < batch_size; ++n) {
714+ for (int c = 0 ; c < output_channels; ++c) {
715+ for (int pd = 0 ; pd < output_depth; ++pd) {
716+ for (int ph = 0 ; ph < output_height; ++ph) {
717+ for (int pw = 0 ; pw < output_width; ++pw) {
718+ const int output_idx =
719+ (pd * output_height + ph) * output_width + pw;
720+ const int input_idx = static_cast <int >(mask_data[output_idx]);
721+ input_grad_data[input_idx] += output_grad_data[output_idx];
722+ }
723+ }
724+ }
725+ // offset
726+ input_grad_data += input_stride;
727+ output_grad_data += output_stride;
728+ mask_data += output_stride;
729+ }
730+ }
731+ }
732+ };
733+
734+ template class MaxPool3dWithIndexFunctor <platform::CPUPlace, float >;
735+ template class MaxPool3dWithIndexGradFunctor <platform::CPUPlace, float >;
736+ template class MaxPool3dWithIndexFunctor <platform::CPUPlace, double >;
737+ template class MaxPool3dWithIndexGradFunctor <platform::CPUPlace, double >;
461738} // namespace math
462739} // namespace operators
463740} // namespace paddle
0 commit comments