Skip to content

Commit b49b752

Browse files
MasterJH5574yzh119
andcommitted
[CherryPick][Intrinsic] lower_bound and upper_bound for binary search in Sparse TIR. (apache#483) (#4)
* upd * upd * fix * upd * upd * upd * upd * upd * fix * upd * upd * upd * upd * upd * upd * upd * codegen-rule * upd * upd * test * upd * fix * two arguments Co-authored-by: Zihao Ye <expye@outlook.com>
1 parent d4b4550 commit b49b752

File tree

10 files changed

+247
-1
lines changed

10 files changed

+247
-1
lines changed

include/tvm/tir/builtin.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,16 @@ TVM_DLL const Op& tvm_warp_shuffle_up();
494494
TVM_DLL const Op& tvm_warp_shuffle_down();
495495
TVM_DLL const Op& tvm_warp_activemask();
496496

497+
/*!
498+
* \brief Lower bound function for binary search.
499+
*/
500+
TVM_DLL const Op& tvm_lower_bound();
501+
502+
/*!
503+
* \brief Upper bound function for binary search.
504+
*/
505+
TVM_DLL const Op& tvm_upper_bound();
506+
497507
/*!
498508
* \brief Initialize the global barrier.
499509
* Call this at beginning of kernel that need global barrier.

python/tvm/script/tir/intrin.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,16 @@ def max_value(dtype, span):
111111
return tvm.tir.max_value(dtype, span)
112112

113113

114+
@register
115+
def lower_bound(arr, val, l, r, span):
116+
return tvm.tir.lower_bound(arr, val, l, r, span)
117+
118+
119+
@register
120+
def upper_bound(arr, val, l, r, span):
121+
return tvm.tir.upper_bound(arr, val, l, r, span)
122+
123+
114124
@register
115125
def floordiv(x, y, span):
116126
return tvm.tir.floordiv(x, y, span)

python/tvm/tir/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from .function import PrimFunc, TensorIntrin
3737

3838
from .op import call_packed, call_intrin, call_pure_extern, call_extern
39-
from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any, min_value, max_value, trace
39+
from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any, min_value, max_value, trace, lower_bound, upper_bound
4040
from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp, clz
4141
from .op import sin, sinh, asin, asinh
4242
from .op import cos, cosh, acos, acosh

python/tvm/tir/op.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -971,6 +971,62 @@ def ldexp(x1, x2):
971971
return call_intrin(x1.dtype, "tir.ldexp", x1, x2) # type: ignore
972972

973973

974+
def lower_bound(arr, val, l, r, span=None):
975+
"""Return the position to the first element in the arr[l:r] that is no less than val.
976+
977+
Parameters
978+
----------
979+
arr : Var
980+
Pointer to the 1D buffer to apply binary search on.
981+
982+
val : PrimExpr
983+
Value of the lower bound to search for in the buffer.
984+
985+
l : PrimExpr
986+
Start position to search for in the buffer.
987+
988+
r : PrimExpr
989+
End position to search for in the buffer.
990+
991+
span : Optional[Span]
992+
The location of this expression in the source code.
993+
994+
Returns
995+
-------
996+
PrimExpr
997+
The index of element in arr[l:r] that is no less then given value.
998+
"""
999+
return _ffi_api.lower_bound(arr, val, l, r, span) # type: ignore
1000+
1001+
1002+
def upper_bound(arr, val, l, r, span=None):
1003+
"""Return the position the first element in the arr that is greater than val.
1004+
1005+
Parameters
1006+
----------
1007+
arr : Var
1008+
Pointer to the 1D buffer to apply binary search on.
1009+
1010+
val : PrimExpr
1011+
Value of the upper bound to search for in the buffer.
1012+
1013+
l : PrimExpr
1014+
Start position to search for in the buffer.
1015+
1016+
r : PrimExpr
1017+
End position to search for in the buffer.
1018+
1019+
span : Optional[Span]
1020+
The location of this expression in the source code.
1021+
1022+
Returns
1023+
-------
1024+
PrimExpr
1025+
The index of element in arr[l:r] that is no less then given value.
1026+
"""
1027+
return _ffi_api.upper_bound(arr, val, l, r, span) # type: ignore
1028+
1029+
9741030
def isnan(x, span=None):
9751031
"""Check if input value is Nan.
9761032

src/target/source/codegen_cuda.cc

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include <utility>
3333
#include <vector>
3434

35+
#include "literal/cuda_binary_search.h"
3536
#include "literal/cuda_half_t.h"
3637
#include "ptx_mma.h"
3738

@@ -133,6 +134,10 @@ std::string CodeGenCUDA::Finish() {
133134
decl_stream << "#include <mma.h>\n";
134135
}
135136

137+
if (need_binary_search_) {
138+
decl_stream << _cuda_binary_search_def;
139+
}
140+
136141
decl_stream << "\n#ifdef _WIN32\n";
137142
decl_stream << " using uint = unsigned int;\n";
138143
decl_stream << " using uchar = unsigned char;\n";
@@ -756,6 +761,30 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
756761
a_ref, a_bias, b_ref, b_bias, c_ref, c_bias, saturate);
757762

758763
this->stream << asm_code;
764+
} else if (op->op.same_as(builtin::tvm_lower_bound())) {
765+
need_binary_search_ = true;
766+
os << "__lower_bound(";
767+
ICHECK_EQ(op->args.size(), 4U);
768+
this->PrintExpr(op->args[0], os);
769+
os << ", ";
770+
this->PrintExpr(op->args[1], os);
771+
os << ", ";
772+
this->PrintExpr(op->args[2], os);
773+
os << ", ";
774+
this->PrintExpr(op->args[3], os);
775+
os << ")";
776+
} else if (op->op.same_as(builtin::tvm_upper_bound())) {
777+
need_binary_search_ = true;
778+
os << "__upper_bound(";
779+
ICHECK_EQ(op->args.size(), 4U);
780+
this->PrintExpr(op->args[0], os);
781+
os << ", ";
782+
this->PrintExpr(op->args[1], os);
783+
os << ", ";
784+
this->PrintExpr(op->args[2], os);
785+
os << ", ";
786+
this->PrintExpr(op->args[3], os);
787+
os << ")";
759788
} else {
760789
CodeGenC::VisitExpr_(op, os);
761790
}

src/target/source/codegen_cuda.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ class CodeGenCUDA final : public CodeGenC {
9999
bool need_math_constants_h_{false};
100100
// whether need mma.h
101101
bool need_mma_h_{false};
102+
// whether need binary search
103+
bool need_binary_search_{false};
102104
// Op attribute map
103105
OpAttrMap<bool> op_need_warp_shuffle_ = Op::GetAttrMap<bool>("cuda.need_warp_shuffle");
104106

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file cuda_binary_search.h
22+
* \brief Binary search function definition for cuda codegen.
23+
*/
24+
#ifndef TVM_TARGET_SOURCE_LITERAL_CUDA_BINARY_SEARCH_H_
25+
#define TVM_TARGET_SOURCE_LITERAL_CUDA_BINARY_SEARCH_H_
26+
27+
static constexpr const char* _cuda_binary_search_def = R"(
28+
template <typename DType>
29+
__forceinline__ __device__ int32_t __lower_bound(
30+
const DType* __restrict__ arr,
31+
DType val,
32+
int32_t l,
33+
int32_t r) {
34+
int32_t low = l - 1, high = r;
35+
/* loop invariant: low < mid < high, arr[low] < val, arr[high] >= val */
36+
while (low + 1 < high) {
37+
int32_t mid = (low + high) >> 1;
38+
if (arr[mid] < val) {
39+
low = mid;
40+
} else {
41+
high = mid;
42+
}
43+
}
44+
// high = low + 1, arr[low] < val, arr[high] >= val
45+
return high;
46+
}
47+
48+
template <typename DType>
49+
__forceinline__ __device__ int32_t __upper_bound(
50+
const DType* __restrict__ arr,
51+
DType val,
52+
int32_t l,
53+
int32_t r) {
54+
int32_t low = l - 1, high = r;
55+
/* loop invariant: low < mid < high, arr[low] < val, arr[high] > val */
56+
while (low + 1 < high) {
57+
int32_t mid = (low + high) >> 1;
58+
if (arr[mid] > val) {
59+
high = mid;
60+
} else {
61+
low = mid;
62+
}
63+
}
64+
// high = low + 1, arr[low] <= val, arr[high] > val
65+
return high;
66+
}
67+
)";
68+
69+
#endif // TVM_TARGET_SOURCE_LITERAL_CUDA_BINARY_SEARCH_H_

