Skip to content

Commit db1f3d4

Browse files
ShawnNewAurelius84
authored andcommitted
[MLU] transpose avg_pool2d to NHWC for better performance. (PaddlePaddle#44475)
1 parent cad9755 commit db1f3d4

File tree

1 file changed

+34
-32
lines changed

1 file changed

+34
-32
lines changed

paddle/fluid/operators/pool_op_mlu.cc

Lines changed: 34 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,25 @@ class MLUPoolOpKernel : public framework::OpKernel<T> {
100100
cnnlPoolingMode_t pool_mode =
101101
ToCnnlPoolingMode(pooling_type, exclusive, adaptive);
102102

103+
// transpose NCHW to NHWC since cnnl pool2d has worse performance in that
104+
// layout.
105+
framework::Tensor trans_in_x;
106+
framework::Tensor trans_out;
107+
if (channel_last) {
108+
trans_in_x = *in_x;
109+
trans_out = *out;
110+
} else {
111+
std::vector<int> perm{0, 2, 3, 1};
112+
TransposeFromMLUTensor<T>(
113+
ctx, perm, in_x, &trans_in_x, true /*need_reshape_or_alloc*/);
114+
trans_out = ctx.AllocateTmpTensor<T, MLUDeviceContext>(
115+
{out_dims[0], out_dims[2], out_dims[3], out_dims[1]}, dev_ctx);
116+
}
117+
MLUCnnlTensorDesc trans_in_x_desc(
118+
trans_in_x, CNNL_LAYOUT_NHWC, ToCnnlDataType<T>());
119+
MLUCnnlTensorDesc trans_out_desc(
120+
trans_out, CNNL_LAYOUT_NHWC, ToCnnlDataType<T>());
121+
103122
if (!adaptive) {
104123
MLUCnnlPoolingDesc pool_desc(pool_mode,
105124
CNNL_NOT_PROPAGATE_NAN,
@@ -128,8 +147,8 @@ class MLUPoolOpKernel : public framework::OpKernel<T> {
128147
{static_cast<int64_t>(extra_input_size)}, cpu_ctx);
129148
cnnlInitPoolingExtraInput(handle,
130149
pool_desc.get(),
131-
in_x_desc.get(),
132-
out_desc.get(),
150+
trans_in_x_desc.get(),
151+
trans_out_desc.get(),
133152
GetBasePtr(&extra_host_tensor));
134153
framework::Tensor extra_device_tensor =
135154
ctx.AllocateTmpTensor<int8_t, MLUDeviceContext>(
@@ -151,44 +170,27 @@ class MLUPoolOpKernel : public framework::OpKernel<T> {
151170
out_w,
152171
pool_desc.get(),
153172
nullptr /*alpha*/,
154-
in_x_desc.get(),
155-
GetBasePtr(in_x),
173+
trans_in_x_desc.get(),
174+
GetBasePtr(&trans_in_x),
156175
nullptr /*beta*/,
157176
GetBasePtr(&extra_device_tensor) /*params_shape_ptr*/,
158-
out_desc.get(),
159-
GetBasePtr(out));
177+
trans_out_desc.get(),
178+
GetBasePtr(&trans_out));
160179
} else {
161180
MLUCnnl::PoolingForward(ctx,
162181
pool_mode,
163182
out_h,
164183
out_w,
165184
pool_desc.get(),
166185
nullptr /*alpha*/,
167-
in_x_desc.get(),
168-
GetBasePtr(in_x),
186+
trans_in_x_desc.get(),
187+
GetBasePtr(&trans_in_x),
169188
nullptr /*beta*/,
170189
nullptr /*params_shape_ptr*/,
171-
out_desc.get(),
172-
GetBasePtr(out));
190+
trans_out_desc.get(),
191+
GetBasePtr(&trans_out));
173192
}
174193
} else {
175-
// cnnl Adaptive pooling only support NHWC layout
176-
framework::Tensor trans_in_x;
177-
framework::Tensor trans_out;
178-
if (channel_last) {
179-
trans_in_x = *in_x;
180-
trans_out = *out;
181-
} else {
182-
std::vector<int> perm{0, 2, 3, 1};
183-
TransposeFromMLUTensor<T>(
184-
ctx, perm, in_x, &trans_in_x, true /*need_reshape_or_alloc*/);
185-
trans_out = ctx.AllocateTmpTensor<T, MLUDeviceContext>(
186-
{out_dims[0], out_dims[2], out_dims[3], out_dims[1]}, dev_ctx);
187-
}
188-
MLUCnnlTensorDesc trans_in_x_desc(
189-
trans_in_x, CNNL_LAYOUT_NHWC, ToCnnlDataType<T>());
190-
MLUCnnlTensorDesc trans_out_desc(
191-
trans_out, CNNL_LAYOUT_NHWC, ToCnnlDataType<T>());
192194
MLUCnnl::AdaptivePoolingForward(ctx,
193195
pool_mode,
194196
trans_in_x_desc.get(),
@@ -197,11 +199,11 @@ class MLUPoolOpKernel : public framework::OpKernel<T> {
197199
GetBasePtr(&trans_out),
198200
nullptr,
199201
nullptr);
200-
if (!channel_last) {
201-
std::vector<int> perm{0, 3, 1, 2};
202-
TransposeFromMLUTensor<T>(
203-
ctx, perm, &trans_out, out, false /*need_reshape_or_alloc*/);
204-
}
202+
}
203+
if (!channel_last) {
204+
std::vector<int> perm{0, 3, 1, 2};
205+
TransposeFromMLUTensor<T>(
206+
ctx, perm, &trans_out, out, false /*need_reshape_or_alloc*/);
205207
}
206208
}
207209
};

0 commit comments

Comments
 (0)