Skip to content

Commit 1eb96ee

Browse files
authored
Move conv-transpose OPs to phi (#40675)
* Move conv-transpose OPs to phi * Fix CI errors * Fix CI errors
1 parent b77e20a commit 1eb96ee

24 files changed

+3294
-2275
lines changed

paddle/fluid/framework/ir/mkldnn/mkldnn_conv_bn_fuse_pass_tester.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
USE_OP_ITSELF(batch_norm);
3030
USE_OP_DEVICE_KERNEL(batch_norm, MKLDNN);
31-
USE_OP(conv2d_transpose);
31+
USE_OP_ITSELF(conv2d_transpose);
3232
USE_OP_DEVICE_KERNEL(conv2d_transpose, MKLDNN);
3333
USE_OP_ITSELF(elementwise_add);
3434
USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN);

paddle/fluid/inference/tensorrt/convert/test_conv2d_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ limitations under the License. */
1717
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
1818

1919
USE_OP_ITSELF(conv2d);
20-
USE_OP(conv2d_transpose);
20+
USE_OP_ITSELF(conv2d_transpose);
2121

2222
namespace paddle {
2323
namespace inference {

paddle/fluid/operators/conv_transpose_cudnn_op.cu

Lines changed: 0 additions & 1286 deletions
This file was deleted.

paddle/fluid/operators/conv_transpose_op.cc

Lines changed: 45 additions & 231 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,17 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/operators/conv_transpose_op.h"
16-
#include <memory>
16+
1717
#include <string>
1818
#include <vector>
1919
#include "paddle/fluid/framework/data_layout.h"
20+
#include "paddle/fluid/framework/infershape_utils.h"
21+
#include "paddle/fluid/framework/op_registry.h"
2022
#include "paddle/fluid/framework/op_version_registry.h"
2123
#include "paddle/fluid/platform/cudnn_workspace_helper.h"
22-
24+
#include "paddle/phi/core/infermeta_utils.h"
25+
#include "paddle/phi/infermeta/backward.h"
26+
#include "paddle/phi/infermeta/binary.h"
2327
#ifdef PADDLE_WITH_MKLDNN
2428
#include "paddle/fluid/platform/mkldnn_helper.h"
2529
#endif
@@ -29,165 +33,6 @@ namespace operators {
2933

3034
using DataLayout = framework::DataLayout;
3135

32-
void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
33-
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "ConvTranspose");
34-
OP_INOUT_CHECK(ctx->HasInput("Filter"), "Input", "Filter", "ConvTranspose");
35-
OP_INOUT_CHECK(ctx->HasOutput("Output"), "Output", "Output", "ConvTranspose");
36-
37-
auto in_dims = ctx->GetInputDim("Input");
38-
auto filter_dims = ctx->GetInputDim("Filter");
39-
std::vector<int> output_size =
40-
ctx->Attrs().Get<std::vector<int>>("output_size");
41-
std::vector<int> output_padding =
42-
ctx->Attrs().Get<std::vector<int>>("output_padding");
43-
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
44-
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
45-
std::vector<int> dilations = ctx->Attrs().Get<std::vector<int>>("dilations");
46-
int groups = ctx->Attrs().Get<int>("groups");
47-
std::string padding_algorithm =
48-
ctx->Attrs().Get<std::string>("padding_algorithm");
49-
const std::string data_layout_str =
50-
ctx->Attrs().Get<std::string>("data_format");
51-
const DataLayout data_layout =
52-
ctx->IsRunMKLDNNKernel() ? DataLayout::kNCHW
53-
: framework::StringToDataLayout(data_layout_str);
54-
55-
PADDLE_ENFORCE_EQ(in_dims.size() == 4 || in_dims.size() == 5, true,
56-
platform::errors::InvalidArgument(
57-
"Input of Op(conv_transpose) should be 4-D or "
58-
"5-D Tensor. But received: %u-D Tensor, "
59-
"the shape of input is [%s]",
60-
in_dims.size(), in_dims));
61-
PADDLE_ENFORCE_EQ(
62-
in_dims.size(), filter_dims.size(),
63-
platform::errors::InvalidArgument(
64-
"The input's dimension size and filter's dimension size of "
65-
"Op (conv_transpose) should be equal. But received: the shape of "
66-
"input is [%s], the dimension size of input is [%d], the shape "
67-
"of filter is [%s], the dimension size of filter is [%d]. ",
68-
in_dims, in_dims.size(), filter_dims, filter_dims.size()));
69-
70-
int stride_size = strides.size();
71-
for (int i = 0; i < stride_size; ++i) {
72-
PADDLE_ENFORCE_GT(
73-
strides[i], 0,
74-
platform::errors::InvalidArgument(
75-
"The stride of Op(Conv) should be larget than 0, but received "
76-
"stride is %d.",
77-
strides[i]));
78-
}
79-
80-
int in_sub_stride_size = in_dims.size() - stride_size;
81-
82-
PADDLE_ENFORCE_EQ(
83-
in_dims.size() - strides.size(), 2U,
84-
platform::errors::InvalidArgument(
85-
"The input's dimension size minus Attr(stride)'s size must "
86-
"be euqal to 2 for Op(conv_transpose). But received: [%d], the "
87-
"input's dimension size is [%d], the shape of input "
88-
"is [%s], the Attr(stride)'s size is [%d].",
89-
in_sub_stride_size, in_dims.size(), in_dims, strides.size()));
90-
if (output_size.size())
91-
PADDLE_ENFORCE_EQ(
92-
output_size.size(), strides.size(),
93-
platform::errors::InvalidArgument(
94-
"The Attr(output_size) and Attr(stride) of Op(conv_transpose) "
95-
"should be the same."));
96-
if (output_padding.size())
97-
PADDLE_ENFORCE_EQ(
98-
output_padding.size(), strides.size(),
99-
platform::errors::InvalidArgument(
100-
"The Attr(output_padding) and Attr(stride) of Op(conv_transpose) "
101-
"should be the same."));
102-
103-
const int64_t C =
104-
(data_layout != DataLayout::kNHWC ? in_dims[1]
105-
: in_dims[in_dims.size() - 1]);
106-
PADDLE_ENFORCE_EQ(
107-
C, filter_dims[0],
108-
platform::errors::InvalidArgument(
109-
"The number of input channels should be equal to filter channels "
110-
"for Op(conv_transpose). But received: the input's channels is "
111-
"[%d], the shape of input is [%s], the filter's channels is [%d], "
112-
"the shape of filter is [%s]. The data_format is %s."
113-
"The error may come from wrong data_format setting.",
114-
C, in_dims, filter_dims[0], filter_dims, data_layout_str));
115-
116-
framework::DDim in_data_dims;
117-
if (data_layout != DataLayout::kNHWC) {
118-
in_data_dims = phi::slice_ddim(in_dims, 2, in_dims.size());
119-
} else {
120-
in_data_dims = phi::slice_ddim(in_dims, 1, in_dims.size() - 1);
121-
}
122-
framework::DDim filter_data_dims =
123-
phi::slice_ddim(filter_dims, 2, filter_dims.size());
124-
std::vector<int> ksize = phi::vectorize<int>(filter_data_dims);
125-
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
126-
in_data_dims, strides, ksize);
127-
128-
std::vector<int64_t> output_shape({in_dims[0]});
129-
if (data_layout != DataLayout::kNHWC) {
130-
output_shape.push_back(filter_dims[1] * groups);
131-
}
132-
const int offset = (data_layout != DataLayout::kNHWC ? 2 : 1);
133-
for (size_t i = 0; i < strides.size(); ++i) {
134-
auto filter_extent = dilations[i] * (filter_dims[i + 2] - 1) + 1;
135-
auto infer_shape = (ctx->IsRuntime() || in_dims[i + offset] > 0)
136-
? (in_dims[i + offset] - 1) * strides[i] -
137-
paddings[2 * i] - paddings[2 * i + 1] +
138-
filter_extent
139-
: -1;
140-
if (output_size.size()) {
141-
if (ctx->IsRuntime()) {
142-
PADDLE_ENFORCE_GE(
143-
output_size[i], infer_shape,
144-
platform::errors::InvalidArgument(
145-
"output_size of Op(ConvTransposeOp) should not be "
146-
"less than the infered output size. But received output_size = "
147-
"[%s], whose dim %d is less than the infered output size [%s]",
148-
phi::make_ddim(output_size).to_str(), i, infer_shape));
149-
PADDLE_ENFORCE_LT(
150-
output_size[i], infer_shape + strides[i],
151-
platform::errors::InvalidArgument(
152-
"output_size of Op(ConvTransposeOp) should be less "
153-
"than infered size + stride. But received output_size = [%s], "
154-
"whose dim %d is not less than the infered output size (%d) + "
155-
"stride (%d) = %d",
156-
phi::make_ddim(output_size).to_str(), i, infer_shape,
157-
strides[i], infer_shape + strides[i]));
158-
}
159-
output_shape.push_back(output_size[i]);
160-
} else if (output_padding.size()) {
161-
if (ctx->IsRuntime()) {
162-
PADDLE_ENFORCE_GE(
163-
output_padding[i], 0,
164-
platform::errors::InvalidArgument(
165-
"output_padding of Op(ConvTransposeOp) should not be "
166-
"less than the 0. But received output_padding = "
167-
"[%s], whose dim %d is less than 0",
168-
phi::make_ddim(output_padding).to_str(), i));
169-
PADDLE_ENFORCE_LT(
170-
output_padding[i], std::max(strides[i], dilations[i]),
171-
platform::errors::InvalidArgument(
172-
"output_padding of Op(ConvTransposeOp) should be less "
173-
"than either stride or dilation. But received output_size = "
174-
"[%s], "
175-
"whose dim %d is not less than either stride (%d) or "
176-
"dilation (%d)",
177-
phi::make_ddim(output_size).to_str(), i, strides[i],
178-
dilations[i]));
179-
}
180-
output_shape.push_back((infer_shape + output_padding[i]));
181-
} else {
182-
output_shape.push_back(infer_shape);
183-
}
184-
}
185-
if (data_layout == DataLayout::kNHWC) {
186-
output_shape.push_back(filter_dims[1] * groups);
187-
}
188-
ctx->SetOutputDim("Output", phi::make_ddim(output_shape));
189-
}
190-
19136
framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
19237
const framework::ExecutionContext& ctx) const {
19338
framework::LibraryType library_{framework::LibraryType::kPlain};
@@ -217,7 +62,7 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
21762
}
21863

