1616
1717#include  < vector> 
1818
19+ #include  " paddle/fluid/framework/infershape_utils.h" 
20+ #include  " paddle/phi/core/infermeta_utils.h" 
21+ #include  " paddle/phi/infermeta/backward.h" 
22+ #include  " paddle/phi/infermeta/binary.h" 
23+ 
1924namespace  paddle  {
2025namespace  operators  {
2126
@@ -24,62 +29,6 @@ class BmmOp : public framework::OperatorWithKernel {
2429  using  framework::OperatorWithKernel::OperatorWithKernel;
2530
2631 protected: 
27-   void  InferShape (framework::InferShapeContext* ctx) const  override  {
28-     PADDLE_ENFORCE_EQ (
29-         ctx->HasInput (" X" 
30-         true ,
31-         platform::errors::NotFound (" Input(X) of BmmOp should not be null" 
32-     PADDLE_ENFORCE_EQ (
33-         ctx->HasInput (" Y" 
34-         true ,
35-         platform::errors::NotFound (" Input(Y) of BmmOp should not be null" 
36-     PADDLE_ENFORCE_EQ (
37-         ctx->HasOutput (" Out" 
38-         true ,
39-         platform::errors::NotFound (" Output(Out) of BmmOp should not be null." 
40- 
41-     auto  x_dims = ctx->GetInputDim (" X" 
42-     auto  y_dims = ctx->GetInputDim (" Y" 
43- 
44-     PADDLE_ENFORCE_EQ (x_dims.size (),
45-                       3 ,
46-                       platform::errors::InvalidArgument (
47-                           " Input(X) of BmmOp must be 3-dimensional in BmmOp, " 
48-                           " but received X's shape: [%s]." 
49-                           x_dims));
50-     PADDLE_ENFORCE_EQ (y_dims.size (),
51-                       3 ,
52-                       platform::errors::InvalidArgument (
53-                           " Input(Y) of BmmOp must be 3-dimensional in BmmOp, " 
54-                           " but received Y's shape: [%s]." 
55-                           y_dims));
56-     PADDLE_ENFORCE_EQ (
57-         x_dims[0 ],
58-         y_dims[0 ],
59-         platform::errors::InvalidArgument (
60-             " Input(X) and Input(Y) must have the same batch size in BmmOp, " 
61-             " but received X's batch size: [%s]," 
62-             " Y's batch size [%s]" 
63-             x_dims[0 ],
64-             y_dims[0 ]));
65-     PADDLE_ENFORCE_EQ (
66-         x_dims[2 ],
67-         y_dims[1 ],
68-         platform::errors::InvalidArgument (
69-             " Input(X)'s width must be equal with Input(Y)'s height in BmmOp," 
70-             " but receive X's width: [%s]," 
71-             " Y's height: [%s]." 
72-             x_dims[2 ],
73-             y_dims[1 ]));
74- 
75-     std::vector<int64_t > dim_out;
76-     dim_out.push_back (x_dims[0 ]);
77-     dim_out.push_back (x_dims[1 ]);
78-     dim_out.push_back (y_dims[2 ]);
79-     ctx->SetOutputDim (" Out" phi::make_ddim (dim_out));
80-     ctx->ShareLoD (" X" /* ->*/ " Out" 
81-   }
82- 
8332  framework::OpKernelType GetExpectedKernelType (
8433      const  framework::ExecutionContext& ctx) const  override  {
8534    auto  data_type = OperatorWithKernel::IndicateVarDataType (ctx, " X" 
@@ -110,33 +59,6 @@ class BmmOpGrad : public framework::OperatorWithKernel {
11059  using  framework::OperatorWithKernel::OperatorWithKernel;
11160
11261 protected: 
113-   void  InferShape (framework::InferShapeContext* ctx) const  override  {
114-     PADDLE_ENFORCE_EQ (
115-         ctx->HasInput (" X" 
116-         true ,
117-         platform::errors::NotFound (" Input(X) of BmmOp should not be null" 
118-     PADDLE_ENFORCE_EQ (
119-         ctx->HasInput (" Y" 
120-         true ,
121-         platform::errors::NotFound (" Input(Y) of BmmOp should not be null" 
122-     PADDLE_ENFORCE_EQ (ctx->HasInput (framework::GradVarName (" Out" 
123-                       true ,
124-                       platform::errors::NotFound (
125-                           " Output(Out@GRAD) of BmmOp should not be null." 
126- 
127-     auto  x_dims = ctx->GetInputDim (" X" 
128-     auto  y_dims = ctx->GetInputDim (" Y" 
129- 
130-     auto  x_grad_name = framework::GradVarName (" X" 
131-     auto  y_grad_name = framework::GradVarName (" Y" 
132- 
133-     if  (ctx->HasOutput (x_grad_name)) {
134-       ctx->SetOutputDim (x_grad_name, x_dims);
135-     }
136-     if  (ctx->HasOutput (y_grad_name)) {
137-       ctx->SetOutputDim (y_grad_name, y_dims);
138-     }
139-   }
14062  framework::OpKernelType GetExpectedKernelType (
14163      const  framework::ExecutionContext& ctx) const  override  {
14264    return  framework::OpKernelType (OperatorWithKernel::IndicateVarDataType (
@@ -166,15 +88,16 @@ class BmmOpGradMaker : public framework::SingleGradOpMaker<T> {
16688
16789namespace  ops  =  paddle::operators;
16890
91+ DECLARE_INFER_SHAPE_FUNCTOR (bmm,
92+                             BmmInferShapeFunctor,
93+                             PD_INFER_META (phi::BmmInferMeta));
94+ DECLARE_INFER_SHAPE_FUNCTOR (bmm_grad,
95+                             BmmGradInferShapeFunctor,
96+                             PD_INFER_META (phi::BmmGradInferMeta));
16997REGISTER_OPERATOR (bmm,
17098                  ops::BmmOp,
17199                  ops::BmmOpMaker,
172100                  ops::BmmOpGradMaker<paddle::framework::OpDesc>,
173-                   ops::BmmOpGradMaker<paddle::imperative::OpBase>);
174- REGISTER_OPERATOR (bmm_grad, ops::BmmOpGrad);
175- REGISTER_OP_CPU_KERNEL (bmm,
176-                        ops::BmmKernel<phi::CPUContext, float >,
177-                        ops::BmmKernel<phi::CPUContext, double >);
178- REGISTER_OP_CPU_KERNEL (bmm_grad,
179-                        ops::BmmGradKernel<phi::CPUContext, float >,
180-                        ops::BmmGradKernel<phi::CPUContext, double >);
101+                   ops::BmmOpGradMaker<paddle::imperative::OpBase>,
102+                   BmmInferShapeFunctor);
103+ REGISTER_OPERATOR (bmm_grad, ops::BmmOpGrad, BmmGradInferShapeFunctor);
0 commit comments