@@ -30,76 +30,6 @@ class Pad3dOp : public framework::OperatorWithKernel {
30
30
public:
31
31
using framework::OperatorWithKernel::OperatorWithKernel;
32
32
33
- void InferShape (framework::InferShapeContext* ctx) const override {
34
- OP_INOUT_CHECK (ctx->HasInput (" X" ), " Input" , " X" , " Pad3d" );
35
- OP_INOUT_CHECK (ctx->HasOutput (" Out" ), " Output" , " Out" , " Pad3d" );
36
-
37
- auto x_dim = ctx->GetInputDim (" X" );
38
- PADDLE_ENFORCE_EQ (x_dim.size (), 5 ,
39
- platform::errors::InvalidArgument (
40
- " The size of Input(X)'s dimension should be equal to "
41
- " 5, but received %d. " ,
42
- x_dim.size ()));
43
-
44
- std::vector<int64_t > out_dims (x_dim.size ());
45
- auto data_format = ctx->Attrs ().Get <std::string>(" data_format" );
46
- out_dims[0 ] = x_dim[0 ];
47
- if (ctx->HasInput (" Paddings" )) {
48
- auto paddings_dim = ctx->GetInputDim (" Paddings" );
49
- PADDLE_ENFORCE_EQ (paddings_dim.size (), 1 ,
50
- platform::errors::InvalidArgument (
51
- " Size of Input(Paddings)'s dimension should be "
52
- " equal to 1, but received %d." ,
53
- paddings_dim.size ()));
54
- if (ctx->IsRuntime ()) {
55
- PADDLE_ENFORCE_EQ (paddings_dim[0 ], 6 ,
56
- platform::errors::InvalidArgument (
57
- " Shape of Input(Paddings) should be equal to "
58
- " [6], but received [%d]." ,
59
- paddings_dim[0 ]));
60
- }
61
- out_dims[1 ] = x_dim[1 ];
62
- out_dims[2 ] = x_dim[2 ];
63
- out_dims[3 ] = x_dim[3 ];
64
- } else {
65
- auto paddings = ctx->Attrs ().Get <std::vector<int >>(" paddings" );
66
- PADDLE_ENFORCE_EQ (
67
- paddings.size (), 6 ,
68
- platform::errors::InvalidArgument (
69
- " Size of paddings should be equal to 4, but received %d." ,
70
- static_cast <int >(paddings.size ())));
71
- if (data_format == " NCDHW" ) {
72
- out_dims[1 ] = x_dim[1 ]; // channel
73
- out_dims[2 ] = ((!ctx->IsRuntime ()) && (x_dim[2 ] < 0 ))
74
- ? x_dim[2 ]
75
- : (x_dim[2 ] + paddings[4 ] + paddings[5 ]); // depth
76
-
77
- out_dims[3 ] = ((!ctx->IsRuntime ()) && (x_dim[3 ] < 0 ))
78
- ? x_dim[3 ]
79
- : (x_dim[3 ] + paddings[2 ] + paddings[3 ]); // height
80
-
81
- out_dims[4 ] = ((!ctx->IsRuntime ()) && (x_dim[4 ] < 0 ))
82
- ? x_dim[4 ]
83
- : (x_dim[4 ] + paddings[0 ] + paddings[1 ]); // width
84
- } else { // NDHWC
85
- out_dims[4 ] = x_dim[4 ]; // channel
86
-
87
- out_dims[1 ] = ((!ctx->IsRuntime ()) && (x_dim[1 ] < 0 ))
88
- ? x_dim[1 ]
89
- : (x_dim[1 ] + paddings[4 ] + paddings[5 ]); // depth
90
- out_dims[2 ] = ((!ctx->IsRuntime ()) && (x_dim[2 ] < 0 ))
91
- ? x_dim[2 ]
92
- : (x_dim[2 ] + paddings[2 ] + paddings[3 ]); // height
93
- out_dims[3 ] = ((!ctx->IsRuntime ()) && (x_dim[3 ] < 0 ))
94
- ? x_dim[3 ]
95
- : (x_dim[3 ] + paddings[0 ] + paddings[1 ]); // width
96
- }
97
- }
98
-
99
- ctx->SetOutputDim (" Out" , phi::make_ddim (out_dims));
100
- ctx->ShareLoD (" X" , /* ->*/ " Out" );
101
- }
102
-
103
33
protected:
104
34
framework::OpKernelType GetExpectedKernelType (
105
35
const framework::ExecutionContext& ctx) const override {
0 commit comments