21
21
* \file Use standard C library call.
22
22
*/
23
23
24
+ #include < builtin_fp16.h>
24
25
#include < dlpack/dlpack.h>
25
26
#include < tvm/runtime/registry.h>
26
27
@@ -42,6 +43,24 @@ bool CompareDescend(const std::pair<int64_t, DType>& lhs, const std::pair<int64_
42
43
return lhs.second > rhs.second ;
43
44
}
44
45
46
+ struct float16 {
47
+ uint16_t bits;
48
+ float to_float () const {
49
+ return __extendXfYf2__<uint16_t , uint16_t , 10 , float , uint32_t , 23 >(bits);
50
+ }
51
+ };
52
+
53
+ template <>
54
+ bool CompareAscend (const std::pair<int64_t , float16>& lhs, const std::pair<int64_t , float16>& rhs) {
55
+ return lhs.second .to_float () < rhs.second .to_float ();
56
+ }
57
+
58
+ template <>
59
+ bool CompareDescend (const std::pair<int64_t , float16>& lhs,
60
+ const std::pair<int64_t , float16>& rhs) {
61
+ return lhs.second .to_float () > rhs.second .to_float ();
62
+ }
63
+
45
64
// Argsort implemented C library sort for nms.
46
65
// Return indices of sorted tensor.
47
66
// By default, the last axis will be used to sort.
@@ -125,7 +144,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms").set_body([](TVMArgs args, TV
125
144
});
126
145
127
146
template <typename DataType, typename OutType>
128
- void sort_impl (DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend, bool is_argsort) {
147
+ void sort_impl (
148
+ DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend,
149
+ std::function<void (OutType*, size_t , const std::pair<int64_t , DataType>&)> epilogue) {
129
150
auto data_ptr = static_cast <DataType*>(input->data );
130
151
auto out_ptr = static_cast <OutType*>(output->data );
131
152
std::vector<std::pair<int64_t , DataType>> sorter;
@@ -153,27 +174,29 @@ void sort_impl(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend,
153
174
} else {
154
175
std::stable_sort (sorter.begin (), sorter.end (), CompareDescend<DataType>);
155
176
}
156
- if (is_argsort) {
157
- for (int64_t k = 0 ; k < input->shape [axis]; ++k) {
158
- out_ptr[base_idx + k * axis_mul_after] = static_cast <OutType>(sorter[k].first );
159
- }
160
- } else {
161
- for (int64_t k = 0 ; k < input->shape [axis]; ++k) {
162
- out_ptr[base_idx + k * axis_mul_after] = static_cast <OutType>(sorter[k].second );
163
- }
177
+ for (int64_t k = 0 ; k < input->shape [axis]; ++k) {
178
+ epilogue (out_ptr, base_idx + k * axis_mul_after, sorter[k]);
164
179
}
165
180
}
166
181
}
167
182
}
168
183
169
184
template <typename DataType, typename OutType>
170
185
void argsort (DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend) {
171
- return sort_impl<DataType, OutType>(input, output, axis, is_ascend, true );
186
+ return sort_impl<DataType, OutType>(
187
+ input, output, axis, is_ascend,
188
+ [](OutType* out_ptr, size_t index, const std::pair<int64_t , DataType>& sort_pair) {
189
+ out_ptr[index] = static_cast <OutType>(sort_pair.first );
190
+ });
172
191
}
173
192
174
193
template <typename DataType>
175
194
void sort (DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend) {
176
- return sort_impl<DataType, DataType>(input, output, axis, is_ascend, false );
195
+ return sort_impl<DataType, DataType>(
196
+ input, output, axis, is_ascend,
197
+ [](DataType* out_ptr, size_t index, const std::pair<int64_t , DataType>& sort_pair) {
198
+ out_ptr[index] = sort_pair.second ;
199
+ });
177
200
}
178
201
179
202
// Argsort implemented C library sort.
@@ -254,6 +277,18 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort").set_body([](TVMArgs args, TVMRet
254
277
} else {
255
278
LOG (FATAL) << " Unsupported output dtype: " << out_dtype;
256
279
}
280
+ } else if (data_dtype == " float16" ) {
281
+ if (out_dtype == " int32" ) {
282
+ argsort<float16, int32_t >(input, output, axis, is_ascend);
283
+ } else if (out_dtype == " int64" ) {
284
+ argsort<float16, int64_t >(input, output, axis, is_ascend);
285
+ } else if (out_dtype == " float32" ) {
286
+ argsort<float16, float >(input, output, axis, is_ascend);
287
+ } else if (out_dtype == " float64" ) {
288
+ argsort<float16, double >(input, output, axis, is_ascend);
289
+ } else {
290
+ LOG (FATAL) << " Unsupported output dtype: " << out_dtype;
291
+ }
257
292
} else {
258
293
LOG (FATAL) << " Unsupported input dtype: " << data_dtype;
259
294
}
@@ -295,6 +330,8 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.sort").set_body([](TVMArgs args, TVMRetVal
295
330
sort<int32_t >(input, output, axis, is_ascend);
296
331
} else if (data_dtype == " int64" ) {
297
332
sort<int64_t >(input, output, axis, is_ascend);
333
+ } else if (data_dtype == " float16" ) {
334
+ sort<float16>(input, output, axis, is_ascend);
298
335
} else {
299
336
LOG (FATAL) << " Unsupported input dtype: " << data_dtype;
300
337
}
@@ -432,6 +469,18 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.topk").set_body([](TVMArgs args, TVMRetVal
432
469
} else {
433
470
LOG (FATAL) << " Unsupported output dtype: " << out_dtype;
434
471
}
472
+ } else if (data_dtype == " float16" ) {
473
+ if (out_dtype == " int32" ) {
474
+ topk<float16, int32_t >(input, values_out, indices_out, k, axis, is_ascend);
475
+ } else if (out_dtype == " int64" ) {
476
+ topk<float16, int64_t >(input, values_out, indices_out, k, axis, is_ascend);
477
+ } else if (out_dtype == " float32" ) {
478
+ topk<float16, float >(input, values_out, indices_out, k, axis, is_ascend);
479
+ } else if (out_dtype == " float64" ) {
480
+ topk<float16, double >(input, values_out, indices_out, k, axis, is_ascend);
481
+ } else {
482
+ LOG (FATAL) << " Unsupported output dtype: " << out_dtype;
483
+ }
435
484
} else {
436
485
LOG (FATAL) << " Unsupported input dtype: " << data_dtype;
437
486
}
0 commit comments