@@ -408,35 +408,43 @@ def test_get_sparse_block():
408408 assert block .same_as (csrmm .body )
409409
410410
411+ def test_get_sp_iters ():
412+ sch = tir .Schedule (csrmm , debug_mask = "all" )
413+ block = sch .get_sparse_block ("csrmm" )
414+ vi , vj , vk = sch .get_sp_iters (block )
415+ assert vi .same_as (csrmm .body .sp_iter_vars [0 ])
416+ assert vj .same_as (csrmm .body .sp_iter_vars [1 ])
417+ assert vk .same_as (csrmm .body .sp_iter_vars [2 ])
418+
419+
411420def test_reorder ():
412421 sch = tir .Schedule (bsrmm , debug_mask = "all" )
413- block_rv = sch .get_sparse_block ("bsrmm" )
414- block = sch .get (block_rv )
415- i , j , bi , bj , f = block .sp_iter_vars
416- sch .sparse_reorder (block_rv , [bi , bj , i , j , f ])
422+ block = sch .get_sparse_block ("bsrmm" )
423+ i , j , bi , bj , f = sch .get_sp_iters (block )
424+ sch .sparse_reorder (block , [bi , bj , i , j , f ])
417425 tvm .ir .assert_structural_equal (sch .mod ["main" ], reordered_bsrmm , True )
426+ assert sch .get (block ).name == "bsrmm"
418427
419428
420429def test_reorder_fail_on_dependency ():
421430 sch = tir .Schedule (bsrmm , debug_mask = "all" )
422- block_rv = sch .get_sparse_block ("bsrmm" )
423- block = sch .get (block_rv )
424- i , j , bi , bj , f = block .sp_iter_vars
431+ block = sch .get_sparse_block ("bsrmm" )
432+ i , j , bi , bj , f = sch .get_sp_iters (block )
425433 with pytest .raises (tvm .tir .ScheduleError ):
426- sch .sparse_reorder (block_rv , [bi , bj , j , i , f ])
434+ sch .sparse_reorder (block , [bi , bj , j , i , f ])
427435
428436
429437def test_reorder_fail_on_new_order_length ():
430438 sch = tir .Schedule (bsrmm , debug_mask = "all" )
431- block_rv = sch .get_sparse_block ("bsrmm" )
432- block = sch .get (block_rv )
433- i , j , bi , bj , f = block .sp_iter_vars
439+ block = sch .get_sparse_block ("bsrmm" )
440+ i , j , bi , bj , f = sch .get_sp_iters (block )
434441 with pytest .raises (tvm .tir .ScheduleError ):
435- sch .sparse_reorder (block_rv , [bi , bj , i , j ])
442+ sch .sparse_reorder (block , [bi , bj , i , j ])
436443
437444
438445if __name__ == "__main__" :
439446 test_get_sparse_block ()
447+ test_get_sp_iters ()
440448 test_reorder ()
441449 test_reorder_fail_on_dependency ()
442450 test_reorder_fail_on_new_order_length ()
0 commit comments