Skip to content

Commit 124eaaf

Browse files
masahiylc
authored andcommitted
[Contrib] Support fp16 input in cpu sort (apache#8672)
1 parent f5dc9ed commit 124eaaf

File tree

3 files changed

+83
-32
lines changed

3 files changed

+83
-32
lines changed

src/runtime/contrib/sort/sort.cc

Lines changed: 60 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
* \file Use standard C library call.
2222
*/
2323

24+
#include <builtin_fp16.h>
2425
#include <dlpack/dlpack.h>
2526
#include <tvm/runtime/registry.h>
2627

@@ -42,6 +43,24 @@ bool CompareDescend(const std::pair<int64_t, DType>& lhs, const std::pair<int64_
4243
return lhs.second > rhs.second;
4344
}
4445

46+
struct float16 {
47+
uint16_t bits;
48+
float to_float() const {
49+
return __extendXfYf2__<uint16_t, uint16_t, 10, float, uint32_t, 23>(bits);
50+
}
51+
};
52+
53+
template <>
54+
bool CompareAscend(const std::pair<int64_t, float16>& lhs, const std::pair<int64_t, float16>& rhs) {
55+
return lhs.second.to_float() < rhs.second.to_float();
56+
}
57+
58+
template <>
59+
bool CompareDescend(const std::pair<int64_t, float16>& lhs,
60+
const std::pair<int64_t, float16>& rhs) {
61+
return lhs.second.to_float() > rhs.second.to_float();
62+
}
63+
4564
// Argsort implemented C library sort for nms.
4665
// Return indices of sorted tensor.
4766
// By default, the last axis will be used to sort.
@@ -125,7 +144,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms").set_body([](TVMArgs args, TV
125144
});
126145

127146
template <typename DataType, typename OutType>
128-
void sort_impl(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend, bool is_argsort) {
147+
void sort_impl(
148+
DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend,
149+
std::function<void(OutType*, size_t, const std::pair<int64_t, DataType>&)> epilogue) {
129150
auto data_ptr = static_cast<DataType*>(input->data);
130151
auto out_ptr = static_cast<OutType*>(output->data);
131152
std::vector<std::pair<int64_t, DataType>> sorter;
@@ -153,27 +174,29 @@ void sort_impl(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend,
153174
} else {
154175
std::stable_sort(sorter.begin(), sorter.end(), CompareDescend<DataType>);
155176
}
156-
if (is_argsort) {
157-
for (int64_t k = 0; k < input->shape[axis]; ++k) {
158-
out_ptr[base_idx + k * axis_mul_after] = static_cast<OutType>(sorter[k].first);
159-
}
160-
} else {
161-
for (int64_t k = 0; k < input->shape[axis]; ++k) {
162-
out_ptr[base_idx + k * axis_mul_after] = static_cast<OutType>(sorter[k].second);
163-
}
177+
for (int64_t k = 0; k < input->shape[axis]; ++k) {
178+
epilogue(out_ptr, base_idx + k * axis_mul_after, sorter[k]);
164179
}
165180
}
166181
}
167182
}
168183

169184
template <typename DataType, typename OutType>
170185
void argsort(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend) {
171-
return sort_impl<DataType, OutType>(input, output, axis, is_ascend, true);
186+
return sort_impl<DataType, OutType>(
187+
input, output, axis, is_ascend,
188+
[](OutType* out_ptr, size_t index, const std::pair<int64_t, DataType>& sort_pair) {
189+
out_ptr[index] = static_cast<OutType>(sort_pair.first);
190+
});
172191
}
173192

174193
template <typename DataType>
175194
void sort(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend) {
176-
return sort_impl<DataType, DataType>(input, output, axis, is_ascend, false);
195+
return sort_impl<DataType, DataType>(
196+
input, output, axis, is_ascend,
197+
[](DataType* out_ptr, size_t index, const std::pair<int64_t, DataType>& sort_pair) {
198+
out_ptr[index] = sort_pair.second;
199+
});
177200
}
178201

