Skip to content

Commit 87b903a

Browse files
authored
[phi]migrate increment addmm multinomial cholesky InferShapes to phi (#39913)
* [phi]migrate increment addmm multinomial cholesky InferShapes to phi * set_dtype and mod MultinomialFunctor
1 parent 37cb6f3 commit 87b903a

16 files changed

+383
-275
lines changed

paddle/fluid/operators/addmm_op.cc

Lines changed: 7 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@ limitations under the License. */
1616
#include <string>
1717
#include <unordered_map>
1818
#include <vector>
19+
#include "paddle/fluid/framework/infershape_utils.h"
1920
#include "paddle/fluid/framework/op_registry.h"
21+
#include "paddle/phi/core/infermeta_utils.h"
22+
#include "paddle/phi/infermeta/ternary.h"
2023
#ifdef PADDLE_WITH_MKLDNN
2124
#include "paddle/fluid/platform/mkldnn_helper.h"
2225
#endif
@@ -33,85 +36,6 @@ class AddMMOp : public framework::OperatorWithKernel {
3336
public:
3437
using framework::OperatorWithKernel::OperatorWithKernel;
3538

36-
void InferShape(framework::InferShapeContext* ctx) const override {
37-
PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true,
38-
platform::errors::NotFound(
39-
"Input(Input) of AddMMOp should not be null."));
40-
PADDLE_ENFORCE_EQ(
41-
ctx->HasInput("X"), true,
42-
platform::errors::NotFound("Input(X) of AddMMOp should not be null."));
43-
PADDLE_ENFORCE_EQ(
44-
ctx->HasInput("Y"), true,
45-
platform::errors::NotFound("Input(Y) of AddMMOp should not be null."));
46-
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
47-
platform::errors::NotFound(
48-
"Output(Out) of AddMMOp should not be null."));
49-
50-
auto input_dims = ctx->GetInputDim("Input");
51-
auto x_dims = ctx->GetInputDim("X");
52-
auto y_dims = ctx->GetInputDim("Y");
53-
54-
auto ndim_input = input_dims.size();
55-
auto ndim_x = x_dims.size();
56-
auto ndim_y = y_dims.size();
57-
58-
float alpha = ctx->Attrs().Get<float>("Alpha");
59-
float beta = ctx->Attrs().Get<float>("Beta");
60-
61-
VLOG(3) << "addmm operator input.shape=" << input_dims
62-
<< " x.shape=" << x_dims << " y.shape=" << y_dims
63-
<< " beta=" << beta << " alpha=" << alpha
64-
<< " ndim_input=" << ndim_input << " ndim_x=" << ndim_x
65-
<< " ndim_y=" << ndim_y;
66-
67-
PADDLE_ENFORCE_NE(phi::product(input_dims), 0,
68-
platform::errors::PreconditionNotMet(
69-
"The Input variable Input(%s) has not "
70-
"been initialized. You may need to confirm "
71-
"if you put exe.run(startup_program) "
72-
"after optimizer.minimize function.",
73-
ctx->Inputs("Input").front()));
74-
75-
PADDLE_ENFORCE_NE(phi::product(x_dims), 0,
76-
platform::errors::PreconditionNotMet(
77-
"The Input variable X(%s) has not "
78-
"been initialized. You may need to confirm "
79-
"if you put exe.run(startup_program) "
80-
"after optimizer.minimize function.",
81-
ctx->Inputs("X").front()));
82-
83-
PADDLE_ENFORCE_NE(phi::product(y_dims), 0,
84-
platform::errors::PreconditionNotMet(
85-
"The Input variable Y(%s) has not "
86-
"been initialized. You may need to confirm "
87-
"if you put exe.run(startup_program) "
88-
"after optimizer.minimize function.",
89-
ctx->Inputs("Y").front()));
90-
// dim check
91-
PADDLE_ENFORCE_EQ(ndim_input, 2,
92-
platform::errors::InvalidArgument(
93-
"The input tensor input's dimension must be 2. "
94-
"But received input's dimension = [%s].",
95-
ndim_input));
96-
PADDLE_ENFORCE_EQ(ndim_x, 2,
97-
platform::errors::InvalidArgument(
98-
"The input tensor x's dimension must be 2. "
99-
"But received x's dimension = [%s].",
100-
ndim_x));
101-
PADDLE_ENFORCE_EQ(ndim_y, 2,
102-
platform::errors::InvalidArgument(
103-
"The input tensor y's dimension must be 2. "
104-
"But received y's dimension = [%s].",
105-
ndim_y));
106-
107-
std::vector<int64_t> output_dims;
108-
output_dims.push_back(x_dims[0]);
109-
output_dims.push_back(y_dims[1]);
110-
111-
ctx->SetOutputDim("Out", phi::make_ddim(output_dims));
112-
ctx->ShareLoD("Input", /*->*/ "Out");
113-
}
114-
11539
framework::OpKernelType GetExpectedKernelType(
11640
const framework::ExecutionContext& ctx) const {
11741
framework::LibraryType library = framework::LibraryType::kPlain;
@@ -223,9 +147,11 @@ class AddMMOpGradMaker : public framework::SingleGradOpMaker<T> {
223147
} // namespace paddle
224148

225149
namespace ops = paddle::operators;
226-
150+
DELCARE_INFER_SHAPE_FUNCTOR(addmm, AddmmInferShapeFunctor,
151+
PT_INFER_META(phi::AddmmInferMeta));
227152
REGISTER_OPERATOR(addmm, ops::AddMMOp, ops::AddMMOpMaker,
228153
ops::AddMMOpGradMaker<paddle::framework::OpDesc>,
229-
ops::AddMMOpGradMaker<paddle::imperative::OpBase>);
154+
ops::AddMMOpGradMaker<paddle::imperative::OpBase>,
155+
AddmmInferShapeFunctor);
230156

