Skip to content

Commit 63fcd98

Browse files
MasterJH5574yzh119
authored andcommitted
[SparseTIR][Schedule] GetSpIters (#24)
1 parent c4e039d commit 63fcd98

File tree

8 files changed

+60
-12
lines changed

8 files changed

+60
-12
lines changed

include/tvm/tir/schedule/schedule.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,12 @@ class ScheduleNode : public runtime::Object {
547547
* \note Indexing error is raised if 0 or multiple blocks exist with the specific name
548548
*/
549549
virtual SparseBlockRV GetSparseBlock(const String& name, const String& func_name = "main") = 0;
550+
/*!
551+
* \brief Retrieve the sparse iterators of a given sparse block
552+
* \param block_rv The block to be queried
553+
* \return The sparse iterators of the input sparse block
554+
*/
555+
virtual Array<SpIterVar> GetSpIters(const SparseBlockRV& block_rv) = 0;
550556
/*!
551557
* \brief Reorder a list of sparse iterators. It requires the new order to not break the iterator
552558
* dependency.

python/tvm/tir/schedule/schedule.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1930,6 +1930,24 @@ def get_sparse_block(
19301930
func_name,
19311931
)
19321932

1933+
def get_sp_iters(self, block: SparseBlockRV) -> List[SpIterVar]:
1934+
"""Retrieve the sparse iterators of a given sparse block
1935+
1936+
Parameters
1937+
----------
1938+
block : SparseBlockRV
1939+
The block to be queried
1940+
1941+
Returns
1942+
-------
1943+
sp_iters : List[SpIterVar]
1944+
The sparse iterators of the input sparse block
1945+
"""
1946+
return _ffi_api.ScheduleGetSpIters( # type: ignore # pylint: disable=no-member
1947+
self,
1948+
block,
1949+
)
1950+
19331951
def sparse_reorder(self, block: SparseBlockRV, new_order: List[SpIterVar]) -> None:
19341952
"""Reorder a list of sparse iterators. It requires the new order to not break the iterator
19351953
dependency.

src/tir/schedule/concrete_schedule.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -700,6 +700,10 @@ SparseBlockRV ConcreteScheduleNode::GetSparseBlock(const String& name, const Str
700700
return CreateRV(GetRef<SparseBlock>(block));
701701
}
702702

703+
Array<SpIterVar> ConcreteScheduleNode::GetSpIters(const SparseBlockRV& block_rv) {
704+
return this->Get(block_rv)->sp_iter_vars;
705+
}
706+
703707
void ConcreteScheduleNode::SparseReorder(const SparseBlockRV& block_rv,
704708
const Array<SpIterVar>& new_order) {
705709
SparseBlock old_block = this->Get(block_rv);

src/tir/schedule/concrete_schedule.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ class ConcreteScheduleNode : public ScheduleNode {
135135
void EnterPostproc() override {}
136136
/******** Schedule: SparseTIR schedules ********/
137137
SparseBlockRV GetSparseBlock(const String& name, const String& func_name = "main") override;
138+
Array<SpIterVar> GetSpIters(const SparseBlockRV& block_rv) override;
138139
void SparseReorder(const SparseBlockRV& block_rv, const Array<SpIterVar>& new_order) override;
139140

140141
protected:

src/tir/schedule/schedule.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleEnterPostproc")
228228
/******** (FFI) SparseTIR schedules ********/
229229
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetSparseBlock")
230230
.set_body_method<Schedule>(&ScheduleNode::GetSparseBlock);
231+
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetSpIters")
232+
.set_body_method<Schedule>(&ScheduleNode::GetSpIters);
231233
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSparseReorder")
232234
.set_body_method<Schedule>(&ScheduleNode::SparseReorder);
233235

src/tir/schedule/traced_schedule.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,14 @@ SparseBlockRV TracedScheduleNode::GetSparseBlock(const String& name, const Strin
416416
return result;
417417
}
418418

419+
Array<SpIterVar> TracedScheduleNode::GetSpIters(const SparseBlockRV& block_rv) {
420+
Array<SpIterVar> result = ConcreteScheduleNode::GetSpIters(block_rv);
421+
422+
// Do not support traced schedule so far.
423+
424+
return result;
425+
}
426+
419427
void TracedScheduleNode::SparseReorder(const SparseBlockRV& block_rv,
420428
const Array<SpIterVar>& new_order) {
421429
ConcreteScheduleNode::SparseReorder(block_rv, new_order);

src/tir/schedule/traced_schedule.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ class TracedScheduleNode : public ConcreteScheduleNode {
9696
void EnterPostproc() final;
9797
/******** Schedule: SparseTIR schedules ********/
9898
SparseBlockRV GetSparseBlock(const String& name, const String& func_name = "main") final;
99+
Array<SpIterVar> GetSpIters(const SparseBlockRV& block_rv) final;
99100
void SparseReorder(const SparseBlockRV& block, const Array<SpIterVar>& new_order) final;
100101
};
101102

tests/python/sparsetir/test_tir_sparse_schedule.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -408,35 +408,43 @@ def test_get_sparse_block():
408408
assert block.same_as(csrmm.body)
409409

410410

411+
def test_get_sp_iters():
412+
sch = tir.Schedule(csrmm, debug_mask="all")
413+
block = sch.get_sparse_block("csrmm")
414+
vi, vj, vk = sch.get_sp_iters(block)
415+
assert vi.same_as(csrmm.body.sp_iter_vars[0])
416+
assert vj.same_as(csrmm.body.sp_iter_vars[1])
417+
assert vk.same_as(csrmm.body.sp_iter_vars[2])
418+
419+
411420
def test_reorder():
412421
sch = tir.Schedule(bsrmm, debug_mask="all")
413-
block_rv = sch.get_sparse_block("bsrmm")
414-
block = sch.get(block_rv)
415-
i, j, bi, bj, f = block.sp_iter_vars
416-
sch.sparse_reorder(block_rv, [bi, bj, i, j, f])
422+
block = sch.get_sparse_block("bsrmm")
423+
i, j, bi, bj, f = sch.get_sp_iters(block)
424+
sch.sparse_reorder(block, [bi, bj, i, j, f])
417425
tvm.ir.assert_structural_equal(sch.mod["main"], reordered_bsrmm, True)
426+
assert sch.get(block).name == "bsrmm"
418427

419428

420429
def test_reorder_fail_on_dependency():
421430
sch = tir.Schedule(bsrmm, debug_mask="all")
422-
block_rv = sch.get_sparse_block("bsrmm")
423-
block = sch.get(block_rv)
424-
i, j, bi, bj, f = block.sp_iter_vars
431+
block = sch.get_sparse_block("bsrmm")
432+
i, j, bi, bj, f = sch.get_sp_iters(block)
425433
with pytest.raises(tvm.tir.ScheduleError):
426-
sch.sparse_reorder(block_rv, [bi, bj, j, i, f])
434+
sch.sparse_reorder(block, [bi, bj, j, i, f])
427435

428436

429437
def test_reorder_fail_on_new_order_length():
430438
sch = tir.Schedule(bsrmm, debug_mask="all")
431-
block_rv = sch.get_sparse_block("bsrmm")
432-
block = sch.get(block_rv)
433-
i, j, bi, bj, f = block.sp_iter_vars
439+
block = sch.get_sparse_block("bsrmm")
440+
i, j, bi, bj, f = sch.get_sp_iters(block)
434441
with pytest.raises(tvm.tir.ScheduleError):
435-
sch.sparse_reorder(block_rv, [bi, bj, i, j])
442+
sch.sparse_reorder(block, [bi, bj, i, j])
436443

437444

438445
if __name__ == "__main__":
439446
test_get_sparse_block()
447+
test_get_sp_iters()
440448
test_reorder()
441449
test_reorder_fail_on_dependency()
442450
test_reorder_fail_on_new_order_length()

0 commit comments

Comments
 (0)