Skip to content

Commit 875cfd5

Browse files
authored
add unique_consecutive_op (#34334)
* add unique_consecutive_op * add unique_consecutive_op * add unique_consecutive_op * add unique_consecutive_op * add unique_consecutive_op * add unique_consecutive_op * add unique_consecutive_op * add unique_consecutive_op * remove unity build * add unique_consecutive op * add unique_consecutive op * add enable static * add noqa * add space line * add default case. * add comma * add space line * modify unique_consecutive unittest * optimize ut coverage * rebase develop * improve coverage * update en docs * update en docs * update en docs * update en docs * update en docs * update en doc
1 parent e29c2d1 commit 875cfd5

File tree

8 files changed

+1183
-0
lines changed

8 files changed

+1183
-0
lines changed
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/unique_consecutive_op.h"
16+
#include "paddle/fluid/framework/op_version_registry.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
21+
class UniqueConsecutiveOp : public framework::OperatorWithKernel {
22+
public:
23+
using framework::OperatorWithKernel::OperatorWithKernel;
24+
25+
void InferShape(framework::InferShapeContext* ctx) const override {
26+
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "unique_consecutive");
27+
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
28+
"unique_consecutive");
29+
30+
auto in_dims = ctx->GetInputDim("X");
31+
bool return_inverse = ctx->Attrs().Get<bool>("return_inverse");
32+
bool return_counts = ctx->Attrs().Get<bool>("return_counts");
33+
auto axis_vec = ctx->Attrs().Get<std::vector<int>>("axis");
34+
if (return_inverse) {
35+
OP_INOUT_CHECK(ctx->HasOutput("Index"), "Output", "Index",
36+
"unique_consecutive");
37+
}
38+
if (return_counts) {
39+
OP_INOUT_CHECK(ctx->HasOutput("Counts"), "Output", "Counts",
40+
"unique_consecutive");
41+
}
42+
43+
if (axis_vec.empty()) {
44+
ctx->SetOutputDim("Out", {-1});
45+
if (return_inverse) {
46+
ctx->SetOutputDim("Index", {framework::product(in_dims)});
47+
}
48+
} else {
49+
int axis = axis_vec[0];
50+
if (axis < 0) {
51+
axis += in_dims.size();
52+
}
53+
PADDLE_ENFORCE_LT(
54+
axis, in_dims.size(),
55+
platform::errors::InvalidArgument("The axis(%d) should be less than "
56+
"the dimension size(%d) of x.",
57+
axis, in_dims.size()));
58+
auto out_dims = in_dims;
59+
out_dims[axis] = -1;
60+
ctx->SetOutputDim("Out", out_dims);
61+
if (return_inverse) {
62+
ctx->SetOutputDim("Index", {in_dims[axis]});
63+
}
64+
}
65+
if (return_counts) {
66+
ctx->SetOutputDim("Counts", {-1});
67+
}
68+
}
69+
70+
protected:
71+
framework::OpKernelType GetExpectedKernelType(
72+
const framework::ExecutionContext& ctx) const override {
73+
return framework::OpKernelType(
74+
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
75+
}
76+
};
77+
78+
class UniqueConsecutiveOpMaker : public framework::OpProtoAndCheckerMaker {
79+
public:
80+
void Make() override {
81+
AddInput("X", "The input tensor of unique_consecutive op.");
82+
AddAttr<int>("dtype",
83+
"(int, default 5(FP32)) "
84+
"data type for output index")
85+
.SetDefault(framework::proto::VarType::FP32);
86+
87+
AddOutput("Out", "A unique consecutive subsequence for input tensor.");
88+
AddOutput("Index",
89+
"The indices for where elements in the original input ended up "
90+
"in the returned unique tensor.")
91+
.AsDispensable();
92+
AddOutput("Counts", "The counts for each unique element.").AsDispensable();
93+
AddAttr<bool>(
94+
"return_inverse",
95+
"If True, also return the indices for where elements"
96+
" in the original input ended up in the returned unique tensor.")
97+
.SetDefault(false);
98+
AddAttr<bool>("return_counts",
99+
"If True, also return the counts for each unique element.")
100+
.SetDefault(false);
101+
AddAttr<std::vector<int>>(
102+
"axis",
103+
"The axis to apply unique. If None, the input will be flattened.")
104+
.SetDefault({});
105+
AddComment(R"DOC(
106+
This function is different from paddle.unique() in the sense that this
107+
function only eliminates consecutive duplicate values.
108+
)DOC");
109+
}
110+
};
111+
} // namespace operators
112+
} // namespace paddle
113+
114+
namespace ops = paddle::operators;
115+
REGISTER_OP_WITHOUT_GRADIENT(unique_consecutive, ops::UniqueConsecutiveOp,
116+
ops::UniqueConsecutiveOpMaker);
117+
REGISTER_OP_CPU_KERNEL(
118+
unique_consecutive,
119+
ops::UniqueConsecutiveKernel<paddle::platform::CPUDeviceContext, float>,
120+
ops::UniqueConsecutiveKernel<paddle::platform::CPUDeviceContext, double>,
121+
ops::UniqueConsecutiveKernel<paddle::platform::CPUDeviceContext, int32_t>,
122+
ops::UniqueConsecutiveKernel<paddle::platform::CPUDeviceContext, int64_t>);
123+
REGISTER_OP_VERSION(unique_consecutive)
124+
.AddCheckpoint(
125+
R"ROC(
126+
Upgrade unique_consecutive, add 2 outputs [Indices, Counts] and 3 attribute
127+
[return_inverse, return_counts, axis].
128+
)ROC",
129+
paddle::framework::compatible::OpVersionDesc()
130+
.NewOutput("Counts", "The counts for each unique element.")
131+
.NewAttr("return_inverse",
132+
"If True, also return the indices for where elements"
133+
" in the original input ended up in the returned unique "
134+
"tensor.",
135+
false)
136+
.NewAttr("return_counts",
137+
"If True, also return the counts for each unique element.",
138+
false)
139+
.NewAttr("axis",
140+
"The axis to apply unique. If None, the input will be "
141+
"flattened.",
142+
std::vector<int>{}));

0 commit comments

Comments
 (0)