@@ -498,8 +498,8 @@ template class Pool3dGradFunctor<
498498 * Ksize, strides, paddings are two elements. These two elements represent
499499 * height and width, respectively.
500500 */
501- template <typename T >
502- class MaxPool2dWithIndexFunctor <platform::CPUPlace, T > {
501+ template <typename T1, typename T2 >
502+ class MaxPool2dWithIndexFunctor <platform::CPUPlace, T1, T2 > {
503503 public:
504504 void operator ()(const platform::DeviceContext& context,
505505 const framework::Tensor& input, std::vector<int >& ksize,
@@ -520,9 +520,9 @@ class MaxPool2dWithIndexFunctor<platform::CPUPlace, T> {
520520 const int input_stride = input_height * input_width;
521521 const int output_stride = output_height * output_width;
522522
523- const T * input_data = input.data <T >();
524- T * output_data = output->mutable_data <T >(context.GetPlace ());
525- T * mask_data = mask->mutable_data <T >(context.GetPlace ());
523+ const T1 * input_data = input.data <T1 >();
524+ T1 * output_data = output->mutable_data <T1 >(context.GetPlace ());
525+ T2 * mask_data = mask->mutable_data <T2 >(context.GetPlace ());
526526
527527 for (int i = 0 ; i < batch_size; i++) {
528528 for (int c = 0 ; c < output_channels; ++c) {
@@ -535,7 +535,7 @@ class MaxPool2dWithIndexFunctor<platform::CPUPlace, T> {
535535 int wend = std::min (wstart + ksize_width, input_width);
536536 wstart = std::max (wstart, 0 );
537537
538- T ele = static_cast <T >(-FLT_MAX);
538+ T1 ele = static_cast <T1 >(-FLT_MAX);
539539 int index = -1 ;
540540 for (int h = hstart; h < hend; ++h) {
541541 for (int w = wstart; w < wend; ++w) {
@@ -563,8 +563,8 @@ class MaxPool2dWithIndexFunctor<platform::CPUPlace, T> {
563563 * Ksize, strides, paddings are two elements. These two elements represent
564564 * height and width, respectively.
565565 */
566- template <typename T >
567- class MaxPool2dWithIndexGradFunctor <platform::CPUPlace, T > {
566+ template <typename T1, typename T2 >
567+ class MaxPool2dWithIndexGradFunctor <platform::CPUPlace, T1, T2 > {
568568 public:
569569 void operator ()(const platform::DeviceContext& context,
570570 const framework::Tensor& output_grad,
@@ -580,9 +580,9 @@ class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, T> {
580580 const int input_stride = input_height * input_width;
581581 const int output_stride = output_height * output_width;
582582
583- const T * mask_data = mask.data <T >();
584- const T * output_grad_data = output_grad.data <T >();
585- T * input_grad_data = input_grad->mutable_data <T >(context.GetPlace ());
583+ const T2 * mask_data = mask.data <T2 >();
584+ const T1 * output_grad_data = output_grad.data <T1 >();
585+ T1 * input_grad_data = input_grad->mutable_data <T1 >(context.GetPlace ());
586586
587587 for (int n = 0 ; n < batch_size; ++n) {
588588 for (int c = 0 ; c < output_channels; ++c) {
@@ -602,18 +602,18 @@ class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, T> {
602602 }
603603};
604604
605- template class MaxPool2dWithIndexFunctor <platform::CPUPlace, float >;
606- template class MaxPool2dWithIndexGradFunctor <platform::CPUPlace, float >;
607- template class MaxPool2dWithIndexFunctor <platform::CPUPlace, double >;
608- template class MaxPool2dWithIndexGradFunctor <platform::CPUPlace, double >;
605+ template class MaxPool2dWithIndexFunctor <platform::CPUPlace, float , int >;
606+ template class MaxPool2dWithIndexGradFunctor <platform::CPUPlace, float , int >;
607+ template class MaxPool2dWithIndexFunctor <platform::CPUPlace, double , int >;
608+ template class MaxPool2dWithIndexGradFunctor <platform::CPUPlace, double , int >;
609609
610610/*
611611 * All tensors are in NCDHW format.
612612 * Ksize, strides, paddings are three elements. These three elements represent
613613 * depth, height and width, respectively.
614614 */
615- template <typename T >
616- class MaxPool3dWithIndexFunctor <platform::CPUPlace, T > {
615+ template <typename T1, typename T2 >
616+ class MaxPool3dWithIndexFunctor <platform::CPUPlace, T1, T2 > {
617617 public:
618618 void operator ()(const platform::DeviceContext& context,
619619 const framework::Tensor& input, std::vector<int >& ksize,
@@ -639,9 +639,9 @@ class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> {
639639 const int input_stride = input_depth * input_height * input_width;
640640 const int output_stride = output_depth * output_height * output_width;
641641
642- const T * input_data = input.data <T >();
643- T * output_data = output->mutable_data <T >(context.GetPlace ());
644- T * mask_data = mask->mutable_data <T >(context.GetPlace ());
642+ const T1 * input_data = input.data <T1 >();
643+ T1 * output_data = output->mutable_data <T1 >(context.GetPlace ());
644+ T2 * mask_data = mask->mutable_data <T2 >(context.GetPlace ());
645645
646646 for (int i = 0 ; i < batch_size; i++) {
647647 for (int c = 0 ; c < output_channels; ++c) {
@@ -659,7 +659,7 @@ class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> {
659659 wstart = std::max (wstart, 0 );
660660
661661 int output_idx = (pd * output_height + ph) * output_width + pw;
662- T ele = static_cast <T >(-FLT_MAX);
662+ T1 ele = static_cast <T1 >(-FLT_MAX);
663663 int index = -1 ;
664664 for (int d = dstart; d < dend; ++d) {
665665 for (int h = hstart; h < hend; ++h) {
@@ -691,8 +691,8 @@ class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> {
691691 * Ksize, strides, paddings are three elements. These three elements represent
692692 * depth, height and width, respectively.
693693 */
694- template <typename T >
695- class MaxPool3dWithIndexGradFunctor <platform::CPUPlace, T > {
694+ template <typename T1, typename T2 >
695+ class MaxPool3dWithIndexGradFunctor <platform::CPUPlace, T1, T2 > {
696696 public:
697697 void operator ()(const platform::DeviceContext& context,
698698 const framework::Tensor& output_grad,
@@ -710,9 +710,9 @@ class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, T> {
710710 const int input_stride = input_depth * input_height * input_width;
711711 const int output_stride = output_depth * output_height * output_width;
712712
713- const T * mask_data = mask.data <T >();
714- const T * output_grad_data = output_grad.data <T >();
715- T * input_grad_data = input_grad->mutable_data <T >(context.GetPlace ());
713+ const T2 * mask_data = mask.data <T2 >();
714+ const T1 * output_grad_data = output_grad.data <T1 >();
715+ T1 * input_grad_data = input_grad->mutable_data <T1 >(context.GetPlace ());
716716
717717 for (int n = 0 ; n < batch_size; ++n) {
718718 for (int c = 0 ; c < output_channels; ++c) {
@@ -735,10 +735,10 @@ class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, T> {
735735 }
736736};
737737
738- template class MaxPool3dWithIndexFunctor <platform::CPUPlace, float >;
739- template class MaxPool3dWithIndexGradFunctor <platform::CPUPlace, float >;
740- template class MaxPool3dWithIndexFunctor <platform::CPUPlace, double >;
741- template class MaxPool3dWithIndexGradFunctor <platform::CPUPlace, double >;
738+ template class MaxPool3dWithIndexFunctor <platform::CPUPlace, float , int >;
739+ template class MaxPool3dWithIndexGradFunctor <platform::CPUPlace, float , int >;
740+ template class MaxPool3dWithIndexFunctor <platform::CPUPlace, double , int >;
741+ template class MaxPool3dWithIndexGradFunctor <platform::CPUPlace, double , int >;
742742} // namespace math
743743} // namespace operators
744744} // namespace paddle
0 commit comments