@@ -13,13 +13,17 @@ See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
14
15
15
#include " paddle/fluid/operators/conv_transpose_op.h"
16
- # include < memory >
16
+
17
17
#include < string>
18
18
#include < vector>
19
19
#include " paddle/fluid/framework/data_layout.h"
20
+ #include " paddle/fluid/framework/infershape_utils.h"
21
+ #include " paddle/fluid/framework/op_registry.h"
20
22
#include " paddle/fluid/framework/op_version_registry.h"
21
23
#include " paddle/fluid/platform/cudnn_workspace_helper.h"
22
-
24
+ #include " paddle/phi/core/infermeta_utils.h"
25
+ #include " paddle/phi/infermeta/backward.h"
26
+ #include " paddle/phi/infermeta/binary.h"
23
27
#ifdef PADDLE_WITH_MKLDNN
24
28
#include " paddle/fluid/platform/mkldnn_helper.h"
25
29
#endif
@@ -29,165 +33,6 @@ namespace operators {
29
33
30
34
using DataLayout = framework::DataLayout;
31
35
32
- void ConvTransposeOp::InferShape (framework::InferShapeContext* ctx) const {
33
- OP_INOUT_CHECK (ctx->HasInput (" Input" ), " Input" , " Input" , " ConvTranspose" );
34
- OP_INOUT_CHECK (ctx->HasInput (" Filter" ), " Input" , " Filter" , " ConvTranspose" );
35
- OP_INOUT_CHECK (ctx->HasOutput (" Output" ), " Output" , " Output" , " ConvTranspose" );
36
-
37
- auto in_dims = ctx->GetInputDim (" Input" );
38
- auto filter_dims = ctx->GetInputDim (" Filter" );
39
- std::vector<int > output_size =
40
- ctx->Attrs ().Get <std::vector<int >>(" output_size" );
41
- std::vector<int > output_padding =
42
- ctx->Attrs ().Get <std::vector<int >>(" output_padding" );
43
- std::vector<int > strides = ctx->Attrs ().Get <std::vector<int >>(" strides" );
44
- std::vector<int > paddings = ctx->Attrs ().Get <std::vector<int >>(" paddings" );
45
- std::vector<int > dilations = ctx->Attrs ().Get <std::vector<int >>(" dilations" );
46
- int groups = ctx->Attrs ().Get <int >(" groups" );
47
- std::string padding_algorithm =
48
- ctx->Attrs ().Get <std::string>(" padding_algorithm" );
49
- const std::string data_layout_str =
50
- ctx->Attrs ().Get <std::string>(" data_format" );
51
- const DataLayout data_layout =
52
- ctx->IsRunMKLDNNKernel () ? DataLayout::kNCHW
53
- : framework::StringToDataLayout (data_layout_str);
54
-
55
- PADDLE_ENFORCE_EQ (in_dims.size () == 4 || in_dims.size () == 5 , true ,
56
- platform::errors::InvalidArgument (
57
- " Input of Op(conv_transpose) should be 4-D or "
58
- " 5-D Tensor. But received: %u-D Tensor, "
59
- " the shape of input is [%s]" ,
60
- in_dims.size (), in_dims));
61
- PADDLE_ENFORCE_EQ (
62
- in_dims.size (), filter_dims.size (),
63
- platform::errors::InvalidArgument (
64
- " The input's dimension size and filter's dimension size of "
65
- " Op (conv_transpose) should be equal. But received: the shape of "
66
- " input is [%s], the dimension size of input is [%d], the shape "
67
- " of filter is [%s], the dimension size of filter is [%d]. " ,
68
- in_dims, in_dims.size (), filter_dims, filter_dims.size ()));
69
-
70
- int stride_size = strides.size ();
71
- for (int i = 0 ; i < stride_size; ++i) {
72
- PADDLE_ENFORCE_GT (
73
- strides[i], 0 ,
74
- platform::errors::InvalidArgument (
75
- " The stride of Op(Conv) should be larget than 0, but received "
76
- " stride is %d." ,
77
- strides[i]));
78
- }
79
-
80
- int in_sub_stride_size = in_dims.size () - stride_size;
81
-
82
- PADDLE_ENFORCE_EQ (
83
- in_dims.size () - strides.size (), 2U ,
84
- platform::errors::InvalidArgument (
85
- " The input's dimension size minus Attr(stride)'s size must "
86
- " be euqal to 2 for Op(conv_transpose). But received: [%d], the "
87
- " input's dimension size is [%d], the shape of input "
88
- " is [%s], the Attr(stride)'s size is [%d]." ,
89
- in_sub_stride_size, in_dims.size (), in_dims, strides.size ()));
90
- if (output_size.size ())
91
- PADDLE_ENFORCE_EQ (
92
- output_size.size (), strides.size (),
93
- platform::errors::InvalidArgument (
94
- " The Attr(output_size) and Attr(stride) of Op(conv_transpose) "
95
- " should be the same." ));
96
- if (output_padding.size ())
97
- PADDLE_ENFORCE_EQ (
98
- output_padding.size (), strides.size (),
99
- platform::errors::InvalidArgument (
100
- " The Attr(output_padding) and Attr(stride) of Op(conv_transpose) "
101
- " should be the same." ));
102
-
103
- const int64_t C =
104
- (data_layout != DataLayout::kNHWC ? in_dims[1 ]
105
- : in_dims[in_dims.size () - 1 ]);
106
- PADDLE_ENFORCE_EQ (
107
- C, filter_dims[0 ],
108
- platform::errors::InvalidArgument (
109
- " The number of input channels should be equal to filter channels "
110
- " for Op(conv_transpose). But received: the input's channels is "
111
- " [%d], the shape of input is [%s], the filter's channels is [%d], "
112
- " the shape of filter is [%s]. The data_format is %s."
113
- " The error may come from wrong data_format setting." ,
114
- C, in_dims, filter_dims[0 ], filter_dims, data_layout_str));
115
-
116
- framework::DDim in_data_dims;
117
- if (data_layout != DataLayout::kNHWC ) {
118
- in_data_dims = phi::slice_ddim (in_dims, 2 , in_dims.size ());
119
- } else {
120
- in_data_dims = phi::slice_ddim (in_dims, 1 , in_dims.size () - 1 );
121
- }
122
- framework::DDim filter_data_dims =
123
- phi::slice_ddim (filter_dims, 2 , filter_dims.size ());
124
- std::vector<int > ksize = phi::vectorize<int >(filter_data_dims);
125
- UpdatePaddingAndDilation (&paddings, &dilations, padding_algorithm,
126
- in_data_dims, strides, ksize);
127
-
128
- std::vector<int64_t > output_shape ({in_dims[0 ]});
129
- if (data_layout != DataLayout::kNHWC ) {
130
- output_shape.push_back (filter_dims[1 ] * groups);
131
- }
132
- const int offset = (data_layout != DataLayout::kNHWC ? 2 : 1 );
133
- for (size_t i = 0 ; i < strides.size (); ++i) {
134
- auto filter_extent = dilations[i] * (filter_dims[i + 2 ] - 1 ) + 1 ;
135
- auto infer_shape = (ctx->IsRuntime () || in_dims[i + offset] > 0 )
136
- ? (in_dims[i + offset] - 1 ) * strides[i] -
137
- paddings[2 * i] - paddings[2 * i + 1 ] +
138
- filter_extent
139
- : -1 ;
140
- if (output_size.size ()) {
141
- if (ctx->IsRuntime ()) {
142
- PADDLE_ENFORCE_GE (
143
- output_size[i], infer_shape,
144
- platform::errors::InvalidArgument (
145
- " output_size of Op(ConvTransposeOp) should not be "
146
- " less than the infered output size. But received output_size = "
147
- " [%s], whose dim %d is less than the infered output size [%s]" ,
148
- phi::make_ddim (output_size).to_str (), i, infer_shape));
149
- PADDLE_ENFORCE_LT (
150
- output_size[i], infer_shape + strides[i],
151
- platform::errors::InvalidArgument (
152
- " output_size of Op(ConvTransposeOp) should be less "
153
- " than infered size + stride. But received output_size = [%s], "
154
- " whose dim %d is not less than the infered output size (%d) + "
155
- " stride (%d) = %d" ,
156
- phi::make_ddim (output_size).to_str (), i, infer_shape,
157
- strides[i], infer_shape + strides[i]));
158
- }
159
- output_shape.push_back (output_size[i]);
160
- } else if (output_padding.size ()) {
161
- if (ctx->IsRuntime ()) {
162
- PADDLE_ENFORCE_GE (
163
- output_padding[i], 0 ,
164
- platform::errors::InvalidArgument (
165
- " output_padding of Op(ConvTransposeOp) should not be "
166
- " less than the 0. But received output_padding = "
167
- " [%s], whose dim %d is less than 0" ,
168
- phi::make_ddim (output_padding).to_str (), i));
169
- PADDLE_ENFORCE_LT (
170
- output_padding[i], std::max (strides[i], dilations[i]),
171
- platform::errors::InvalidArgument (
172
- " output_padding of Op(ConvTransposeOp) should be less "
173
- " than either stride or dilation. But received output_size = "
174
- " [%s], "
175
- " whose dim %d is not less than either stride (%d) or "
176
- " dilation (%d)" ,
177
- phi::make_ddim (output_size).to_str (), i, strides[i],
178
- dilations[i]));
179
- }
180
- output_shape.push_back ((infer_shape + output_padding[i]));
181
- } else {
182
- output_shape.push_back (infer_shape);
183
- }
184
- }
185
- if (data_layout == DataLayout::kNHWC ) {
186
- output_shape.push_back (filter_dims[1 ] * groups);
187
- }
188
- ctx->SetOutputDim (" Output" , phi::make_ddim (output_shape));
189
- }
190
-
191
36
framework::OpKernelType ConvTransposeOp::GetExpectedKernelType (
192
37
const framework::ExecutionContext& ctx) const {
193
38
framework::LibraryType library_{framework::LibraryType::kPlain };
@@ -217,7 +62,7 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
217
62
}
218
63
219
64
framework::OpKernelType ConvTransposeOp::GetKernelTypeForVar (
220
- const std::string& var_name, const Tensor& tensor,
65
+ const std::string& var_name, const framework:: Tensor& tensor,
221
66
const framework::OpKernelType& expected_kernel_type) const {
222
67
#ifdef PADDLE_WITH_MKLDNN
223
68
// Only input require reshaping, weights and
@@ -493,17 +338,6 @@ The input(X) size and output(Out) size may be different.
493
338
)DOC" );
494
339
}
495
340
496
- void ConvTransposeOpGrad::InferShape (framework::InferShapeContext* ctx) const {
497
- auto in_dims = ctx->GetInputDim (" Input" );
498
- auto filter_dims = ctx->GetInputDim (" Filter" );
499
- if (ctx->HasOutput (framework::GradVarName (" Input" ))) {
500
- ctx->SetOutputDim (framework::GradVarName (" Input" ), in_dims);
501
- }
502
- if (ctx->HasOutput (framework::GradVarName (" Filter" ))) {
503
- ctx->SetOutputDim (framework::GradVarName (" Filter" ), filter_dims);
504
- }
505
- }
506
-
507
341
framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType (
508
342
const framework::ExecutionContext& ctx) const {
509
343
bool use_cudnn =
@@ -587,24 +421,6 @@ class ConvTransposeDoubleGradMaker : public framework::SingleGradOpMaker<T> {
587
421
}
588
422
};
589
423
590
- void ConvTransposeOpDoubleGrad::InferShape (
591
- framework::InferShapeContext* ctx) const {
592
- auto x_dims = ctx->GetInputDim (" Input" );
593
- auto w_dims = ctx->GetInputDim (" Filter" );
594
- auto do_dims = ctx->GetInputDim (" DOutput" );
595
-
596
- if (ctx->HasOutput (" DDOutput" ) &&
597
- (ctx->HasInput (" DDInput" ) || (ctx->HasInput (" DDFilter" )))) {
598
- ctx->SetOutputDim (" DDOutput" , do_dims);
599
- }
600
- if (ctx->HasOutput (" DFilter" ) && ctx->HasInput (" DDInput" )) {
601
- ctx->SetOutputDim (" DFilter" , w_dims);
602
- }
603
- if (ctx->HasOutput (" DInput" ) && ctx->HasInput (" DDFilter" )) {
604
- ctx->SetOutputDim (" DInput" , x_dims);
605
- }
606
- }
607
-
608
424
framework::OpKernelType ConvTransposeOpDoubleGrad::GetExpectedKernelType (
609
425
const framework::ExecutionContext& ctx) const {
610
426
bool use_cudnn =
@@ -635,59 +451,57 @@ framework::OpKernelType ConvTransposeOpDoubleGrad::GetExpectedKernelType(
635
451
namespace ops = paddle::operators;
636
452
637
453
// conv2d_transpose
454
+ DECLARE_INFER_SHAPE_FUNCTOR (conv2d_transpose, Conv2dTranposeInferShapeFunctor,
455
+ PD_INFER_META (phi::ConvTransposeInferMeta));
456
+ DECLARE_INFER_SHAPE_FUNCTOR (conv2d_transpose_grad,
457
+ Conv2dTranposeGradInferShapeFunctor,
458
+ PD_INFER_META (phi::ConvTransposeGradInferMeta));
459
+ DECLARE_INFER_SHAPE_FUNCTOR (
460
+ conv2d_transpose_grad_grad, Conv2dTranposeDoubleGradInferShapeFunctor,
461
+ PD_INFER_META (phi::Conv2dTransposeDoubleGradInferMeta));
462
+
638
463
REGISTER_OPERATOR (conv2d_transpose, ops::ConvTransposeOp,
639
464
ops::Conv2DTransposeOpMaker,
640
465
ops::ConvTransposeGradOpMaker<paddle::framework::OpDesc>,
641
- ops::ConvTransposeGradOpMaker<paddle::imperative::OpBase>);
642
- REGISTER_OPERATOR (
643
- conv2d_transpose_grad, ops::ConvTransposeOpGrad,
644
- ops::ConvTransposeDoubleGradMaker<paddle::framework::OpDesc>,
645
- ops::ConvTransposeDoubleGradMaker<paddle::imperative::OpBase>);
646
- REGISTER_OPERATOR (conv2d_transpose_grad_grad, ops::ConvTransposeOpDoubleGrad);
647
-
648
- REGISTER_OP_CPU_KERNEL (
649
- conv2d_transpose,
650
- ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, float >,
651
- ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, double >);
652
- REGISTER_OP_CPU_KERNEL (
653
- conv2d_transpose_grad,
654
- ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext, float >,
655
- ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext,
656
- double >);
466
+ ops::ConvTransposeGradOpMaker<paddle::imperative::OpBase>,
467
+ Conv2dTranposeInferShapeFunctor);
468
+ REGISTER_OPERATOR (conv2d_transpose_grad, ops::ConvTransposeOpGrad,
469
+ ops::ConvTransposeDoubleGradMaker<paddle::framework::OpDesc>,
470
+ ops::ConvTransposeDoubleGradMaker<paddle::imperative::OpBase>,
471
+ Conv2dTranposeGradInferShapeFunctor);
472
+ REGISTER_OPERATOR (conv2d_transpose_grad_grad, ops::ConvTransposeOpDoubleGrad,
473
+ Conv2dTranposeDoubleGradInferShapeFunctor);
657
474
658
475
// conv3d_transpose
476
+ DECLARE_INFER_SHAPE_FUNCTOR (conv3d_transpose, Conv3dTranposeInferShapeFunctor,
477
+ PD_INFER_META (phi::ConvTransposeInferMeta));
478
+ DECLARE_INFER_SHAPE_FUNCTOR (conv3d_transpose_grad,
479
+ Conv3dTranposeGradInferShapeFunctor,
480
+ PD_INFER_META (phi::ConvTransposeGradInferMeta));
481
+
659
482
REGISTER_OPERATOR (conv3d_transpose, ops::ConvTransposeOp,
660
483
ops::Conv3DTransposeOpMaker,
661
484
ops::ConvTransposeGradOpMaker<paddle::framework::OpDesc>,
662
- ops::ConvTransposeGradOpMaker<paddle::imperative::OpBase>);
663
- REGISTER_OPERATOR (conv3d_transpose_grad, ops::ConvTransposeOpGrad);
664
-
665
- REGISTER_OP_CPU_KERNEL (
666
- conv3d_transpose,
667
- ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, float >,
668
- ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, double >);
669
- REGISTER_OP_CPU_KERNEL (
670
- conv3d_transpose_grad,
671
- ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext, float >,
672
- ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext,
673
- double >);
485
+ ops::ConvTransposeGradOpMaker<paddle::imperative::OpBase>,
486
+ Conv3dTranposeInferShapeFunctor);
487
+ REGISTER_OPERATOR (conv3d_transpose_grad, ops::ConvTransposeOpGrad,
488
+ Conv3dTranposeGradInferShapeFunctor);
674
489
675
490
// depthwise conv2d_transpose
491
+ DECLARE_INFER_SHAPE_FUNCTOR (depthwise_conv2d_transpose,
492
+ DepthWiseConv2dTranposeInferShapeFunctor,
493
+ PD_INFER_META (phi::ConvTransposeInferMeta));
494
+ DECLARE_INFER_SHAPE_FUNCTOR (depthwise_conv2d_transpose_grad,
495
+ DepthWiseConv2dTranposeGradInferShapeFunctor,
496
+ PD_INFER_META (phi::ConvTransposeGradInferMeta));
497
+
676
498
REGISTER_OPERATOR (depthwise_conv2d_transpose, ops::ConvTransposeOp,
677
499
ops::Conv2DTransposeOpMaker,
678
500
ops::ConvTransposeGradOpMaker<paddle::framework::OpDesc>,
679
- ops::ConvTransposeGradOpMaker<paddle::imperative::OpBase>);
680
- REGISTER_OPERATOR (depthwise_conv2d_transpose_grad, ops::ConvTransposeOpGrad);
681
-
682
- REGISTER_OP_CPU_KERNEL (
683
- depthwise_conv2d_transpose,
684
- ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, float >,
685
- ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, double >);
686
- REGISTER_OP_CPU_KERNEL (
687
- depthwise_conv2d_transpose_grad,
688
- ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext, float >,
689
- ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext,
690
- double >);
501
+ ops::ConvTransposeGradOpMaker<paddle::imperative::OpBase>,
502
+ DepthWiseConv2dTranposeInferShapeFunctor);
503
+ REGISTER_OPERATOR (depthwise_conv2d_transpose_grad, ops::ConvTransposeOpGrad,
504
+ DepthWiseConv2dTranposeGradInferShapeFunctor);
691
505
692
506
REGISTER_OP_VERSION (conv_transpose)
693
507
.AddCheckpoint(
0 commit comments