Skip to content

Commit 9c0eaad

Browse files
authored
[Phi] trans logsumexp op (#40790)
* [Phi] trans logsumexp op * fix bugs * fix bugs * fix bugs * fix bugs * fix bugs * fix bugs * add sig * fix sig bugs * fix sig bugs * fix xpu bugs * fix review bugs * test=develop
1 parent b532315 commit 9c0eaad

14 files changed

+440
-270
lines changed

paddle/fluid/operators/reduce_ops/logsumexp_op.cc

Lines changed: 8 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -12,91 +12,20 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
#include "paddle/fluid/operators/reduce_ops/logsumexp_op.h"
1615
#include <algorithm>
1716
#include <string>
1817
#include <vector>
18+
#include "paddle/fluid/framework/infershape_utils.h"
19+
#include "paddle/fluid/operators/reduce_ops/reduce_op_function.h"
20+
#include "paddle/phi/core/infermeta_utils.h"
21+
#include "paddle/phi/infermeta/unary.h"
1922

2023
namespace paddle {
2124
namespace operators {
2225

2326
class LogsumexpOp : public framework::OperatorWithKernel {
2427
public:
2528
using framework::OperatorWithKernel::OperatorWithKernel;
26-
27-
void InferShape(framework::InferShapeContext* ctx) const override {
28-
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "logsumexp");
29-
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "logsumexp");
30-
auto x_dims = ctx->GetInputDim("X");
31-
auto x_rank = x_dims.size();
32-
PADDLE_ENFORCE_LE(x_rank, 4,
33-
platform::errors::InvalidArgument(
34-
"The input tensor X's dimensions of logsumexp "
35-
"should be less or equal than 4. But received X's "
36-
"dimensions = %d, X's shape = [%s].",
37-
x_rank, x_dims));
38-
auto axis = ctx->Attrs().Get<std::vector<int>>("axis");
39-
PADDLE_ENFORCE_GT(
40-
axis.size(), 0,
41-
platform::errors::InvalidArgument(
42-
"The size of axis of logsumexp "
43-
"should be greater than 0. But received the size of axis "
44-
"of logsumexp is %d.",
45-
axis.size()));
46-
47-
for (size_t i = 0; i < axis.size(); i++) {
48-
PADDLE_ENFORCE_LT(axis[i], x_rank,
49-
platform::errors::InvalidArgument(
50-
"axis[%d] should be in the "
51-
"range [-D, D), where D is the dimensions of X and "
52-
"D is %d. But received axis[%d] = %d.",
53-
i, x_rank, i, axis[i]));
54-
PADDLE_ENFORCE_GE(axis[i], -x_rank,
55-
platform::errors::InvalidArgument(
56-
"axis[%d] should be in the "
57-
"range [-D, D), where D is the dimensions of X and "
58-
"D is %d. But received axis[%d] = %d.",
59-
i, x_rank, i, axis[i]));
60-
if (axis[i] < 0) {
61-
axis[i] += x_rank;
62-
}
63-
}
64-
65-
bool keepdim = ctx->Attrs().Get<bool>("keepdim");
66-
bool reduce_all = ctx->Attrs().Get<bool>("reduce_all");
67-
auto dims_vector = vectorize(x_dims);
68-
if (reduce_all) {
69-
if (keepdim)
70-
ctx->SetOutputDim("Out",
71-
phi::make_ddim(std::vector<int64_t>(x_rank, 1)));
72-
else
73-
ctx->SetOutputDim("Out", {1});
74-
} else {
75-
auto dims_vector = vectorize(x_dims);
76-
if (keepdim) {
77-
for (size_t i = 0; i < axis.size(); ++i) {
78-
dims_vector[axis[i]] = 1;
79-
}
80-
} else {
81-
const int kDelFlag = -1;
82-
for (size_t i = 0; i < axis.size(); ++i) {
83-
dims_vector[axis[i]] = kDelFlag;
84-
}
85-
dims_vector.erase(
86-
std::remove(dims_vector.begin(), dims_vector.end(), kDelFlag),
87-
dims_vector.end());
88-
}
89-
if (!keepdim && dims_vector.size() == 0) {
90-
dims_vector.push_back(1);
91-
}
92-
auto out_dims = phi::make_ddim(dims_vector);
93-
ctx->SetOutputDim("Out", out_dims);
94-
if (axis.size() > 0 && axis[0] != 0) {
95-
// Only pass LoD when not reducing on the first dim.
96-
ctx->ShareLoD("X", /*->*/ "Out");
97-
}
98-
}
99-
}
10029
};
10130

10231
class LogsumexpOpMaker : public framework::OpProtoAndCheckerMaker {
@@ -164,16 +93,10 @@ class LogsumexpGradOpMaker : public framework::SingleGradOpMaker<T> {
16493
} // namespace paddle
16594

16695
namespace ops = paddle::operators;
167-
96+
DECLARE_INFER_SHAPE_FUNCTOR(logsumexp, LogsumexpInferShapeFunctor,
97+
PD_INFER_META(phi::LogsumexpInferMeta));
16898
REGISTER_OPERATOR(logsumexp, ops::LogsumexpOp, ops::LogsumexpOpMaker,
16999
ops::LogsumexpGradOpMaker<paddle::framework::OpDesc>,
170-
ops::LogsumexpGradOpMaker<paddle::imperative::OpBase>);
100+
ops::LogsumexpGradOpMaker<paddle::imperative::OpBase>,
101+
LogsumexpInferShapeFunctor);
171102
REGISTER_OPERATOR(logsumexp_grad, ops::LogsumexpGrapOp);
172-
173-
REGISTER_OP_CPU_KERNEL(
174-
logsumexp, ops::LogsumexpKernel<paddle::platform::CPUDeviceContext, float>,
175-
ops::LogsumexpKernel<paddle::platform::CPUDeviceContext, double>);
176-
REGISTER_OP_CPU_KERNEL(
177-
logsumexp_grad,
178-
ops::LogsumexpGradKernel<paddle::platform::CPUDeviceContext, float>,
179-
ops::LogsumexpGradKernel<paddle::platform::CPUDeviceContext, double>);

paddle/fluid/operators/reduce_ops/logsumexp_op.h

Lines changed: 0 additions & 170 deletions
This file was deleted.

paddle/fluid/operators/reduce_ops/logsumexp_op_xpu.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
#ifdef PADDLE_WITH_XPU
1616

17-
#include "paddle/fluid/operators/reduce_ops/logsumexp_op.h"
17+
#include "paddle/fluid/operators/reduce_ops/reduce_op_function.h"
1818
#include "paddle/fluid/platform/device/xpu/xpu_header.h"
1919
#include "paddle/fluid/platform/device_context.h"
2020

paddle/phi/infermeta/unary.cc

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -804,6 +804,91 @@ void KthvalueInferMeta(const MetaTensor& x,
804804
indices->set_dtype(x.dtype());
805805
}
806806

807+
void LogsumexpInferMeta(const MetaTensor& input,
808+
const std::vector<int64_t>& axis,
809+
bool keepdim,
810+
bool reduce_all,
811+
MetaTensor* out) {
812+
auto x_dims = input.dims();
813+
auto x_rank = x_dims.size();
814+
std::vector<int64_t> formated_axis = axis;
815+
PADDLE_ENFORCE_LE(x_rank,
816+
4,
817+
errors::InvalidArgument(
818+
"The input tensor X's dimensions of logsumexp "
819+
"should be less or equal than 4. But received X's "
820+
"dimensions = %d, X's shape = [%s].",
821+
x_rank,
822+
x_dims));
823+
PADDLE_ENFORCE_GT(
824+
axis.size(),
825+
0,
826+
errors::InvalidArgument(
827+
"The size of axis of logsumexp "
828+
"should be greater than 0. But received the size of axis "
829+
"of logsumexp is %d.",
830+
axis.size()));
831+
832+
for (size_t i = 0; i < axis.size(); i++) {
833+
PADDLE_ENFORCE_LT(axis[i],
834+
x_rank,
835+
errors::InvalidArgument(
836+
"axis[%d] should be in the "
837+
"range [-D, D), where D is the dimensions of X and "
838+
"D is %d. But received axis[%d] = %d.",
839+
i,
840+
x_rank,
841+
i,
842+
axis[i]));
843+
PADDLE_ENFORCE_GE(axis[i],
844+
-x_rank,
845+
errors::InvalidArgument(
846+
"axis[%d] should be in the "
847+
"range [-D, D), where D is the dimensions of X and "
848+
"D is %d. But received axis[%d] = %d.",
849+
i,
850+
x_rank,
851+
i,
852+
axis[i]));
853+
if (axis[i] < 0) {
854+
formated_axis[i] += x_rank;
855+
}
856+
}
857+
858+
auto dims_vector = vectorize(x_dims);
859+
if (reduce_all) {
860+
if (keepdim)
861+
out->set_dims(phi::make_ddim(std::vector<int64_t>(x_rank, 1)));
862+
else
863+
out->set_dims({1});
864+
} else {
865+
auto dims_vector = vectorize(x_dims);
866+
if (keepdim) {
867+
for (size_t i = 0; i < formated_axis.size(); ++i) {
868+
dims_vector[formated_axis[i]] = 1;
869+
}
870+
} else {
871+
const int kDelFlag = -1;
872+
for (size_t i = 0; i < formated_axis.size(); ++i) {
873+
dims_vector[formated_axis[i]] = kDelFlag;
874+
}
875+
dims_vector.erase(
876+
std::remove(dims_vector.begin(), dims_vector.end(), kDelFlag),
877+
dims_vector.end());
878+
}
879+
if (!keepdim && dims_vector.size() == 0) {
880+
dims_vector.push_back(1);
881+
}
882+
auto out_dims = phi::make_ddim(dims_vector);
883+
out->set_dims(out_dims);
884+
if (formated_axis.size() > 0 && formated_axis[0] != 0) {
885+
// Only pass LoD when not reducing on the first dim.
886+
out->share_lod(input);
887+
}
888+
}
889+
out->set_dtype(input.dtype());
890+
}
891+
807892
void MatrixPowerInferMeta(const MetaTensor& x, int n, MetaTensor* out) {
808893
auto dims = x.dims();
809894
auto n_dim = dims.size();

0 commit comments

Comments
 (0)