Skip to content

Commit c243a39

Browse files
remove raw pad3d infershape function
1 parent a209149 commit c243a39

File tree

1 file changed

+0
-70
lines changed

1 file changed

+0
-70
lines changed

paddle/fluid/operators/pad3d_op.cc

Lines changed: 0 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -30,76 +30,6 @@ class Pad3dOp : public framework::OperatorWithKernel {
3030
public:
3131
using framework::OperatorWithKernel::OperatorWithKernel;
3232

33-
void InferShape(framework::InferShapeContext* ctx) const override {
34-
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Pad3d");
35-
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Pad3d");
36-
37-
auto x_dim = ctx->GetInputDim("X");
38-
PADDLE_ENFORCE_EQ(x_dim.size(), 5,
39-
platform::errors::InvalidArgument(
40-
"The size of Input(X)'s dimension should be equal to "
41-
"5, but received %d. ",
42-
x_dim.size()));
43-
44-
std::vector<int64_t> out_dims(x_dim.size());
45-
auto data_format = ctx->Attrs().Get<std::string>("data_format");
46-
out_dims[0] = x_dim[0];
47-
if (ctx->HasInput("Paddings")) {
48-
auto paddings_dim = ctx->GetInputDim("Paddings");
49-
PADDLE_ENFORCE_EQ(paddings_dim.size(), 1,
50-
platform::errors::InvalidArgument(
51-
"Size of Input(Paddings)'s dimension should be "
52-
"equal to 1, but received %d.",
53-
paddings_dim.size()));
54-
if (ctx->IsRuntime()) {
55-
PADDLE_ENFORCE_EQ(paddings_dim[0], 6,
56-
platform::errors::InvalidArgument(
57-
"Shape of Input(Paddings) should be equal to "
58-
"[6], but received [%d].",
59-
paddings_dim[0]));
60-
}
61-
out_dims[1] = x_dim[1];
62-
out_dims[2] = x_dim[2];
63-
out_dims[3] = x_dim[3];
64-
} else {
65-
auto paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
66-
PADDLE_ENFORCE_EQ(
67-
paddings.size(), 6,
68-
platform::errors::InvalidArgument(
69-
"Size of paddings should be equal to 4, but received %d.",
70-
static_cast<int>(paddings.size())));
71-
if (data_format == "NCDHW") {
72-
out_dims[1] = x_dim[1]; // channel
73-
out_dims[2] = ((!ctx->IsRuntime()) && (x_dim[2] < 0))
74-
? x_dim[2]
75-
: (x_dim[2] + paddings[4] + paddings[5]); // depth
76-
77-
out_dims[3] = ((!ctx->IsRuntime()) && (x_dim[3] < 0))
78-
? x_dim[3]
79-
: (x_dim[3] + paddings[2] + paddings[3]); // height
80-
81-
out_dims[4] = ((!ctx->IsRuntime()) && (x_dim[4] < 0))
82-
? x_dim[4]
83-
: (x_dim[4] + paddings[0] + paddings[1]); // width
84-
} else { // NDHWC
85-
out_dims[4] = x_dim[4]; // channel
86-
87-
out_dims[1] = ((!ctx->IsRuntime()) && (x_dim[1] < 0))
88-
? x_dim[1]
89-
: (x_dim[1] + paddings[4] + paddings[5]); // depth
90-
out_dims[2] = ((!ctx->IsRuntime()) && (x_dim[2] < 0))
91-
? x_dim[2]
92-
: (x_dim[2] + paddings[2] + paddings[3]); // height
93-
out_dims[3] = ((!ctx->IsRuntime()) && (x_dim[3] < 0))
94-
? x_dim[3]
95-
: (x_dim[3] + paddings[0] + paddings[1]); // width
96-
}
97-
}
98-
99-
ctx->SetOutputDim("Out", phi::make_ddim(out_dims));
100-
ctx->ShareLoD("X", /*->*/ "Out");
101-
}
102-
10333
protected:
10434
framework::OpKernelType GetExpectedKernelType(
10535
const framework::ExecutionContext& ctx) const override {

0 commit comments

Comments
 (0)