179202
// Argsort implemented C library sort.
@@ -254,6 +277,18 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort").set_body([](TVMArgs args, TVMRet
254277
} else {
255278
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
256279
}
280+
} else if (data_dtype == "float16") {
281+
if (out_dtype == "int32") {
282+
argsort<float16, int32_t>(input, output, axis, is_ascend);
283+
} else if (out_dtype == "int64") {
284+
argsort<float16, int64_t>(input, output, axis, is_ascend);
285+
} else if (out_dtype == "float32") {
286+
argsort<float16, float>(input, output, axis, is_ascend);
287+
} else if (out_dtype == "float64") {
288+
argsort<float16, double>(input, output, axis, is_ascend);
289+
} else {
290+
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
291+
}
257292
} else {
258293
LOG(FATAL) << "Unsupported input dtype: " << data_dtype;
259294
}
@@ -295,6 +330,8 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.sort").set_body([](TVMArgs args, TVMRetVal
295330
sort<int32_t>(input, output, axis, is_ascend);
296331
} else if (data_dtype == "int64") {
297332
sort<int64_t>(input, output, axis, is_ascend);
333+
} else if (data_dtype == "float16") {
334+
sort<float16>(input, output, axis, is_ascend);
298335
} else {
299336
LOG(FATAL) << "Unsupported input dtype: " << data_dtype;
300337
}
@@ -432,6 +469,18 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.topk").set_body([](TVMArgs args, TVMRetVal
432469
} else {
433470
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
434471
}
472+
} else if (data_dtype == "float16") {
473+
if (out_dtype == "int32") {
474+
topk<float16, int32_t>(input, values_out, indices_out, k, axis, is_ascend);
475+
} else if (out_dtype == "int64") {
476+
topk<float16, int64_t>(input, values_out, indices_out, k, axis, is_ascend);
477+
} else if (out_dtype == "float32") {
478+
topk<float16, float>(input, values_out, indices_out, k, axis, is_ascend);
479+
} else if (out_dtype == "float64") {
480+
topk<float16, double>(input, values_out, indices_out, k, axis, is_ascend);
481+
} else {
482+
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
483+
}
435484
} else {
436485
LOG(FATAL) << "Unsupported input dtype: " << data_dtype;
437486
}

tests/python/relay/test_op_level6.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,24 +16,23 @@
1616
# under the License.
1717
""" Support level6 operator test cases.
1818
"""
19+
import pytest
1920
import numpy as np
2021
import tvm
21-
from tvm import te
2222
from tvm import relay
2323
import tvm.testing
2424

2525

2626
@tvm.testing.uses_gpu
2727
def test_sort():
28-
def verify_sort(shape, axis, is_ascend, is_dyn=False):
29-
28+
def verify_sort(shape, axis, is_ascend, is_dyn=False, in_dtype="float32"):
3029
if is_dyn:
31-
x = relay.var("x", relay.TensorType([relay.Any()] * len(shape), "float32"))
30+
x = relay.var("x", relay.TensorType([relay.Any()] * len(shape), in_dtype))
3231
else:
33-
x = relay.var("x", relay.TensorType(shape, "float32"))
32+
x = relay.var("x", relay.TensorType(shape, in_dtype))
3433
z = relay.sort(x, axis=axis, is_ascend=is_ascend)
3534
func = relay.Function([x], z)
36-
x_data = np.random.uniform(size=shape).astype("float32")
35+
x_data = np.random.uniform(size=shape).astype(in_dtype)
3736
if is_ascend:
3837
ref_res = np.sort(x_data, axis=axis)
3938
else:
@@ -56,18 +55,19 @@ def verify_sort(shape, axis, is_ascend, is_dyn=False):
5655
verify_sort((3, 5, 6), axis=-1, is_ascend=False, is_dyn=is_dyn)
5756
verify_sort((3, 2000, 6), axis=1, is_ascend=False, is_dyn=is_dyn)
5857
verify_sort((1, 122640), axis=1, is_ascend=False, is_dyn=is_dyn)
58+
verify_sort((1, 122640), axis=1, is_ascend=False, is_dyn=is_dyn, in_dtype="float16")
5959

6060