21964
framework::OpKernelType ConvTransposeOp::GetKernelTypeForVar(
220-
const std::string& var_name, const Tensor& tensor,
65+
const std::string& var_name, const framework::Tensor& tensor,
22166
const framework::OpKernelType& expected_kernel_type) const {
22267
#ifdef PADDLE_WITH_MKLDNN
22368
// Only input require reshaping, weights and
@@ -493,17 +338,6 @@ The input(X) size and output(Out) size may be different.
493338
)DOC");
494339
}
495340

496-
void ConvTransposeOpGrad::InferShape(framework::InferShapeContext* ctx) const {
497-
auto in_dims = ctx->GetInputDim("Input");
498-
auto filter_dims = ctx->GetInputDim("Filter");
499-
if (ctx->HasOutput(framework::GradVarName("Input"))) {
500-
ctx->SetOutputDim(framework::GradVarName("Input"), in_dims);
501-
}
502-
if (ctx->HasOutput(framework::GradVarName("Filter"))) {
503-
ctx->SetOutputDim(framework::GradVarName("Filter"), filter_dims);
504-
}
505-
}
506-
507341
framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType(
508342
const framework::ExecutionContext& ctx) const {
509343
bool use_cudnn =
@@ -587,24 +421,6 @@ class ConvTransposeDoubleGradMaker : public framework::SingleGradOpMaker<T> {
587421
}
588422
};
589423

