@@ -22,13 +22,141 @@ limitations under the License. */
2222#include < vector>
2323#include " paddle/fluid/framework/tensor_util.h" // TensorToVector()
2424#include " paddle/fluid/operators/unique_op.h" // TransComute()
25- #include " paddle/fluid/operators/unique_utils.h"
2625
2726namespace paddle {
2827namespace operators {
2928
3029using Tensor = framework::Tensor;
3130
31+ // Binary function 'less than'
32+ template <typename InT>
33+ struct LessThan {
34+ int col;
35+ const InT* in_trans_data;
36+
37+ LessThan (int64_t _col, const InT* _in_trans_data)
38+ : col(_col), in_trans_data(_in_trans_data) {}
39+
40+ __device__ bool operator ()(int64_t a, int64_t b) const {
41+ for (int i = 0 ; i < col; ++i) {
42+ InT lhs = in_trans_data[i + a * col];
43+ InT rhs = in_trans_data[i + b * col];
44+ if (lhs < rhs) {
45+ return true ;
46+ } else if (lhs > rhs) {
47+ return false ;
48+ }
49+ }
50+ return false ;
51+ }
52+ };
53+
54+ // Binary function 'equal_to'
55+ template <typename InT>
56+ struct BinaryEqual {
57+ int64_t col;
58+ const InT* in_trans_data;
59+
60+ BinaryEqual (int64_t _col, const InT* _in_trans_data)
61+ : col(_col), in_trans_data(_in_trans_data) {}
62+
63+ __device__ bool operator ()(int64_t a, int64_t b) const {
64+ for (int64_t i = 0 ; i < col; ++i) {
65+ InT lhs = in_trans_data[i + a * col];
66+ InT rhs = in_trans_data[i + b * col];
67+ if (lhs != rhs) {
68+ return false ;
69+ }
70+ }
71+ return true ;
72+ }
73+ };
74+
75+ // Binary function 'not_equal_to'
76+ template <typename InT>
77+ struct BinaryNotEqual {
78+ int64_t col;
79+ const InT* in_trans_data;
80+
81+ BinaryNotEqual (int64_t _col, const InT* _in_trans_data)
82+ : col(_col), in_trans_data(_in_trans_data) {}
83+
84+ __device__ bool operator ()(int64_t a, int64_t b) const {
85+ for (int64_t i = 0 ; i < col; ++i) {
86+ InT lhs = in_trans_data[i + a * col];
87+ InT rhs = in_trans_data[i + b * col];
88+ if (lhs != rhs) {
89+ return true ;
90+ }
91+ }
92+ return false ;
93+ }
94+ };
95+
96+ // index_select() function for Tensor
97+ template <typename InT, typename IndexT>
98+ void IndexSelect (const framework::ExecutionContext& context,
99+ const Tensor& input, const Tensor& index, Tensor* output,
100+ int dim) {
101+ auto input_dim = input.dims ();
102+ auto input_dim_size = input_dim.size ();
103+ auto output_dim = output->dims ();
104+
105+ auto slice_size = 1 ;
106+ for (auto i = dim + 1 ; i < input_dim_size; i++) {
107+ slice_size *= input_dim[i];
108+ }
109+
110+ auto input_width = slice_size * input_dim[dim];
111+ auto output_width = slice_size * output_dim[dim];
112+
113+ auto outer_nums = 1 ;
114+ for (auto i = 0 ; i < dim; i++) {
115+ outer_nums *= input_dim[i];
116+ }
117+
118+ auto index_size = index.dims ()[0 ];
119+
120+ std::vector<InT> input_vec;
121+ std::vector<IndexT> index_vec;
122+ TensorToVector (input, context.device_context (), &input_vec);
123+ TensorToVector (index, context.device_context (), &index_vec);
124+ std::vector<InT> out_vec (output->numel ());
125+
126+ for (int i = 0 ; i < index_size; i++) {
127+ PADDLE_ENFORCE_GE (
128+ index_vec[i], 0 ,
129+ platform::errors::InvalidArgument (
130+ " Variable value (index) of OP(index_select) "
131+ " expected >= 0 and < %ld, but got %ld. Please check input "
132+ " value." ,
133+ input_dim[dim], index_vec[i]));
134+ PADDLE_ENFORCE_LT (
135+ index_vec[i], input_dim[dim],
136+ platform::errors::InvalidArgument (
137+ " Variable value (index) of OP(index_select) "
138+ " expected >= 0 and < %ld, but got %ld. Please check input "
139+ " value." ,
140+ input_dim[dim], index_vec[i]));
141+ }
142+
143+ for (auto i = 0 ; i < outer_nums; i++) {
144+ auto input_start_offset = i * input_width;
145+ auto output_start_offset = i * output_width;
146+
147+ for (auto j = 0 ; j < index_size; j++) {
148+ IndexT index_value = index_vec[j];
149+ for (auto k = 0 ; k < slice_size; k++) {
150+ out_vec[output_start_offset + j * slice_size + k] =
151+ input_vec[input_start_offset + index_value * slice_size + k];
152+ }
153+ }
154+ }
155+ output->mutable_data <InT>(context.GetPlace ());
156+ framework::TensorFromVector (out_vec, context.device_context (), output);
157+ output->Resize (output_dim);
158+ }
159+
32160// The core logic of computing Unique for a flattend Tensor
33161template <typename InT, typename IndexT, typename equal_T, typename not_equal_T>
34162static void UniqueFlattendCUDATensor (const framework::ExecutionContext& context,
0 commit comments