@@ -13,13 +13,17 @@ See the License for the specific language governing permissions and
1313limitations under the License. */
1414
1515#include " paddle/fluid/operators/conv_transpose_op.h"
16- # include < memory >
16+
1717#include < string>
1818#include < vector>
1919#include " paddle/fluid/framework/data_layout.h"
20+ #include " paddle/fluid/framework/infershape_utils.h"
21+ #include " paddle/fluid/framework/op_registry.h"
2022#include " paddle/fluid/framework/op_version_registry.h"
2123#include " paddle/fluid/platform/cudnn_workspace_helper.h"
22-
24+ #include " paddle/phi/core/infermeta_utils.h"
25+ #include " paddle/phi/infermeta/backward.h"
26+ #include " paddle/phi/infermeta/binary.h"
2327#ifdef PADDLE_WITH_MKLDNN
2428#include " paddle/fluid/platform/mkldnn_helper.h"
2529#endif
@@ -29,165 +33,6 @@ namespace operators {
2933
3034using DataLayout = framework::DataLayout;
3135
32- void ConvTransposeOp::InferShape (framework::InferShapeContext* ctx) const {
33- OP_INOUT_CHECK (ctx->HasInput (" Input" ), " Input" , " Input" , " ConvTranspose" );
34- OP_INOUT_CHECK (ctx->HasInput (" Filter" ), " Input" , " Filter" , " ConvTranspose" );
35- OP_INOUT_CHECK (ctx->HasOutput (" Output" ), " Output" , " Output" , " ConvTranspose" );
36-
37- auto in_dims = ctx->GetInputDim (" Input" );
38- auto filter_dims = ctx->GetInputDim (" Filter" );
39- std::vector<int > output_size =
40- ctx->Attrs ().Get <std::vector<int >>(" output_size" );
41- std::vector<int > output_padding =
42- ctx->Attrs ().Get <std::vector<int >>(" output_padding" );
43- std::vector<int > strides = ctx->Attrs ().Get <std::vector<int >>(" strides" );
44- std::vector<int > paddings = ctx->Attrs ().Get <std::vector<int >>(" paddings" );
45- std::vector<int > dilations = ctx->Attrs ().Get <std::vector<int >>(" dilations" );
46- int groups = ctx->Attrs ().Get <int >(" groups" );
47- std::string padding_algorithm =
48- ctx->Attrs ().Get <std::string>(" padding_algorithm" );
49- const std::string data_layout_str =
50- ctx->Attrs ().Get <std::string>(" data_format" );
51- const DataLayout data_layout =
52- ctx->IsRunMKLDNNKernel () ? DataLayout::kNCHW
53- : framework::StringToDataLayout (data_layout_str);
54-
55- PADDLE_ENFORCE_EQ (in_dims.size () == 4 || in_dims.size () == 5 , true ,
56- platform::errors::InvalidArgument (
57- " Input of Op(conv_transpose) should be 4-D or "
58- " 5-D Tensor. But received: %u-D Tensor, "
59- " the shape of input is [%s]" ,
60- in_dims.size (), in_dims));
61- PADDLE_ENFORCE_EQ (
62- in_dims.size (), filter_dims.size (),
63- platform::errors::InvalidArgument (
64- " The input's dimension size and filter's dimension size of "
65- " Op (conv_transpose) should be equal. But received: the shape of "
66- " input is [%s], the dimension size of input is [%d], the shape "
67- " of filter is [%s], the dimension size of filter is [%d]. " ,
68- in_dims, in_dims.size (), filter_dims, filter_dims.size ()));
69-
70- int stride_size = strides.size ();
71- for (int i = 0 ; i < stride_size; ++i) {
72- PADDLE_ENFORCE_GT (
73- strides[i], 0 ,
74- platform::errors::InvalidArgument (
75- " The stride of Op(Conv) should be larget than 0, but received "
76- " stride is %d." ,
77- strides[i]));
78- }
79-
80- int in_sub_stride_size = in_dims.size () - stride_size;
81-
82- PADDLE_ENFORCE_EQ (
83- in_dims.size () - strides.size (), 2U ,
84- platform::errors::InvalidArgument (
85- " The input's dimension size minus Attr(stride)'s size must "
86- " be euqal to 2 for Op(conv_transpose). But received: [%d], the "
87- " input's dimension size is [%d], the shape of input "
88- " is [%s], the Attr(stride)'s size is [%d]." ,
89- in_sub_stride_size, in_dims.size (), in_dims, strides.size ()));
90- if (output_size.size ())
91- PADDLE_ENFORCE_EQ (
92- output_size.size (), strides.size (),
93- platform::errors::InvalidArgument (
94- " The Attr(output_size) and Attr(stride) of Op(conv_transpose) "
95- " should be the same." ));
96- if (output_padding.size ())
97- PADDLE_ENFORCE_EQ (
98- output_padding.size (), strides.size (),
99- platform::errors::InvalidArgument (
100- " The Attr(output_padding) and Attr(stride) of Op(conv_transpose) "
101- " should be the same." ));
102-
103- const int64_t C =
104- (data_layout != DataLayout::kNHWC ? in_dims[1 ]
105- : in_dims[in_dims.size () - 1 ]);
106- PADDLE_ENFORCE_EQ (
107- C, filter_dims[0 ],
108- platform::errors::InvalidArgument (
109- " The number of input channels should be equal to filter channels "
110- " for Op(conv_transpose). But received: the input's channels is "
111- " [%d], the shape of input is [%s], the filter's channels is [%d], "
112- " the shape of filter is [%s]. The data_format is %s."
113- " The error may come from wrong data_format setting." ,
114- C, in_dims, filter_dims[0 ], filter_dims, data_layout_str));
115-
116- framework::DDim in_data_dims;
117- if (data_layout != DataLayout::kNHWC ) {
118- in_data_dims = phi::slice_ddim (in_dims, 2 , in_dims.size ());
119- } else {
120- in_data_dims = phi::slice_ddim (in_dims, 1 , in_dims.size () - 1 );
121- }
122- framework::DDim filter_data_dims =
123- phi::slice_ddim (filter_dims, 2 , filter_dims.size ());
124- std::vector<int > ksize = phi::vectorize<int >(filter_data_dims);
125- UpdatePaddingAndDilation (&paddings, &dilations, padding_algorithm,
126- in_data_dims, strides, ksize);
127-
128- std::vector<int64_t > output_shape ({in_dims[0 ]});
129- if (data_layout != DataLayout::kNHWC ) {
130- output_shape.push_back (filter_dims[1 ] * groups);
131- }
132- const int offset = (data_layout != DataLayout::kNHWC ? 2 : 1 );
133- for (size_t i = 0 ; i < strides.size (); ++i) {
134- auto filter_extent = dilations[i] * (filter_dims[i + 2 ] - 1 ) + 1 ;
135- auto infer_shape = (ctx->IsRuntime () || in_dims[i + offset] > 0 )
136- ? (in_dims[i + offset] - 1 ) * strides[i] -
137- paddings[2 * i] - paddings[2 * i + 1 ] +
138- filter_extent
139- : -1 ;
140- if (output_size.size ()) {
141- if (ctx->IsRuntime ()) {
142- PADDLE_ENFORCE_GE (
143- output_size[i], infer_shape,
144- platform::errors::InvalidArgument (
145- " output_size of Op(ConvTransposeOp) should not be "
146- " less than the infered output size. But received output_size = "
147- " [%s], whose dim %d is less than the infered output size [%s]" ,
148- phi::make_ddim (output_size).to_str (), i, infer_shape));
149- PADDLE_ENFORCE_LT (
150- output_size[i], infer_shape + strides[i],
151- platform::errors::InvalidArgument (
152- " output_size of Op(ConvTransposeOp) should be less "
153- " than infered size + stride. But received output_size = [%s], "
154- " whose dim %d is not less than the infered output size (%d) + "
155- " stride (%d) = %d" ,
156- phi::make_ddim (output_size).to_str (), i, infer_shape,
157- strides[i], infer_shape + strides[i]));
158- }
159- output_shape.push_back (output_size[i]);
160- } else if (output_padding.size ()) {
161- if (ctx->IsRuntime ()) {
162- PADDLE_ENFORCE_GE (
163- output_padding[i], 0 ,
164- platform::errors::InvalidArgument (
165- " output_padding of Op(ConvTransposeOp) should not be "
166- " less than the 0. But received output_padding = "
167- " [%s], whose dim %d is less than 0" ,
168- phi::make_ddim (output_padding).to_str (), i));
169- PADDLE_ENFORCE_LT (
170- output_padding[i], std::max (strides[i], dilations[i]),
171- platform::errors::InvalidArgument (
172- " output_padding of Op(ConvTransposeOp) should be less "
173- " than either stride or dilation. But received output_size = "
174- " [%s], "
175- " whose dim %d is not less than either stride (%d) or "
176- " dilation (%d)" ,
177- phi::make_ddim (output_size).to_str (), i, strides[i],
178- dilations[i]));
179- }
180- output_shape.push_back ((infer_shape + output_padding[i]));
181- } else {
182- output_shape.push_back (infer_shape);
183- }
184- }
185- if (data_layout == DataLayout::kNHWC ) {
186- output_shape.push_back (filter_dims[1 ] * groups);
187- }
188- ctx->SetOutputDim (" Output" , phi::make_ddim (output_shape));
189- }
190-
19136framework::OpKernelType ConvTransposeOp::GetExpectedKernelType (
19237 const framework::ExecutionContext& ctx) const {
19338 framework::LibraryType library_{framework::LibraryType::kPlain };
@@ -217,7 +62,7 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
21762}
21863
21964framework::OpKernelType ConvTransposeOp::GetKernelTypeForVar (
220- const std::string& var_name, const Tensor& tensor,
65+ const std::string& var_name, const framework:: Tensor& tensor,
22166 const framework::OpKernelType& expected_kernel_type) const {
22267#ifdef PADDLE_WITH_MKLDNN
22368 // Only input require reshaping, weights and
@@ -493,17 +338,6 @@ The input(X) size and output(Out) size may be different.
493338)DOC" );
494339}
495340
496- void ConvTransposeOpGrad::InferShape (framework::InferShapeContext* ctx) const {
497- auto in_dims = ctx->GetInputDim (" Input" );
498- auto filter_dims = ctx->GetInputDim (" Filter" );
499- if (ctx->HasOutput (framework::GradVarName (" Input" ))) {
500- ctx->SetOutputDim (framework::GradVarName (" Input" ), in_dims);
501- }
502- if (ctx->HasOutput (framework::GradVarName (" Filter" ))) {
503- ctx->SetOutputDim (framework::GradVarName (" Filter" ), filter_dims);
504- }
505- }
506-
507341framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType (
508342 const framework::ExecutionContext& ctx) const {
509343 bool use_cudnn =
@@ -587,24 +421,6 @@ class ConvTransposeDoubleGradMaker : public framework::SingleGradOpMaker<T> {
587421 }
588422};
589423
590- void ConvTransposeOpDoubleGrad::InferShape (
591- framework::InferShapeContext* ctx) const {
592- auto x_dims = ctx->GetInputDim (" Input" );
593- auto w_dims = ctx->GetInputDim (" Filter" );
594- auto do_dims = ctx->GetInputDim (" DOutput" );
595-
596- if (ctx->HasOutput (" DDOutput" ) &&
597- (ctx->HasInput (" DDInput" ) || (ctx->HasInput (" DDFilter" )))) {
598- ctx->SetOutputDim (" DDOutput" , do_dims);
599- }
600- if (ctx->HasOutput (" DFilter" ) && ctx->HasInput (" DDInput" )) {
601- ctx->SetOutputDim (" DFilter" , w_dims);
602- }
603- if (ctx->HasOutput (" DInput" ) && ctx->HasInput (" DDFilter" )) {
604- ctx->SetOutputDim (" DInput" , x_dims);
605- }
606- }
607-
608424framework::OpKernelType ConvTransposeOpDoubleGrad::GetExpectedKernelType (
609425 const framework::ExecutionContext& ctx) const {
610426 bool use_cudnn =
@@ -635,59 +451,57 @@ framework::OpKernelType ConvTransposeOpDoubleGrad::GetExpectedKernelType(
635451namespace ops = paddle::operators;
636452
637453// conv2d_transpose
454+ DECLARE_INFER_SHAPE_FUNCTOR (conv2d_transpose, Conv2dTranposeInferShapeFunctor,
455+ PD_INFER_META (phi::ConvTransposeInferMeta));
456+ DECLARE_INFER_SHAPE_FUNCTOR (conv2d_transpose_grad,
457+ Conv2dTranposeGradInferShapeFunctor,
458+ PD_INFER_META (phi::ConvTransposeGradInferMeta));
459+ DECLARE_INFER_SHAPE_FUNCTOR (
460+ conv2d_transpose_grad_grad, Conv2dTranposeDoubleGradInferShapeFunctor,
461+ PD_INFER_META (phi::Conv2dTransposeDoubleGradInferMeta));
462+
638463REGISTER_OPERATOR (conv2d_transpose, ops::ConvTransposeOp,
639464 ops::Conv2DTransposeOpMaker,
640465 ops::ConvTransposeGradOpMaker<paddle::framework::OpDesc>,
641- ops::ConvTransposeGradOpMaker<paddle::imperative::OpBase>);
642- REGISTER_OPERATOR (
643- conv2d_transpose_grad, ops::ConvTransposeOpGrad,
644- ops::ConvTransposeDoubleGradMaker<paddle::framework::OpDesc>,
645- ops::ConvTransposeDoubleGradMaker<paddle::imperative::OpBase>);
646- REGISTER_OPERATOR (conv2d_transpose_grad_grad, ops::ConvTransposeOpDoubleGrad);
647-
648- REGISTER_OP_CPU_KERNEL (
649- conv2d_transpose,
650- ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, float >,
651- ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, double >);
652- REGISTER_OP_CPU_KERNEL (
653- conv2d_transpose_grad,
654- ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext, float >,
655- ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext,
656- double >);
466+ ops::ConvTransposeGradOpMaker<paddle::imperative::OpBase>,
467+ Conv2dTranposeInferShapeFunctor);
468+ REGISTER_OPERATOR (conv2d_transpose_grad, ops::ConvTransposeOpGrad,
469+ ops::ConvTransposeDoubleGradMaker<paddle::framework::OpDesc>,
470+ ops::ConvTransposeDoubleGradMaker<paddle::imperative::OpBase>,
471+ Conv2dTranposeGradInferShapeFunctor);
472+ REGISTER_OPERATOR (conv2d_transpose_grad_grad, ops::ConvTransposeOpDoubleGrad,
473+ Conv2dTranposeDoubleGradInferShapeFunctor);
657474
658475// conv3d_transpose
476+ DECLARE_INFER_SHAPE_FUNCTOR (conv3d_transpose, Conv3dTranposeInferShapeFunctor,
477+ PD_INFER_META (phi::ConvTransposeInferMeta));
478+ DECLARE_INFER_SHAPE_FUNCTOR (conv3d_transpose_grad,
479+ Conv3dTranposeGradInferShapeFunctor,
480+ PD_INFER_META (phi::ConvTransposeGradInferMeta));
481+
659482REGISTER_OPERATOR (conv3d_transpose, ops::ConvTransposeOp,
660483 ops::Conv3DTransposeOpMaker,
661484 ops::ConvTransposeGradOpMaker<paddle::framework::OpDesc>,
662- ops::ConvTransposeGradOpMaker<paddle::imperative::OpBase>);
663- REGISTER_OPERATOR (conv3d_transpose_grad, ops::ConvTransposeOpGrad);
664-
665- REGISTER_OP_CPU_KERNEL (
666- conv3d_transpose,
667- ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, float >,
668- ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, double >);
669- REGISTER_OP_CPU_KERNEL (
670- conv3d_transpose_grad,
671- ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext, float >,
672- ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext,
673- double >);
485+ ops::ConvTransposeGradOpMaker<paddle::imperative::OpBase>,
486+ Conv3dTranposeInferShapeFunctor);
487+ REGISTER_OPERATOR (conv3d_transpose_grad, ops::ConvTransposeOpGrad,
488+ Conv3dTranposeGradInferShapeFunctor);
674489
675490// depthwise conv2d_transpose
491+ DECLARE_INFER_SHAPE_FUNCTOR (depthwise_conv2d_transpose,
492+ DepthWiseConv2dTranposeInferShapeFunctor,
493+ PD_INFER_META (phi::ConvTransposeInferMeta));
494+ DECLARE_INFER_SHAPE_FUNCTOR (depthwise_conv2d_transpose_grad,
495+ DepthWiseConv2dTranposeGradInferShapeFunctor,
496+ PD_INFER_META (phi::ConvTransposeGradInferMeta));
497+
676498REGISTER_OPERATOR (depthwise_conv2d_transpose, ops::ConvTransposeOp,
677499 ops::Conv2DTransposeOpMaker,
678500 ops::ConvTransposeGradOpMaker<paddle::framework::OpDesc>,
679- ops::ConvTransposeGradOpMaker<paddle::imperative::OpBase>);
680- REGISTER_OPERATOR (depthwise_conv2d_transpose_grad, ops::ConvTransposeOpGrad);
681-
682- REGISTER_OP_CPU_KERNEL (
683- depthwise_conv2d_transpose,
684- ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, float >,
685- ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, double >);
686- REGISTER_OP_CPU_KERNEL (
687- depthwise_conv2d_transpose_grad,
688- ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext, float >,
689- ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext,
690- double >);
501+ ops::ConvTransposeGradOpMaker<paddle::imperative::OpBase>,
502+ DepthWiseConv2dTranposeInferShapeFunctor);
503+ REGISTER_OPERATOR (depthwise_conv2d_transpose_grad, ops::ConvTransposeOpGrad,
504+ DepthWiseConv2dTranposeGradInferShapeFunctor);
691505
692506REGISTER_OP_VERSION (conv_transpose)
693507 .AddCheckpoint(
0 commit comments