src/tir/op/builtin.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,12 @@ TIR_DEFINE_BUILTIN_FUNC(tvm_thread_allreduce)
222222
TIR_DEFINE_BUILTIN_FUNC(tvm_load_matrix_sync)
223223
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kReadState));
224224

225+
TIR_DEFINE_BUILTIN_FUNC(tvm_lower_bound)
226+
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
227+
228+
TIR_DEFINE_BUILTIN_FUNC(tvm_upper_bound)
229+
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
230+
225231
TIR_DEFINE_BUILTIN_FUNC(tvm_mma_sync)
226232
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
227233

src/tir/op/op.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -804,6 +804,16 @@ PrimExpr nearbyint(PrimExpr x, Span span) {
804804

805805
TIR_REGISTER_PURE_UNARY_OP("tir.nearbyint");
806806

807+
// lower_bound
808+
PrimExpr lower_bound(Var arr, PrimExpr val, PrimExpr l, PrimExpr r, Span span) {
809+
return tir::Call({kDLInt, 32, 1}, builtin::tvm_lower_bound(), {arr, val, l, r}, span);
810+
}
811+
812+
// upper_bound
813+
PrimExpr upper_bound(Var arr, PrimExpr val, PrimExpr l, PrimExpr r, Span span) {
814+
return tir::Call({kDLInt, 32, 1}, builtin::tvm_upper_bound(), {arr, val, l, r}, span);
815+
}
816+
807817
// trunc
808818
PrimExpr trunc(PrimExpr x, Span span) {
809819
if (x.dtype().is_int() || x.dtype().is_uint()) {
@@ -918,6 +928,10 @@ TVM_REGISTER_GLOBAL("tir.trunc").set_body_typed(tvm::trunc);
918928

919929
TVM_REGISTER_GLOBAL("tir._cast").set_body_typed(tvm::cast);
920930

931+
TVM_REGISTER_GLOBAL("tir.lower_bound").set_body_typed(tvm::lower_bound);
932+
933+
TVM_REGISTER_GLOBAL("tir.upper_bound").set_body_typed(tvm::upper_bound);
934+
921935
// operator overloading, smarter than make
922936
#define REGISTER_MAKE_BINARY_OP(Node, Func) \
923937
TVM_REGISTER_GLOBAL("tir." #Node).set_body_typed([](PrimExpr a, PrimExpr b, Span span) { \

tests/python/unittest/test_tir_intrin.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,55 @@ def test_fma():
253253
assert mod["test_tir_fma"].body.body.value.op.name == "tir.call_llvm_pure_intrin"
254254

255255

256+
@tvm.script.tir
257+
def binary_search(a: ty.handle, b: ty.handle, c: ty.handle, d: ty.handle) -> None:
258+
n = tir.var('int32')
259+
m = tir.var('int32')
260+
A = tir.match_buffer(a, (n,), dtype='int32')
261+
B = tir.match_buffer(b, (m,), dtype='int32')
262+
C = tir.match_buffer(c, (m,), dtype='int32')
263+
D = tir.match_buffer(d, (m,), dtype='int32')
264+
with tir.block([m], 'search') as [vi]:
265+
tir.reads([A[0:n], B[vi]])
266+
tir.writes([C[vi], D[vi]])
267+
C[vi] = tir.lower_bound(A.data, B[vi], 0, n)
268+
D[vi] = tir.upper_bound(A.data, B[vi], 0, n)
269+
270+
271+
def test_binary_search():
272+
sch = tir.Schedule(binary_search)
273+
b = sch.get_block('search')
274+
i, = sch.get_loops(b)
275+
io, ii = sch.split(i, [1, None])
276+
sch.bind(io, 'threadIdx.x')
277+
sch.bind(ii, 'blockIdx.x')
278+
f = tvm.build(sch.mod['main'], target='cuda')
279+
# print(f.imported_modules[0].get_source())
280+
281+
x = np.arange(-128, 128).astype(np.int32)
282+
y = np.random.randint(-200, 200, size=1024).astype(np.int32)
283+
a = np.zeros((1024,)).astype(np.int32)
284+
b = np.zeros((1024,)).astype(np.int32)
285+
286+
# numpy results
287+
np_a = np.searchsorted(x, y, side='left').astype(np.int32)
288+
np_b = np.searchsorted(x, y, side='right').astype(np.int32)
289+
290+
# tvm results
291+
dev = tvm.cuda(0)
292+
x_array = tvm.nd.array(x, device=dev)
293+
y_array = tvm.nd.array(y, device=dev)
294+
a_array = tvm.nd.array(a, device=dev)
295+
b_array = tvm.nd.array(b, device=dev)
296+
f(x_array, y_array, a_array, b_array)
297+
tvm_a = a_array.numpy()
298+
tvm_b = b_array.numpy()
299+
300+
# verify result
301+
tvm.testing.assert_allclose(np_a, tvm_a)
302+
tvm.testing.assert_allclose(np_b, tvm_b)
303+
304+
256305
if __name__ == "__main__":
257306
test_nearbyint()
258307
test_unary_intrin()
@@ -261,3 +310,4 @@ def test_fma():
261310
test_ldexp()
262311
test_clz()
263312
test_fma()
313+
test_binary_search()

0 commit comments

Comments
 (0)