@@ -100,6 +100,25 @@ class MLUPoolOpKernel : public framework::OpKernel<T> {
100100 cnnlPoolingMode_t pool_mode =
101101 ToCnnlPoolingMode (pooling_type, exclusive, adaptive);
102102
103+ // transpose NCHW to NHWC since cnnl pool2d has worse performance in that
104+ // layout.
105+ framework::Tensor trans_in_x;
106+ framework::Tensor trans_out;
107+ if (channel_last) {
108+ trans_in_x = *in_x;
109+ trans_out = *out;
110+ } else {
111+ std::vector<int > perm{0 , 2 , 3 , 1 };
112+ TransposeFromMLUTensor<T>(
113+ ctx, perm, in_x, &trans_in_x, true /* need_reshape_or_alloc*/ );
114+ trans_out = ctx.AllocateTmpTensor <T, MLUDeviceContext>(
115+ {out_dims[0 ], out_dims[2 ], out_dims[3 ], out_dims[1 ]}, dev_ctx);
116+ }
117+ MLUCnnlTensorDesc trans_in_x_desc (
118+ trans_in_x, CNNL_LAYOUT_NHWC, ToCnnlDataType<T>());
119+ MLUCnnlTensorDesc trans_out_desc (
120+ trans_out, CNNL_LAYOUT_NHWC, ToCnnlDataType<T>());
121+
103122 if (!adaptive) {
104123 MLUCnnlPoolingDesc pool_desc (pool_mode,
105124 CNNL_NOT_PROPAGATE_NAN,
@@ -128,8 +147,8 @@ class MLUPoolOpKernel : public framework::OpKernel<T> {
128147 {static_cast <int64_t >(extra_input_size)}, cpu_ctx);
129148 cnnlInitPoolingExtraInput (handle,
130149 pool_desc.get (),
131- in_x_desc .get (),
132- out_desc .get (),
150+ trans_in_x_desc .get (),
151+ trans_out_desc .get (),
133152 GetBasePtr (&extra_host_tensor));
134153 framework::Tensor extra_device_tensor =
135154 ctx.AllocateTmpTensor <int8_t , MLUDeviceContext>(
@@ -151,44 +170,27 @@ class MLUPoolOpKernel : public framework::OpKernel<T> {
151170 out_w,
152171 pool_desc.get (),
153172 nullptr /* alpha*/ ,
154- in_x_desc .get (),
155- GetBasePtr (in_x ),
173+ trans_in_x_desc .get (),
174+ GetBasePtr (&trans_in_x ),
156175 nullptr /* beta*/ ,
157176 GetBasePtr (&extra_device_tensor) /* params_shape_ptr*/ ,
158- out_desc .get (),
159- GetBasePtr (out ));
177+ trans_out_desc .get (),
178+ GetBasePtr (&trans_out ));
160179 } else {
161180 MLUCnnl::PoolingForward (ctx,
162181 pool_mode,
163182 out_h,
164183 out_w,
165184 pool_desc.get (),
166185 nullptr /* alpha*/ ,
167- in_x_desc .get (),
168- GetBasePtr (in_x ),
186+ trans_in_x_desc .get (),
187+ GetBasePtr (&trans_in_x ),
169188 nullptr /* beta*/ ,
170189 nullptr /* params_shape_ptr*/ ,
171- out_desc .get (),
172- GetBasePtr (out ));
190+ trans_out_desc .get (),
191+ GetBasePtr (&trans_out ));
173192 }
174193 } else {
175- // cnnl Adaptive pooling only support NHWC layout
176- framework::Tensor trans_in_x;
177- framework::Tensor trans_out;
178- if (channel_last) {
179- trans_in_x = *in_x;
180- trans_out = *out;
181- } else {
182- std::vector<int > perm{0 , 2 , 3 , 1 };
183- TransposeFromMLUTensor<T>(
184- ctx, perm, in_x, &trans_in_x, true /* need_reshape_or_alloc*/ );
185- trans_out = ctx.AllocateTmpTensor <T, MLUDeviceContext>(
186- {out_dims[0 ], out_dims[2 ], out_dims[3 ], out_dims[1 ]}, dev_ctx);
187- }
188- MLUCnnlTensorDesc trans_in_x_desc (
189- trans_in_x, CNNL_LAYOUT_NHWC, ToCnnlDataType<T>());
190- MLUCnnlTensorDesc trans_out_desc (
191- trans_out, CNNL_LAYOUT_NHWC, ToCnnlDataType<T>());
192194 MLUCnnl::AdaptivePoolingForward (ctx,
193195 pool_mode,
194196 trans_in_x_desc.get (),
@@ -197,11 +199,11 @@ class MLUPoolOpKernel : public framework::OpKernel<T> {
197199 GetBasePtr (&trans_out),
198200 nullptr ,
199201 nullptr );
200- if (!channel_last) {
201- std::vector< int > perm{ 0 , 3 , 1 , 2 };
202- TransposeFromMLUTensor<T>(
203- ctx, perm, &trans_out, out, false /* need_reshape_or_alloc */ );
204- }
202+ }
203+ if (!channel_last) {
204+ std::vector< int > perm{ 0 , 3 , 1 , 2 };
205+ TransposeFromMLUTensor<T>(
206+ ctx, perm, &trans_out, out, false /* need_reshape_or_alloc */ );
205207 }
206208 }
207209};
0 commit comments