|
12 | 12 | // See the License for the specific language governing permissions and
|
13 | 13 | // limitations under the License.
|
14 | 14 |
|
15 |
| -#include "paddle/fluid/operators/reduce_ops/logsumexp_op.h" |
16 | 15 | #include <algorithm>
|
17 | 16 | #include <string>
|
18 | 17 | #include <vector>
|
| 18 | +#include "paddle/fluid/framework/infershape_utils.h" |
| 19 | +#include "paddle/fluid/operators/reduce_ops/reduce_op_function.h" |
| 20 | +#include "paddle/phi/core/infermeta_utils.h" |
| 21 | +#include "paddle/phi/infermeta/unary.h" |
19 | 22 |
|
20 | 23 | namespace paddle {
|
21 | 24 | namespace operators {
|
22 | 25 |
|
23 | 26 | class LogsumexpOp : public framework::OperatorWithKernel {
|
24 | 27 | public:
|
25 | 28 | using framework::OperatorWithKernel::OperatorWithKernel;
|
26 |
| - |
27 |
| - void InferShape(framework::InferShapeContext* ctx) const override { |
28 |
| - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "logsumexp"); |
29 |
| - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "logsumexp"); |
30 |
| - auto x_dims = ctx->GetInputDim("X"); |
31 |
| - auto x_rank = x_dims.size(); |
32 |
| - PADDLE_ENFORCE_LE(x_rank, 4, |
33 |
| - platform::errors::InvalidArgument( |
34 |
| - "The input tensor X's dimensions of logsumexp " |
35 |
| - "should be less or equal than 4. But received X's " |
36 |
| - "dimensions = %d, X's shape = [%s].", |
37 |
| - x_rank, x_dims)); |
38 |
| - auto axis = ctx->Attrs().Get<std::vector<int>>("axis"); |
39 |
| - PADDLE_ENFORCE_GT( |
40 |
| - axis.size(), 0, |
41 |
| - platform::errors::InvalidArgument( |
42 |
| - "The size of axis of logsumexp " |
43 |
| - "should be greater than 0. But received the size of axis " |
44 |
| - "of logsumexp is %d.", |
45 |
| - axis.size())); |
46 |
| - |
47 |
| - for (size_t i = 0; i < axis.size(); i++) { |
48 |
| - PADDLE_ENFORCE_LT(axis[i], x_rank, |
49 |
| - platform::errors::InvalidArgument( |
50 |
| - "axis[%d] should be in the " |
51 |
| - "range [-D, D), where D is the dimensions of X and " |
52 |
| - "D is %d. But received axis[%d] = %d.", |
53 |
| - i, x_rank, i, axis[i])); |
54 |
| - PADDLE_ENFORCE_GE(axis[i], -x_rank, |
55 |
| - platform::errors::InvalidArgument( |
56 |
| - "axis[%d] should be in the " |
57 |
| - "range [-D, D), where D is the dimensions of X and " |
58 |
| - "D is %d. But received axis[%d] = %d.", |
59 |
| - i, x_rank, i, axis[i])); |
60 |
| - if (axis[i] < 0) { |
61 |
| - axis[i] += x_rank; |
62 |
| - } |
63 |
| - } |
64 |
| - |
65 |
| - bool keepdim = ctx->Attrs().Get<bool>("keepdim"); |
66 |
| - bool reduce_all = ctx->Attrs().Get<bool>("reduce_all"); |
67 |
| - auto dims_vector = vectorize(x_dims); |
68 |
| - if (reduce_all) { |
69 |
| - if (keepdim) |
70 |
| - ctx->SetOutputDim("Out", |
71 |
| - phi::make_ddim(std::vector<int64_t>(x_rank, 1))); |
72 |
| - else |
73 |
| - ctx->SetOutputDim("Out", {1}); |
74 |
| - } else { |
75 |
| - auto dims_vector = vectorize(x_dims); |
76 |
| - if (keepdim) { |
77 |
| - for (size_t i = 0; i < axis.size(); ++i) { |
78 |
| - dims_vector[axis[i]] = 1; |
79 |
| - } |
80 |
| - } else { |
81 |
| - const int kDelFlag = -1; |
82 |
| - for (size_t i = 0; i < axis.size(); ++i) { |
83 |
| - dims_vector[axis[i]] = kDelFlag; |
84 |
| - } |
85 |
| - dims_vector.erase( |
86 |
| - std::remove(dims_vector.begin(), dims_vector.end(), kDelFlag), |
87 |
| - dims_vector.end()); |
88 |
| - } |
89 |
| - if (!keepdim && dims_vector.size() == 0) { |
90 |
| - dims_vector.push_back(1); |
91 |
| - } |
92 |
| - auto out_dims = phi::make_ddim(dims_vector); |
93 |
| - ctx->SetOutputDim("Out", out_dims); |
94 |
| - if (axis.size() > 0 && axis[0] != 0) { |
95 |
| - // Only pass LoD when not reducing on the first dim. |
96 |
| - ctx->ShareLoD("X", /*->*/ "Out"); |
97 |
| - } |
98 |
| - } |
99 |
| - } |
100 | 29 | };
|
101 | 30 |
|
102 | 31 | class LogsumexpOpMaker : public framework::OpProtoAndCheckerMaker {
|
@@ -164,16 +93,10 @@ class LogsumexpGradOpMaker : public framework::SingleGradOpMaker<T> {
|
164 | 93 | } // namespace paddle
|
165 | 94 |
|
166 | 95 | namespace ops = paddle::operators;
|
167 |
| - |
| 96 | +DECLARE_INFER_SHAPE_FUNCTOR(logsumexp, LogsumexpInferShapeFunctor, |
| 97 | + PD_INFER_META(phi::LogsumexpInferMeta)); |
168 | 98 | REGISTER_OPERATOR(logsumexp, ops::LogsumexpOp, ops::LogsumexpOpMaker,
|
169 | 99 | ops::LogsumexpGradOpMaker<paddle::framework::OpDesc>,
|
170 |
| - ops::LogsumexpGradOpMaker<paddle::imperative::OpBase>); |
| 100 | + ops::LogsumexpGradOpMaker<paddle::imperative::OpBase>, |
| 101 | + LogsumexpInferShapeFunctor); |
171 | 102 | REGISTER_OPERATOR(logsumexp_grad, ops::LogsumexpGrapOp);
|
172 |
| - |
173 |
| -REGISTER_OP_CPU_KERNEL( |
174 |
| - logsumexp, ops::LogsumexpKernel<paddle::platform::CPUDeviceContext, float>, |
175 |
| - ops::LogsumexpKernel<paddle::platform::CPUDeviceContext, double>); |
176 |
| -REGISTER_OP_CPU_KERNEL( |
177 |
| - logsumexp_grad, |
178 |
| - ops::LogsumexpGradKernel<paddle::platform::CPUDeviceContext, float>, |
179 |
| - ops::LogsumexpGradKernel<paddle::platform::CPUDeviceContext, double>); |
0 commit comments