231157
REGISTER_OPERATOR(addmm_grad, ops::AddMMGradOp);

paddle/fluid/operators/cholesky_op.cc

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15+
#include "paddle/fluid/framework/infershape_utils.h"
1516
#include "paddle/fluid/framework/op_registry.h"
17+
#include "paddle/phi/core/infermeta_utils.h"
18+
#include "paddle/phi/infermeta/unary.h"
1619

1720
namespace paddle {
1821
namespace operators {
@@ -23,26 +26,6 @@ using framework::Tensor;
2326
class CholeskyOp : public framework::OperatorWithKernel {
2427
public:
2528
using framework::OperatorWithKernel::OperatorWithKernel;
26-
27-
void InferShape(framework::InferShapeContext* ctx) const override {
28-
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Cholesky");
29-
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Cholesky");
30-
auto dims = ctx->GetInputDim("X");
31-
auto rank = dims.size();
32-
PADDLE_ENFORCE_GE(rank, 2,
33-
platform::errors::InvalidArgument(
34-
"The Input(X) should have at least 2 dimensions. But "
35-
"received a %d dimension tensor.",
36-
rank));
37-
PADDLE_ENFORCE_EQ(
38-
dims[rank - 2], dims[rank - 1],
39-
platform::errors::InvalidArgument(
40-
"The inner-most 2 dimensions of Input(X) all should be symmetric "
41-
"positive-definite matrices and have the same size. But received "
42-
"X's shape[-2] = %d and shape[-1] = %d.",
43-
dims[rank - 2], dims[rank - 1]));
44-
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
45-
}
4629
};
4730

4831
class CholeskyOpMaker : public framework::OpProtoAndCheckerMaker {
@@ -107,7 +90,10 @@ class CholeskyGradOpMaker : public framework::SingleGradOpMaker<T> {
10790
} // namespace paddle
10891

10992
namespace ops = paddle::operators;
93+
DELCARE_INFER_SHAPE_FUNCTOR(cholesky, CholeskyInferShapeFunctor,
94+
PT_INFER_META(phi::CholeskyInferMeta));
11095
REGISTER_OPERATOR(cholesky, ops::CholeskyOp, ops::CholeskyOpMaker,
11196
ops::CholeskyGradOpMaker<paddle::framework::OpDesc>,
112-
ops::CholeskyGradOpMaker<paddle::imperative::OpBase>);
97+
ops::CholeskyGradOpMaker<paddle::imperative::OpBase>,
98+
CholeskyInferShapeFunctor);
11399
REGISTER_OPERATOR(cholesky_grad, ops::CholeskyGradOp);

paddle/fluid/operators/increment_op.cc

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
#include "paddle/fluid/framework/infershape_utils.h"
1516
#include "paddle/fluid/framework/op_registry.h"
17+
#include "paddle/phi/core/infermeta_utils.h"
18+
#include "paddle/phi/infermeta/unary.h"
1619

1720
namespace paddle {
1821
namespace framework {
@@ -37,18 +40,6 @@ class IncrementOp : public framework::OperatorWithKernel {
3740
const framework::AttributeMap &attrs)
3841
: OperatorWithKernel(type, inputs, outputs, attrs) {}
3942

40-
void InferShape(framework::InferShapeContext *ctx) const override {
41-
PADDLE_ENFORCE_EQ(phi::product(ctx->GetInputDim("X")), 1UL,
42-
platform::errors::InvalidArgument(
43-
"The number of elements in Input(X) should be 1."
44-
"Now the number is %d.",
45-
phi::product(ctx->GetInputDim("X"))));
46-
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "increment");
47-
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "increment");
48-
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
49-
ctx->ShareLoD("X", "Out");
50-
}
51-
5243
protected:
5344
framework::OpKernelType GetExpectedKernelType(
5445
const framework::ExecutionContext &ctx) const override {
@@ -96,6 +87,9 @@ class IncrementGradOpMaker : public framework::SingleGradOpMaker<T> {
9687
} // namespace paddle
9788

9889
namespace ops = paddle::operators;
90+
DELCARE_INFER_SHAPE_FUNCTOR(increment, IncrementInferShapeFunctor,
91+
PT_INFER_META(phi::IncrementInferMeta));
9992
REGISTER_OPERATOR(increment, ops::IncrementOp, ops::IncrementOpMaker,
10093
ops::IncrementGradOpMaker<paddle::framework::OpDesc>,
101-
ops::IncrementGradOpMaker<paddle::imperative::OpBase>);
94+
ops::IncrementGradOpMaker<paddle::imperative::OpBase>,
95+
IncrementInferShapeFunctor);

paddle/fluid/operators/multinomial_op.cc

Lines changed: 7 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@ limitations under the License. */
1616
#include <string>
1717
#include <vector>
1818

19-
#include "paddle/fluid/framework/generator.h"
19+
#include "paddle/fluid/framework/infershape_utils.h"
2020
#include "paddle/fluid/framework/op_registry.h"
2121
#include "paddle/fluid/framework/operator.h"
22-
#include "paddle/fluid/operators/common_infer_shape_functions.h"
22+
#include "paddle/phi/core/infermeta_utils.h"
23+
#include "paddle/phi/infermeta/unary.h"
2324

2425
namespace paddle {
2526
namespace operators {
@@ -45,46 +46,17 @@ This OP returns a Tensor filled with the sampled categoris according to Multinom
4546
class MultinomialOp : public framework::OperatorWithKernel {
4647
public:
4748
using framework::OperatorWithKernel::OperatorWithKernel;
48-
49-
void InferShape(framework::InferShapeContext *ctx) const override {
50-
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Multinomial");
51-
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Multinomial");
52-
53-
auto x_dim = ctx->GetInputDim("X");
54-
int64_t x_rank = x_dim.size();
55-
PADDLE_ENFORCE_GT(x_rank, 0,
56-
platform::errors::InvalidArgument(
57-
"The number of dimensions of the input probability "
58-
"distribution should be > 0, but got %d.",
59-
x_rank));
60-
PADDLE_ENFORCE_LE(x_rank, 2,
61-
platform::errors::InvalidArgument(
62-
"The number of dimensions of the input probability "
63-
"distribution should be <= 2, but got %d.",
64-
x_rank));
65-
66-
std::vector<int64_t> out_dims(x_rank);
67-
for (int64_t i = 0; i < x_rank - 1; i++) {
68-
out_dims[i] = x_dim[i];
69-
}
70-
71-
int64_t num_samples = ctx->Attrs().Get<int>("num_samples");
72-
PADDLE_ENFORCE_GT(
73-
num_samples, 0,
74-
platform::errors::InvalidArgument(
75-
"The number of samples should be > 0, but got %d.", num_samples));
76-
out_dims[x_rank - 1] = num_samples;
77-
78-
ctx->SetOutputDim("Out", phi::make_ddim(out_dims));
79-
}
8049
};
8150

8251
} // namespace operators
8352
} // namespace paddle
8453

8554
namespace ops = paddle::operators;
8655
namespace plat = paddle::platform;
56+
DELCARE_INFER_SHAPE_FUNCTOR(multinomial, MultinomialInferShapeFunctor,
57+
PT_INFER_META(phi::MultinomialInferMeta));
8758
REGISTER_OPERATOR(
8859
multinomial, ops::MultinomialOp, ops::MultinomialOpMaker,
8960
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
90-
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
61+
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
62+
MultinomialInferShapeFunctor);

paddle/phi/infermeta/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
cc_library(infermeta SRCS nullary.cc unary.cc binary.cc multiary.cc DEPS convert_utils meta_tensor infermeta_utils)
1+
cc_library(infermeta SRCS nullary.cc unary.cc binary.cc ternary.cc multiary.cc DEPS convert_utils meta_tensor infermeta_utils)
22
cc_library(backward_infermeta SRCS backward.cc DEPS meta_tensor convert_utils)

paddle/phi/infermeta/ternary.cc

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/phi/infermeta/ternary.h"
16+
#include "paddle/phi/core/ddim.h"
17+
#include "paddle/phi/kernels/funcs/common_shape.h"
18+
19+
namespace phi {
20+
21+
void AddmmInferMeta(const MetaTensor& input,
22+
const MetaTensor& x,
23+
const MetaTensor& y,
24+
float alpha,
25+
float beta,
26+
MetaTensor* out) {
27+
auto input_dims = input.dims();
28+
auto x_dims = x.dims();
29+
auto y_dims = y.dims();
30+
31+
auto ndim_input = input_dims.size();
32+
auto ndim_x = x_dims.size();
33+
auto ndim_y = y_dims.size();
34+
35+
VLOG(3) << "addmm operator input.shape=" << input_dims
36+
<< " x.shape=" << x_dims << " y.shape=" << y_dims << " beta=" << beta
37+
<< " alpha=" << alpha << " ndim_input=" << ndim_input
38+
<< " ndim_x=" << ndim_x << " ndim_y=" << ndim_y;
39+
40+
PADDLE_ENFORCE_NE(
41+
product(input_dims),
42+
0,
43+
errors::PreconditionNotMet("The Input variable 'input' has not "
44+
"been initialized. You may need to confirm "
45+
"if you put exe.run(startup_program) "
46+
"after optimizer.minimize function."));
47+
48+
PADDLE_ENFORCE_NE(
49+
product(x_dims),
50+
0,
51+
errors::PreconditionNotMet("The Input variable 'x' has not "
52+
"been initialized. You may need to confirm "
53+
"if you put exe.run(startup_program) "
54+
"after optimizer.minimize function."));
55+
56+
PADDLE_ENFORCE_NE(
57+
product(y_dims),
58+
0,
59+
errors::PreconditionNotMet("The Input variable 'y' has not "
60+
"been initialized. You may need to confirm "
61+
"if you put exe.run(startup_program) "
62+
"after optimizer.minimize function."));
63+
// dim check
64+
PADDLE_ENFORCE_EQ(
65+
ndim_input,
66+
2,
67+
errors::InvalidArgument("The input tensor input's dimension must be 2. "
68+
"But received input's dimension = [%s].",
69+
ndim_input));
70+
PADDLE_ENFORCE_EQ(
71+
ndim_x,
72+
2,
73+
errors::InvalidArgument("The input tensor x's dimension must be 2. "
74+
"But received x's dimension = [%s].",
75+
ndim_x));
76+
PADDLE_ENFORCE_EQ(
77+
ndim_y,
78+
2,
79+
errors::InvalidArgument("The input tensor y's dimension must be 2. "
80+
"But received y's dimension = [%s].",
81+
ndim_y));
82+
83+
std::vector<int64_t> output_dims;
84+
output_dims.push_back(x_dims[0]);
85+
output_dims.push_back(y_dims[1]);
86+
87+
out->set_dims(make_ddim(output_dims));
88+
out->share_lod(input);
89+
out->set_dtype(input.dtype());
90+
}
91+
92+
} // namespace phi

0 commit comments

Comments
 (0)