|
12 | 12 | // See the License for the specific language governing permissions and
|
13 | 13 | // limitations under the License.
|
14 | 14 |
|
| 15 | +#include "paddle/fluid/framework/infershape_utils.h" |
15 | 16 | #include "paddle/fluid/framework/op_registry.h"
|
| 17 | +#include "paddle/phi/core/infermeta_utils.h" |
| 18 | +#include "paddle/phi/infermeta/ternary.h" |
16 | 19 |
|
17 | 20 | namespace paddle {
|
18 | 21 | namespace operators {
|
19 | 22 |
|
20 | 23 | class LerpOp : public framework::OperatorWithKernel {
|
21 | 24 | public:
|
22 | 25 | using framework::OperatorWithKernel::OperatorWithKernel;
|
23 |
| - |
24 |
| - void InferShape(framework::InferShapeContext* ctx) const override { |
25 |
| - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "lerp"); |
26 |
| - OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "lerp"); |
27 |
| - OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "lerp"); |
28 |
| - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "lerp"); |
29 |
| - |
30 |
| - auto x_dims = ctx->GetInputDim("X"); |
31 |
| - auto y_dims = ctx->GetInputDim("Y"); |
32 |
| - auto w_dims = ctx->GetInputDim("Weight"); |
33 |
| - framework::DDim out_dims; |
34 |
| - out_dims = GetOutputDims(x_dims, y_dims); |
35 |
| - if (w_dims.size() > 1 || w_dims[0] != 1) { |
36 |
| - out_dims = GetOutputDims(out_dims, w_dims); |
37 |
| - } |
38 |
| - |
39 |
| - ctx->SetOutputDim("Out", out_dims); |
40 |
| - ctx->ShareLoD("X", /*->*/ "Out"); |
41 |
| - } |
42 |
| - |
43 |
| - private: |
44 |
| - framework::DDim GetOutputDims(const framework::DDim& s_dims, |
45 |
| - const framework::DDim& l_dims) const { |
46 |
| - if (s_dims.size() > l_dims.size()) { |
47 |
| - return GetOutputDims(l_dims, s_dims); |
48 |
| - } |
49 |
| - std::vector<int64_t> shapes = phi::vectorize<int64_t>(l_dims); |
50 |
| - for (int i = s_dims.size() - 1, j = l_dims.size() - 1; i >= 0; --i, --j) { |
51 |
| - int64_t s = s_dims[i]; |
52 |
| - int64_t l = l_dims[j]; |
53 |
| - if (s != l) { |
54 |
| - if (l == 1) { |
55 |
| - shapes[j] = s; |
56 |
| - } else if (s != 1) { |
57 |
| - PADDLE_THROW(platform::errors::InvalidArgument( |
58 |
| - "The shape of tensor a %s:%d must match shape of tensor b " |
59 |
| - "%s:%d.", |
60 |
| - s_dims.to_str(), i, l_dims.to_str(), j)); |
61 |
| - } |
62 |
| - } |
63 |
| - } |
64 |
| - return phi::make_ddim(shapes); |
65 |
| - } |
66 | 26 | };
|
67 | 27 |
|
68 | 28 | class LerpOpMaker : public framework::OpProtoAndCheckerMaker {
|
@@ -125,10 +85,12 @@ DECLARE_INPLACE_OP_INFERER(LerpInplaceInferer, {"X", "Out"});
|
125 | 85 | } // namespace operators
|
126 | 86 | } // namespace paddle
|
127 | 87 |
|
| 88 | +DELCARE_INFER_SHAPE_FUNCTOR(lerp, LerpInferShapeFunctor, |
| 89 | + PT_INFER_META(phi::LerpInferMeta)); |
128 | 90 | REGISTER_OPERATOR(
|
129 | 91 | lerp, paddle::operators::LerpOp, paddle::operators::LerpOpMaker,
|
130 | 92 | paddle::operators::LerpOpGradMaker<paddle::framework::OpDesc>,
|
131 | 93 | paddle::operators::LerpOpGradMaker<paddle::imperative::OpBase>,
|
132 |
| - paddle::operators::LerpInplaceInferer); |
| 94 | + paddle::operators::LerpInplaceInferer, LerpInferShapeFunctor); |
133 | 95 |
|
134 | 96 | REGISTER_OPERATOR(lerp_grad, paddle::operators::LerpGradOp);
|
0 commit comments