590-
void ConvTransposeOpDoubleGrad::InferShape(
591-
framework::InferShapeContext* ctx) const {
592-
auto x_dims = ctx->GetInputDim("Input");
593-
auto w_dims = ctx->GetInputDim("Filter");
594-
auto do_dims = ctx->GetInputDim("DOutput");
595-
596-
if (ctx->HasOutput("DDOutput") &&
597-
(ctx->HasInput("DDInput") || (ctx->HasInput("DDFilter")))) {
598-
ctx->SetOutputDim("DDOutput", do_dims);
599-
}
600-
if (ctx->HasOutput("DFilter") && ctx->HasInput("DDInput")) {
601-
ctx->SetOutputDim("DFilter", w_dims);
602-
}
603-
if (ctx->HasOutput("DInput") && ctx->HasInput("DDFilter")) {
604-
ctx->SetOutputDim("DInput", x_dims);
605-
}
606-
}
607-
608424
framework::OpKernelType ConvTransposeOpDoubleGrad::GetExpectedKernelType(
609425
const framework::ExecutionContext& ctx) const {
610426
bool use_cudnn =
@@ -635,59 +451,57 @@ framework::OpKernelType ConvTransposeOpDoubleGrad::GetExpectedKernelType(
635451
namespace ops = paddle::operators;
636452

637453
// conv2d_transpose
454+
DECLARE_INFER_SHAPE_FUNCTOR(conv2d_transpose, Conv2dTranposeInferShapeFunctor,
455+
PD_INFER_META(phi::ConvTransposeInferMeta));
456+
DECLARE_INFER_SHAPE_FUNCTOR(conv2d_transpose_grad,
457+
Conv2dTranposeGradInferShapeFunctor,
458+
PD_INFER_META(phi::ConvTransposeGradInferMeta));
459+
DECLARE_INFER_SHAPE_FUNCTOR(
460+
conv2d_transpose_grad_grad, Conv2dTranposeDoubleGradInferShapeFunctor,
461+
PD_INFER_META(phi::Conv2dTransposeDoubleGradInferMeta));
462+
638463
REGISTER_OPERATOR(conv2d_transpose, ops::ConvTransposeOp,
639464
ops::Conv2DTransposeOpMaker,
640465
ops::ConvTransposeGradOpMaker<paddle::framework::OpDesc>,
641-
ops::ConvTransposeGradOpMaker<paddle::imperative::OpBase>);
642-
REGISTER_OPERATOR(
643-
conv2d_transpose_grad, ops::ConvTransposeOpGrad,
644-
ops::ConvTransposeDoubleGradMaker<paddle::framework::OpDesc>,
645-
ops::ConvTransposeDoubleGradMaker<paddle::imperative::OpBase>);
646-
REGISTER_OPERATOR(conv2d_transpose_grad_grad, ops::ConvTransposeOpDoubleGrad);
647-
648-
REGISTER_OP_CPU_KERNEL(
649-
conv2d_transpose,
650-
ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, float>,
651-
ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, double>);
652-
REGISTER_OP_CPU_KERNEL(
653-
conv2d_transpose_grad,
654-
ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
655-
ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext,
656-
double>);
466+
ops::ConvTransposeGradOpMaker<paddle::imperative::OpBase>,
467+
Conv2dTranposeInferShapeFunctor);
468+
REGISTER_OPERATOR(conv2d_transpose_grad, ops::ConvTransposeOpGrad,
469+
ops::ConvTransposeDoubleGradMaker<paddle::framework::OpDesc>,
470+
ops::ConvTransposeDoubleGradMaker<paddle::imperative::OpBase>,
471+
Conv2dTranposeGradInferShapeFunctor);
472+
REGISTER_OPERATOR(conv2d_transpose_grad_grad, ops::ConvTransposeOpDoubleGrad,
473+
Conv2dTranposeDoubleGradInferShapeFunctor);
657474

658475
// conv3d_transpose
476+
DECLARE_INFER_SHAPE_FUNCTOR(conv3d_transpose, Conv3dTranposeInferShapeFunctor,
477+
PD_INFER_META(phi::ConvTransposeInferMeta));
478+
DECLARE_INFER_SHAPE_FUNCTOR(conv3d_transpose_grad,
479+
Conv3dTranposeGradInferShapeFunctor,
480+
PD_INFER_META(phi::ConvTransposeGradInferMeta));
481+
659482
REGISTER_OPERATOR(conv3d_transpose, ops::ConvTransposeOp,
660483
ops::Conv3DTransposeOpMaker,
661484
ops::ConvTransposeGradOpMaker<paddle::framework::OpDesc>,
662-
ops::ConvTransposeGradOpMaker<paddle::imperative::OpBase>);
663-
REGISTER_OPERATOR(conv3d_transpose_grad, ops::ConvTransposeOpGrad);
664-
665-
REGISTER_OP_CPU_KERNEL(
666-
conv3d_transpose,
667-
ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, float>,
668-
ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, double>);
669-
REGISTER_OP_CPU_KERNEL(
670-
conv3d_transpose_grad,
671-
ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
672-
ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext,
673-
double>);
485+
ops::ConvTransposeGradOpMaker<paddle::imperative::OpBase>,
486+
Conv3dTranposeInferShapeFunctor);
487+
REGISTER_OPERATOR(conv3d_transpose_grad, ops::ConvTransposeOpGrad,
488+
Conv3dTranposeGradInferShapeFunctor);
674489

