Skip to content

Commit e2b64ef

Browse files
authored
[SparseTIR] Hack IsAffineBinding check (#27)
* [TensorIR][Schedule] Inherit block anotation upon creating new blocks * Fix SDDMM test * Hack IsAffineBinding for sparse blocks
1 parent 0118eb3 commit e2b64ef

File tree

6 files changed

+110
-9
lines changed

6 files changed

+110
-9
lines changed

include/tvm/tir/schedule/state.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,13 @@ class ScheduleStateNode : public Object {
162162
* \return A boolean flag indicating if the block has quasi-affine bindings
163163
*/
164164
bool IsAffineBlockBinding(const StmtSRef& block_sref) const {
165+
// (SparseTIR Hack) Always return true for sparse blocks.
166+
const auto* block = block_sref->StmtAs<BlockNode>();
167+
Optional<ObjectRef> sparse_attr = block != nullptr ? block->annotations.Get("sparse") : NullOpt;
168+
if (sparse_attr.defined() && sparse_attr.as<IntImmNode>()->value == 1) {
169+
return true;
170+
}
171+
165172
return GetBlockInfo(block_sref).affine_binding;
166173
}
167174
/*!

src/tir/schedule/analysis/analysis.cc

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -168,12 +168,7 @@ Definition of a scope that is a stage pipeline:
168168
if (it_atomic != block->annotations.end()) {
169169
is_atomic = ((*it_atomic).second).as<IntImmNode>()->value;
170170
}
171-
auto&& it_sparse = block->annotations.find("sparse");
172-
bool is_sparse = false;
173-
if (it_sparse != block->annotations.end()) {
174-
is_sparse = ((*it_sparse).second).as<IntImmNode>()->value;
175-
}
176-
if (!is_sparse && !is_atomic) {
171+
if (!is_atomic) {
177172
throw NotCompactDataFlowError(self->mod, GetRef<Stmt>(scope_root_subtree->stmt),
178173
GetRef<Block>(block));
179174
}
@@ -445,6 +440,12 @@ void CheckNotOutputBlock(const ScheduleState& self, const StmtSRef& block_sref,
445440

446441
bool IsAffineBinding(const BlockRealize& realize, const Map<Var, Range>& loop_var_ranges,
447442
arith::Analyzer* analyzer) {
443+
// (SparseTIR Hack) Always return true for sparse blocks.
444+
Optional<ObjectRef> sparse_attr = realize->block->annotations.Get("sparse");
445+
if (sparse_attr.defined() && sparse_attr.as<IntImmNode>()->value == 1) {
446+
return true;
447+
}
448+
448449
if (loop_var_ranges.empty()) {
449450
return true;
450451
}

src/tir/schedule/primitive/reduction.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
212212
ObjectPtr<BlockNode> init_block = make_object<BlockNode>();
213213
ObjectPtr<BlockRealizeNode> init_realize = make_object<BlockRealizeNode>();
214214
init_block->name_hint = block->name_hint + "_init";
215+
init_block->annotations = block->annotations;
215216
init_realize->iter_values = {};
216217
init_realize->block = Block(init_block);
217218
// Step 1. Create new block vars and their bindings
@@ -580,7 +581,10 @@ class BaseBlockCreator {
580581
/*body=*/new_reduction_update_,
581582
/*init=*/
582583
BufferStore(new_reduction_update_->buffer, reducer_->identity_element[0],
583-
new_reduction_update_->indices));
584+
new_reduction_update_->indices),
585+
/*alloc_buffers=*/{},
586+
/*match_buffers=*/{},
587+
/*annotations=*/old_block_realize_->block->annotations);
584588
new_block_realize_ = BlockRealize(iter_values_, predicate, new_block_);
585589
}
586590

tests/python/sparsetir/test_tir_sparse_correctness.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -267,10 +267,9 @@ def test_sddmm():
267267
)
268268
blk = sch.get_block("sddmm")
269269
ij, k = sch.get_loops(blk)
270-
# TODO(zihao): fix the behavior in the future.
271-
# sch.decompose_reduction(blk, ij)
272270
sch.bind(ij, "blockIdx.x")
273271
sch.bind(k, "threadIdx.x")
272+
sch.decompose_reduction(blk, k)
274273

275274
# convert numpy tensor to tvm ndarray
276275
C_indices = tvm.nd.array(indices.astype("int32"), device=tvm.cuda(0))

tests/python/unittest/test_tir_schedule_reduction.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,39 @@ def matmul_decompose4(a: T.handle, b: T.handle, c: T.handle) -> None:
152152
C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk])
153153

154154