6161
@tvm.testing.uses_gpu
6262
def test_argsort():
63-
def verify_argsort(shape, axis, is_ascend, dtype, is_dyn=False):
63+
def verify_argsort(shape, axis, is_ascend, dtype, is_dyn=False, in_dtype="float32"):
6464
if is_dyn:
65-
x = relay.var("x", relay.TensorType([relay.Any()] * len(shape), "float32"))
65+
x = relay.var("x", relay.TensorType([relay.Any()] * len(shape), in_dtype))
6666
else:
67-
x = relay.var("x", relay.TensorType(shape, "float32"))
67+
x = relay.var("x", relay.TensorType(shape, in_dtype))
6868
z = relay.argsort(x, axis=axis, is_ascend=is_ascend, dtype=dtype)
6969
func = relay.Function([x], z)
70-
x_data = np.random.uniform(size=shape).astype("float32")
70+
x_data = np.random.uniform(size=shape).astype(in_dtype)
7171
if is_ascend:
7272
ref_res = np.argsort(x_data, axis=axis, kind="stable")
7373
else:
@@ -93,31 +93,34 @@ def verify_argsort(shape, axis, is_ascend, dtype, is_dyn=False):
9393
verify_argsort((3, 6000, 6), axis=1, is_ascend=False, dtype=dtype, is_dyn=is_dyn)
9494
verify_argsort((1000, 1, 1), axis=0, is_ascend=False, dtype=dtype, is_dyn=is_dyn)
9595
verify_argsort((1, 122640), axis=1, is_ascend=False, dtype=dtype, is_dyn=is_dyn)
96+
verify_argsort(
97+
(1, 122640), axis=1, is_ascend=False, dtype=dtype, is_dyn=is_dyn, in_dtype="float16"
98+
)
9699

97100

98101
@tvm.testing.uses_gpu
99102
def test_topk():
100-
def verify_topk(k, axis, ret_type, is_ascend, dtype):
103+
def verify_topk(k, axis, ret_type, is_ascend, dtype, in_dtype="float32"):
101104
shape = (20, 100)
102-
x = relay.var("x", relay.TensorType(shape, "float32"))
105+
x = relay.var("x", relay.TensorType(shape, in_dtype))
103106
out = relay.topk(x, k, axis, ret_type, is_ascend, dtype)
104107
if isinstance(out, relay.expr.TupleWrapper):
105108
out = out.astuple()
106109
func = relay.Function([x], out)
107-
np_data = np.random.uniform(size=shape).astype("float32")
110+
np_data = np.random.uniform(size=shape).astype(in_dtype)
108111
if is_ascend:
109-
np_indices = np.argsort(np_data, axis=axis)
112+
np_indices = np.argsort(np_data, axis=axis, kind="stable")
110113
else:
111-
np_indices = np.argsort(-np_data, axis=axis)
114+
np_indices = np.argsort(-np_data, axis=axis, kind="stable")
112115
kk = k if k >= 1 else shape[axis]
113116
if axis == 0:
114117
np_indices = np_indices[:kk, :]
115-
np_values = np.zeros(np_indices.shape).astype("float32")
118+
np_values = np.zeros(np_indices.shape).astype(in_dtype)
116119
for i in range(shape[1]):
117120
np_values[:, i] = np_data[np_indices[:, i], i]
118121
else:
119122
np_indices = np_indices[:, :kk]
120-
np_values = np.zeros(np_indices.shape).astype("float32")
123+
np_values = np.zeros(np_indices.shape).astype(in_dtype)
121124
for i in range(shape[0]):
122125
np_values[i, :] = np_data[i, np_indices[i, :]]
123126
np_indices = np_indices.astype(dtype)
@@ -140,9 +143,8 @@ def verify_topk(k, axis, ret_type, is_ascend, dtype):
140143
for ret_type in ["both", "values", "indices"]:
141144
verify_topk(k, axis, ret_type, True, "int64")
142145
verify_topk(k, axis, ret_type, False, "float32")
146+
verify_topk(k, axis, ret_type, False, "int64", "float16")
143147

144148

145149
if __name__ == "__main__":
146-
test_sort()
147-
test_argsort()
148-
test_topk()
150+
pytest.main([__file__])

web/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
TVM_ROOT=$(shell cd ..; pwd)
1919

2020
INCLUDE_FLAGS = -I$(TVM_ROOT) -I$(TVM_ROOT)/include\
21-
-I$(TVM_ROOT)/3rdparty/dlpack/include -I$(TVM_ROOT)/3rdparty/dmlc-core/include
21+
-I$(TVM_ROOT)/3rdparty/dlpack/include -I$(TVM_ROOT)/3rdparty/dmlc-core/include -I$(TVM_ROOT)/3rdparty/compiler-rt
2222

2323
.PHONY: clean all rmtypedep preparetest
2424

0 commit comments

Comments
 (0)