Skip to content

Commit 1c20588

Browse files
authored
move eye, lerp infershape to phi (#40105)
1 parent 167d511 commit 1c20588

File tree

9 files changed

+75
-65
lines changed

9 files changed

+75
-65
lines changed

paddle/fluid/operators/eye_op.cc

+7-19
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15+
#include "paddle/fluid/framework/infershape_utils.h"
1516
#include "paddle/fluid/framework/op_registry.h"
17+
#include "paddle/phi/core/infermeta_utils.h"
18+
#include "paddle/phi/infermeta/nullary.h"
1619

1720
namespace paddle {
1821
namespace operators {
@@ -21,24 +24,6 @@ class EyeOp : public framework::OperatorWithKernel {
2124
public:
2225
using framework::OperatorWithKernel::OperatorWithKernel;
2326

24-
void InferShape(framework::InferShapeContext* ctx) const override {
25-
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
26-
platform::errors::InvalidArgument(
27-
"Output(Out) of EyeOP should not be null."));
28-
auto num_rows = ctx->Attrs().Get<int64_t>("num_rows");
29-
PADDLE_ENFORCE_EQ(
30-
num_rows >= 0, true,
31-
platform::errors::InvalidArgument(
32-
"The value of Input(num_rows) should be non-negative int."));
33-
auto num_columns = ctx->Attrs().Get<int64_t>("num_columns");
34-
if (num_columns == -1) num_columns = num_rows;
35-
PADDLE_ENFORCE_EQ(
36-
num_columns >= 0, true,
37-
platform::errors::InvalidArgument(
38-
"The value of Input(num_columns) should be non-negative int."));
39-
ctx->SetOutputDim("Out", {num_rows, num_columns});
40-
}
41-
4227
protected:
4328
framework::OpKernelType GetExpectedKernelType(
4429
const framework::ExecutionContext& ctx) const override {
@@ -82,8 +67,11 @@ Return an identity tensor whose shape is [num_rows, num_columns].
8267
} // namespace paddle
8368

8469
namespace ops = paddle::operators;
70+
DELCARE_INFER_SHAPE_FUNCTOR(eye, EyeInferShapeFunctor,
71+
PT_INFER_META(phi::EyeInferMeta));
8572

8673
REGISTER_OPERATOR(
8774
eye, ops::EyeOp, ops::EyeOpMaker, ops::EyeOpVarTypeInference,
8875
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
89-
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
76+
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
77+
EyeInferShapeFunctor);

paddle/fluid/operators/lerp_op.cc

+6-44
Original file line numberDiff line numberDiff line change
@@ -12,57 +12,17 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
#include "paddle/fluid/framework/infershape_utils.h"
1516
#include "paddle/fluid/framework/op_registry.h"
17+
#include "paddle/phi/core/infermeta_utils.h"
18+
#include "paddle/phi/infermeta/ternary.h"
1619

1720
namespace paddle {
1821
namespace operators {
1922

2023
class LerpOp : public framework::OperatorWithKernel {
2124
public:
2225
using framework::OperatorWithKernel::OperatorWithKernel;
23-
24-
void InferShape(framework::InferShapeContext* ctx) const override {
25-
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "lerp");
26-
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "lerp");
27-
OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "lerp");
28-
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "lerp");
29-
30-
auto x_dims = ctx->GetInputDim("X");
31-
auto y_dims = ctx->GetInputDim("Y");
32-
auto w_dims = ctx->GetInputDim("Weight");
33-
framework::DDim out_dims;
34-
out_dims = GetOutputDims(x_dims, y_dims);
35-
if (w_dims.size() > 1 || w_dims[0] != 1) {
36-
out_dims = GetOutputDims(out_dims, w_dims);
37-
}
38-
39-
ctx->SetOutputDim("Out", out_dims);
40-
ctx->ShareLoD("X", /*->*/ "Out");
41-
}
42-
43-
private:
44-
framework::DDim GetOutputDims(const framework::DDim& s_dims,
45-
const framework::DDim& l_dims) const {
46-
if (s_dims.size() > l_dims.size()) {
47-
return GetOutputDims(l_dims, s_dims);
48-
}
49-
std::vector<int64_t> shapes = phi::vectorize<int64_t>(l_dims);
50-
for (int i = s_dims.size() - 1, j = l_dims.size() - 1; i >= 0; --i, --j) {
51-
int64_t s = s_dims[i];
52-
int64_t l = l_dims[j];
53-
if (s != l) {
54-
if (l == 1) {
55-
shapes[j] = s;
56-
} else if (s != 1) {
57-
PADDLE_THROW(platform::errors::InvalidArgument(
58-
"The shape of tensor a %s:%d must match shape of tensor b "
59-
"%s:%d.",
60-
s_dims.to_str(), i, l_dims.to_str(), j));
61-
}
62-
}
63-
}
64-
return phi::make_ddim(shapes);
65-
}
6626
};
6727

6828
class LerpOpMaker : public framework::OpProtoAndCheckerMaker {
@@ -125,10 +85,12 @@ DECLARE_INPLACE_OP_INFERER(LerpInplaceInferer, {"X", "Out"});
12585
} // namespace operators
12686
} // namespace paddle
12787