155+
@T.prim_func
156+
def matmul_with_annotation(a: T.handle, b: T.handle, c: T.handle) -> None:
157+
A = T.match_buffer(a, [128, 128])
158+
B = T.match_buffer(b, [128, 128])
159+
C = T.match_buffer(c, [128, 128])
160+
for i, j, k in T.grid(128, 128, 128):
161+
with T.block("update"):
162+
T.block_attr({"test_annotation": 1})
163+
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
164+
with T.init():
165+
C[vi, vj] = 0.0
166+
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
167+
168+
169+
@T.prim_func
170+
def matmul_decompose_with_annotation(a: T.handle, b: T.handle, c: T.handle) -> None:
171+
A = T.match_buffer(a, [128, 128])
172+
B = T.match_buffer(b, [128, 128])
173+
C = T.match_buffer(c, [128, 128])
174+
175+
for i, j in T.grid(128, 128):
176+
with T.block("init"):
177+
T.block_attr({"test_annotation": 1})
178+
vi, vj = T.axis.remap("SS", [i, j])
179+
C[vi, vj] = 0.0
180+
181+
for i, j, k in T.grid(128, 128, 128):
182+
with T.block("update"):
183+
T.block_attr({"test_annotation": 1})
184+
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
185+
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
186+
187+
155188
# pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg
156189

157190

@@ -201,5 +234,14 @@ def test_reduction_decompose4():
201234
verify_trace_roundtrip(s, mod=matmul)
202235

203236

237+
def test_reduction_decompose_with_annotation():
238+
s = tir.Schedule(matmul_with_annotation, debug_mask="all")
239+
C = s.get_block("update")
240+
i, j, k = s.get_loops(C)
241+
s.decompose_reduction(C, i)
242+
tvm.ir.assert_structural_equal(matmul_decompose_with_annotation, s.mod["main"])
243+
verify_trace_roundtrip(s, mod=matmul_with_annotation)
244+
245+
204246
if __name__ == "__main__":
205247
sys.exit(pytest.main([__file__] + sys.argv[1:]))

tests/python/unittest/test_tir_schedule_rfactor.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,43 @@ def square_sum_square_root_rfactor(a: T.handle, d: T.handle) -> None:
214214
D[b_2] = T.sqrt(C[b_2], dtype="float32")
215215

216216

217+
@T.prim_func
218+
def square_sum_with_annotation(a: T.handle, c: T.handle) -> None:
219+
A = T.match_buffer(a, [16, 256, 256])
220+
C = T.match_buffer(c, [16])
221+
222+
for b0, i0, j0 in T.grid(16, 256, 256):
223+
with T.block("C"):
224+
T.block_attr({"test_annotation": 1})
225+
b, i, j = T.axis.remap("SRR", [b0, i0, j0])
226+
with T.init():
227+
C[b] = 0.0
228+
C[b] = C[b] + A[b, i, j] * A[b, i, j]
229+
230+
231+
@T.prim_func
232+
def square_sum_with_annotation_rfactor(a: T.handle, c: T.handle) -> None:
233+
A = T.match_buffer(a, [16, 256, 256])
234+
C = T.match_buffer(c, [16])
235+
C_rf = T.alloc_buffer([16, 256])
236+
237+
for i0, i1, i2 in T.grid(16, 256, 256):
238+
with T.block("C_rf"):
239+
T.block_attr({"test_annotation": 1})
240+
vi2, b, i = T.axis.remap("SSR", [i2, i0, i1])
241+
with T.init():
242+
C_rf[b, vi2] = 0.0
243+
C_rf[b, vi2] = C_rf[b, vi2] + (A[b, i, vi2] * A[b, i, vi2])
244+
245+
for i0_1, i2_1 in T.grid(16, 256):
246+
with T.block("C"):
247+
T.block_attr({"test_annotation": 1})
248+
vi2_1, b_1 = T.axis.remap("RS", [i2_1, i0_1])
249+
with T.init():
250+
C[b_1] = 0.0
251+
C[b_1] = C[b_1] + C_rf[b_1, vi2_1]
252+
253+
217254
@T.prim_func
218255
def element_wise(a: T.handle, b: T.handle) -> None:
219256
A = T.match_buffer(a, (128, 128))
@@ -660,5 +697,16 @@ def test_reduction_rfactor_predicate(): # pylint: disable=invalid-name
660697
verify_trace_roundtrip(s, mod=rowsum_predicate)
661698

662699

700+
def test_reduction_rfactor_with_annotation():
701+
s = tir.Schedule(square_sum_with_annotation, debug_mask="all")
702+
C = s.get_block("C")
703+
_, _, j = s.get_loops(C)
704+
rf_block = s.rfactor(j, 1)
705+
tvm.ir.assert_structural_equal(s.mod["main"], square_sum_with_annotation_rfactor)
706+
assert s.get(rf_block).same_as(s.get(s.get_block("C_rf")))
707+
assert s.get(C).same_as(s.get(s.get_block("C")))
708+
verify_trace_roundtrip(s, mod=square_sum_with_annotation)
709+
710+
663711
if __name__ == "__main__":
664712
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 commit comments

Comments
 (0)