675490
// depthwise conv2d_transpose
491+
DECLARE_INFER_SHAPE_FUNCTOR(depthwise_conv2d_transpose,
492+
DepthWiseConv2dTranposeInferShapeFunctor,
493+
PD_INFER_META(phi::ConvTransposeInferMeta));
494+
DECLARE_INFER_SHAPE_FUNCTOR(depthwise_conv2d_transpose_grad,
495+
DepthWiseConv2dTranposeGradInferShapeFunctor,
496+
PD_INFER_META(phi::ConvTransposeGradInferMeta));
497+
676498
REGISTER_OPERATOR(depthwise_conv2d_transpose, ops::ConvTransposeOp,
677499
ops::Conv2DTransposeOpMaker,
678500
ops::ConvTransposeGradOpMaker<paddle::framework::OpDesc>,
679-
ops::ConvTransposeGradOpMaker<paddle::imperative::OpBase>);
680-
REGISTER_OPERATOR(depthwise_conv2d_transpose_grad, ops::ConvTransposeOpGrad);
681-
682-
REGISTER_OP_CPU_KERNEL(
683-
depthwise_conv2d_transpose,
684-
ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, float>,
685-
ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, double>);
686-
REGISTER_OP_CPU_KERNEL(
687-
depthwise_conv2d_transpose_grad,
688-
ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
689-
ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext,
690-
double>);
501+
ops::ConvTransposeGradOpMaker<paddle::imperative::OpBase>,
502+
DepthWiseConv2dTranposeInferShapeFunctor);
503+
REGISTER_OPERATOR(depthwise_conv2d_transpose_grad, ops::ConvTransposeOpGrad,
504+
DepthWiseConv2dTranposeGradInferShapeFunctor);
691505

692506
REGISTER_OP_VERSION(conv_transpose)
693507
.AddCheckpoint(

0 commit comments

Comments
 (0)