88+
DELCARE_INFER_SHAPE_FUNCTOR(lerp, LerpInferShapeFunctor,
89+
PT_INFER_META(phi::LerpInferMeta));
12890
REGISTER_OPERATOR(
12991
lerp, paddle::operators::LerpOp, paddle::operators::LerpOpMaker,
13092
paddle::operators::LerpOpGradMaker<paddle::framework::OpDesc>,
13193
paddle::operators::LerpOpGradMaker<paddle::imperative::OpBase>,
132-
paddle::operators::LerpInplaceInferer);
94+
paddle::operators::LerpInplaceInferer, LerpInferShapeFunctor);
13395

13496
REGISTER_OPERATOR(lerp_grad, paddle::operators::LerpGradOp);

paddle/phi/infermeta/nullary.cc

+8
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,12 @@ void CreateInferMeta(const ScalarArray& shape,
3232
CreateInferMetaBase(shape.GetData(), dtype, DataLayout::NCHW, out);
3333
}
3434

35+
void EyeInferMeta(int64_t num_rows,
36+
int64_t num_columns,
37+
DataType dtype,
38+
MetaTensor* out) {
39+
if (num_columns == -1) num_columns = num_rows;
40+
out->set_dims({num_rows, num_columns});
41+
out->set_dtype(dtype);
42+
}
3543
} // namespace phi

paddle/phi/infermeta/nullary.h

+5
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,9 @@ void CreateInferMetaBase(const std::vector<int64_t>& shape,
3535

3636
void CreateInferMeta(const ScalarArray& shape, DataType dtype, MetaTensor* out);
3737

38+
void EyeInferMeta(int64_t num_rows,
39+
int64_t num_columns,
40+
DataType dtype,
41+
MetaTensor* out);
42+
3843
} // namespace phi

paddle/phi/infermeta/ternary.cc

+17
Original file line numberDiff line numberDiff line change
@@ -89,4 +89,21 @@ void AddmmInferMeta(const MetaTensor& input,
8989
out->set_dtype(input.dtype());
9090
}
9191

92+
void LerpInferMeta(const MetaTensor& x,
93+
const MetaTensor& y,
94+
const MetaTensor& weight,
95+
MetaTensor* out) {
96+
auto x_dims = x.dims();
97+
auto y_dims = y.dims();
98+
auto w_dims = weight.dims();
99+
DDim out_dims;
100+
out_dims = funcs::GetOutputDims(x_dims, y_dims);
101+
if (w_dims.size() > 1 || w_dims[0] != 1) {
102+
out_dims = funcs::GetOutputDims(out_dims, w_dims);
103+
}
104+
out->set_dims(out_dims);
105+
out->set_dtype(x.dtype());
106+
out->share_lod(x);
107+
}
108+
92109
} // namespace phi

paddle/phi/infermeta/ternary.h

+5
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,9 @@ void AddmmInferMeta(const MetaTensor& input,
3737
float beta,
3838
MetaTensor* out);
3939

40+
void LerpInferMeta(const MetaTensor& x,
41+
const MetaTensor& y,
42+
const MetaTensor& weight,
43+
MetaTensor* out);
44+
4045
} // namespace phi

paddle/phi/kernels/eye_kernel.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ template <typename T, typename Context>
2222
void EyeKernel(const Context& ctx,
2323
int64_t num_rows,
2424
int64_t num_columns,
25-
int dtype,
25+
DataType dtype,
2626
DenseTensor* out);
2727

2828
} // namespace phi

paddle/phi/kernels/funcs/common_shape.h

+25
Original file line numberDiff line numberDiff line change
@@ -140,5 +140,30 @@ inline bool CheckDims(const DDim &dims_x, const DDim &dims_y) {
140140
return true;
141141
}
142142

143+
inline DDim GetOutputDims(const DDim &s_dims, const DDim &l_dims) {
144+
if (s_dims.size() > l_dims.size()) {
145+
return GetOutputDims(l_dims, s_dims);
146+
}
147+
std::vector<int64_t> shapes = phi::vectorize<int64_t>(l_dims);
148+
for (int i = s_dims.size() - 1, j = l_dims.size() - 1; i >= 0; --i, --j) {
149+
int64_t s = s_dims[i];
150+
int64_t l = l_dims[j];
151+
if (s != l) {
152+
if (l == 1) {
153+
shapes[j] = s;
154+
} else if (s != 1) {
155+
PADDLE_THROW(errors::InvalidArgument(
156+
"The shape of tensor a %s:%d must match shape of tensor b "
157+
"%s:%d.",
158+
s_dims.to_str(),
159+
i,
160+
l_dims.to_str(),
161+
j));
162+
}
163+
}
164+
}
165+
return phi::make_ddim(shapes);
166+
}
167+
143168
} // namespace funcs
144169
} // namespace phi

paddle/phi/kernels/impl/eye_kernel_impl.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ template <typename T, typename Context>
3636
void EyeKernel(const Context& ctx,
3737
int64_t num_rows,
3838
int64_t num_columns,
39-
int dtype,
39+
DataType dtype,
4040
DenseTensor* out) {
4141
auto num = num_columns;
4242
if (num == -1) {

0 commit comments

Comments
 (0)