Skip to content

Commit 847c969

Browse files
committed
add unique_consecutive op
1 parent 5f0fd5f commit 847c969

File tree

1 file changed

+129
-1
lines changed

1 file changed

+129
-1
lines changed

paddle/fluid/operators/unique_op.cu

Lines changed: 129 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,141 @@ limitations under the License. */
2222
#include <vector>
2323
#include "paddle/fluid/framework/tensor_util.h" // TensorToVector()
2424
#include "paddle/fluid/operators/unique_op.h" // TransComute()
25-
#include "paddle/fluid/operators/unique_utils.h"
2625

2726
namespace paddle {
2827
namespace operators {
2928

3029
using Tensor = framework::Tensor;
3130

31+
// Binary function 'less than'
32+
template <typename InT>
33+
struct LessThan {
34+
int col;
35+
const InT* in_trans_data;
36+
37+
LessThan(int64_t _col, const InT* _in_trans_data)
38+
: col(_col), in_trans_data(_in_trans_data) {}
39+
40+
__device__ bool operator()(int64_t a, int64_t b) const {
41+
for (int i = 0; i < col; ++i) {
42+
InT lhs = in_trans_data[i + a * col];
43+
InT rhs = in_trans_data[i + b * col];
44+
if (lhs < rhs) {
45+
return true;
46+
} else if (lhs > rhs) {
47+
return false;
48+
}
49+
}
50+
return false;
51+
}
52+
};
53+
54+
// Binary function 'equal_to'
55+
template <typename InT>
56+
struct BinaryEqual {
57+
int64_t col;
58+
const InT* in_trans_data;
59+
60+
BinaryEqual(int64_t _col, const InT* _in_trans_data)
61+
: col(_col), in_trans_data(_in_trans_data) {}
62+
63+
__device__ bool operator()(int64_t a, int64_t b) const {
64+
for (int64_t i = 0; i < col; ++i) {
65+
InT lhs = in_trans_data[i + a * col];
66+
InT rhs = in_trans_data[i + b * col];
67+
if (lhs != rhs) {
68+
return false;
69+
}
70+
}
71+
return true;
72+
}
73+
};
74+
75+
// Binary function 'not_equal_to'
76+
template <typename InT>
77+
struct BinaryNotEqual {
78+
int64_t col;
79+
const InT* in_trans_data;
80+
81+
BinaryNotEqual(int64_t _col, const InT* _in_trans_data)
82+
: col(_col), in_trans_data(_in_trans_data) {}
83+
84+
__device__ bool operator()(int64_t a, int64_t b) const {
85+
for (int64_t i = 0; i < col; ++i) {
86+
InT lhs = in_trans_data[i + a * col];
87+
InT rhs = in_trans_data[i + b * col];
88+
if (lhs != rhs) {
89+
return true;
90+
}
91+
}
92+
return false;
93+
}
94+
};
95+
96+
// index_select() function for Tensor
97+
template <typename InT, typename IndexT>
98+
void IndexSelect(const framework::ExecutionContext& context,
99+
const Tensor& input, const Tensor& index, Tensor* output,
100+
int dim) {
101+
auto input_dim = input.dims();
102+
auto input_dim_size = input_dim.size();
103+
auto output_dim = output->dims();
104+
105+
auto slice_size = 1;
106+
for (auto i = dim + 1; i < input_dim_size; i++) {
107+
slice_size *= input_dim[i];
108+
}
109+
110+
auto input_width = slice_size * input_dim[dim];
111+
auto output_width = slice_size * output_dim[dim];
112+
113+
auto outer_nums = 1;
114+
for (auto i = 0; i < dim; i++) {
115+
outer_nums *= input_dim[i];
116+
}
117+
118+
auto index_size = index.dims()[0];
119+
120+
std::vector<InT> input_vec;
121+
std::vector<IndexT> index_vec;
122+
TensorToVector(input, context.device_context(), &input_vec);
123+
TensorToVector(index, context.device_context(), &index_vec);
124+
std::vector<InT> out_vec(output->numel());
125+
126+
for (int i = 0; i < index_size; i++) {
127+
PADDLE_ENFORCE_GE(
128+
index_vec[i], 0,
129+
platform::errors::InvalidArgument(
130+
"Variable value (index) of OP(index_select) "
131+
"expected >= 0 and < %ld, but got %ld. Please check input "
132+
"value.",
133+
input_dim[dim], index_vec[i]));
134+
PADDLE_ENFORCE_LT(
135+
index_vec[i], input_dim[dim],
136+
platform::errors::InvalidArgument(
137+
"Variable value (index) of OP(index_select) "
138+
"expected >= 0 and < %ld, but got %ld. Please check input "
139+
"value.",
140+
input_dim[dim], index_vec[i]));
141+
}
142+
143+
for (auto i = 0; i < outer_nums; i++) {
144+
auto input_start_offset = i * input_width;
145+
auto output_start_offset = i * output_width;
146+
147+
for (auto j = 0; j < index_size; j++) {
148+
IndexT index_value = index_vec[j];
149+
for (auto k = 0; k < slice_size; k++) {
150+
out_vec[output_start_offset + j * slice_size + k] =
151+
input_vec[input_start_offset + index_value * slice_size + k];
152+
}
153+
}
154+
}
155+
output->mutable_data<InT>(context.GetPlace());
156+
framework::TensorFromVector(out_vec, context.device_context(), output);
157+
output->Resize(output_dim);
158+
}
159+
32160
// The core logic of computing Unique for a flattend Tensor
33161
template <typename InT, typename IndexT, typename equal_T, typename not_equal_T>
34162
static void UniqueFlattendCUDATensor(const framework::ExecutionContext& context,

0 commit comments

Comments
 (0)