@@ -253,6 +253,55 @@ def test_fma():
253253 assert mod ["test_tir_fma" ].body .body .value .op .name == "tir.call_llvm_pure_intrin"
254254
255255
256+ @tvm .script .tir
257+ def binary_search (a : ty .handle , b : ty .handle , c : ty .handle , d : ty .handle ) -> None :
258+ n = tir .var ('int32' )
259+ m = tir .var ('int32' )
260+ A = tir .match_buffer (a , (n ,), dtype = 'int32' )
261+ B = tir .match_buffer (b , (m ,), dtype = 'int32' )
262+ C = tir .match_buffer (c , (m ,), dtype = 'int32' )
263+ D = tir .match_buffer (d , (m ,), dtype = 'int32' )
264+ with tir .block ([m ], 'search' ) as [vi ]:
265+ tir .reads ([A [0 :n ], B [vi ]])
266+ tir .writes ([C [vi ], D [vi ]])
267+ C [vi ] = tir .lower_bound (A .data , B [vi ], 0 , n )
268+ D [vi ] = tir .upper_bound (A .data , B [vi ], 0 , n )
269+
270+
271+ def test_binary_search ():
272+ sch = tir .Schedule (binary_search )
273+ b = sch .get_block ('search' )
274+ i , = sch .get_loops (b )
275+ io , ii = sch .split (i , [1 , None ])
276+ sch .bind (io , 'threadIdx.x' )
277+ sch .bind (ii , 'blockIdx.x' )
278+ f = tvm .build (sch .mod ['main' ], target = 'cuda' )
279+ # print(f.imported_modules[0].get_source())
280+
281+ x = np .arange (- 128 , 128 ).astype (np .int32 )
282+ y = np .random .randint (- 200 , 200 , size = 1024 ).astype (np .int32 )
283+ a = np .zeros ((1024 ,)).astype (np .int32 )
284+ b = np .zeros ((1024 ,)).astype (np .int32 )
285+
286+ # numpy results
287+ np_a = np .searchsorted (x , y , side = 'left' ).astype (np .int32 )
288+ np_b = np .searchsorted (x , y , side = 'right' ).astype (np .int32 )
289+
290+ # tvm results
291+ dev = tvm .cuda (0 )
292+ x_array = tvm .nd .array (x , device = dev )
293+ y_array = tvm .nd .array (y , device = dev )
294+ a_array = tvm .nd .array (a , device = dev )
295+ b_array = tvm .nd .array (b , device = dev )
296+ f (x_array , y_array , a_array , b_array )
297+ tvm_a = a_array .numpy ()
298+ tvm_b = b_array .numpy ()
299+
300+ # verify result
301+ tvm .testing .assert_allclose (np_a , tvm_a )
302+ tvm .testing .assert_allclose (np_b , tvm_b )
303+
304+
256305if __name__ == "__main__" :
257306 test_nearbyint ()
258307 test_unary_intrin ()
@@ -261,3 +310,4 @@ def test_fma():
261310 test_ldexp ()
262311 test_clz ()
263312 test_fma ()
313+ test_binary_search ()
0 commit comments