Skip to content

Commit a42bf57

Browse files
authored
Change the order of generated blocks for block isolation. (apache#35)
* upd * upd * upd
1 parent 8620909 commit a42bf57

File tree

2 files changed

+77
-82
lines changed

2 files changed

+77
-82
lines changed

src/tir/transforms/lower_sparse_tir.cc

Lines changed: 74 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
#include <tvm/tir/transform.h>
2828

2929
#include <set>
30-
#include <stack>
3130
#include <utility>
3231

3332
#include "../../support/utils.h"
@@ -521,84 +520,84 @@ class IndexTransformer : public StmtExprMutator {
521520
var_map[sp_iter_var->var.get()] = loop_var;
522521
}
523522

524-
// Step 4. Collet block iters and iter bindings.
525-
std::set<const AxisNode*> in_stack;
523+
// Step 4. Collect block iters and iter bindings.
524+
/* Whether the axis appears in the stack. */
525+
std::unordered_set<const AxisNode*> in_stack;
526526
/* A stack that stores block itervars in each block. */
527-
std::stack<Array<IterVar>> block_iters_st;
527+
std::vector<Array<IterVar>> block_iters_st;
528528
/* A stack that stores itervar bindings in each block. */
529-
std::stack<Array<PrimExpr>> iter_bindings_st;
529+
std::vector<Array<PrimExpr>> iter_bindings_st;
530530
/* A stack that stores generated loop vars in each block. */
531-
std::stack<Array<Var>> loop_vars_st;
531+
std::vector<Array<Var>> loop_vars_st;
532532
/* A stack that stores whether to place init block in each block. */
533-
std::stack<bool> place_init_st;
533+
std::vector<bool> place_init_st;
534534
/* An indicator that records whether init block has been set. */
535535
bool init_set = false;
536-
do {
537-
/* Block itervars of current block. */
538-
Array<IterVar> block_iters;
539-
/* Itervar bindings of current block. */
540-
Array<PrimExpr> iter_bindings;
541-
/* Axis names of current block. */
542-
Array<Axis> blk_axes;
543-
/* Generated loop vars of current block. */
544-
Array<Var> loop_vars;
545-
/* An indicator that records whether there is reduction axis in current block. */
546-
bool has_reduction_var = false;
547-
for (int i = 0; i < n_iter; ++i) {
548-
SpIterVar sp_it_var = sp_block->sp_iter_vars[i];
549-
Axis axis = sp_it_var->axis;
550-
551-
/* Add itervar to current block when
552-
* - it's not used yet (not in stack) and
553-
* - it's parent axis was used in outer blocks or
554-
* - it's an iterator to a fixed axis.
555-
*/
556-
auto parent = axis->GetParentAxis();
557-
bool emit_iter_var = true;
558-
if (in_stack.find(axis.get()) !=
559-
in_stack.end()) { // the iter var has already been emitted.
560-
emit_iter_var = false;
536+
/* Block itervars of current block. */
537+
Array<IterVar> block_iters;
538+
/* Itervar bindings of current block. */
539+
Array<PrimExpr> iter_bindings;
540+
/* Generated loop vars of current block. */
541+
Array<Var> loop_vars;
542+
/* Whether the axis appears in the cuurent block. */
543+
std::unordered_set<const AxisNode*> in_block;
544+
/* An indicator that records whether there is reduction axis in current block. */
545+
bool has_reduction_var = false;
546+
547+
auto UpdateStack = [&]() {
548+
block_iters_st.emplace_back(std::move(block_iters));
549+
iter_bindings_st.emplace_back(std::move(iter_bindings));
550+
loop_vars_st.emplace_back(std::move(loop_vars));
551+
if (init_set) {
552+
place_init_st.emplace_back(false);
553+
} else {
554+
place_init_st.emplace_back(has_reduction_var);
555+
init_set |= has_reduction_var;
556+
}
557+
};
558+
559+
for (int i = 0; i < n_iter; ++i) {
560+
SpIterVar sp_it_var = sp_block->sp_iter_vars[i];
561+
Axis axis = sp_it_var->axis;
562+
auto parent = axis->GetParentAxis();
563+
bool create_new_blk = false;
564+
bool is_fixed_axis = axis->kind() == AxisKind::kDenseFixed || axis->kind() == AxisKind::kSparseFixed;
565+
if (!is_fixed_axis && parent.defined()) {
566+
const AxisNode* parent_node = parent.value().get();
567+
if (in_block.find(parent_node) != in_block.end()) {
568+
/* parent node is in the current block, need to create new block. */
569+
create_new_blk = true;
570+
} else if (in_stack.find(parent_node) != in_stack.end()) {
571+
/* parent node is in the previous blocks in the stack, no need to create new block. */
572+
create_new_blk = false;
561573
} else {
562-
if (parent.defined()) { // has parent
563-
if (in_stack.find(parent.value().get()) == in_stack.end()) { // parent not emitted yet
564-
if (axis->kind() == AxisKind::kDenseVariable ||
565-
axis->kind() == AxisKind::kSparseVariable) { // is not fixed axis.
566-
emit_iter_var = false;
567-
}
568-
}
569-
}
574+
CHECK(false) << "The parent axis of " << axis->GetName() << " should appear before " << axis->GetName() << " when defining a sparse block.";
570575
}
571-
// LOG(INFO) << axis->name << " " << (parent.defined() ? parent.value()->name : "no-parent")
572-
// << " " << emit_iter_var;
573-
if (emit_iter_var) {
574-
loop_vars.push_back(all_loop_vars[i]);
575-
blk_axes.push_back(axis);
576-
block_iters.push_back(SpIterVarToIterVar(sp_it_var, var_map));
577-
iter_bindings.push_back(all_loop_vars[i]);
578-
has_reduction_var |= sp_it_var->is_reduction;
576+
}
577+
if (create_new_blk) {
578+
/* update in stack set. */
579+
for (const AxisNode* node : in_block) {
580+
in_stack.insert(node);
579581
}
582+
/* Update stack. */
583+
UpdateStack();
584+
/* Reset block states. */
585+
loop_vars = {};
586+
block_iters = {};
587+
iter_bindings = {};
588+
has_reduction_var = false;
589+
in_block.clear();
580590
}
581591

582-
/* Tag axes in current block as "in-stack". */
583-
for (const Axis&& axis : blk_axes) {
584-
in_stack.insert(axis.get());
585-
}
592+
loop_vars.push_back(all_loop_vars[i]);
593+
block_iters.push_back(SpIterVarToIterVar(sp_it_var, var_map));
594+
iter_bindings.push_back(all_loop_vars[i]);
595+
has_reduction_var |= sp_it_var->is_reduction;
596+
in_block.insert(axis.get());
597+
}
586598

587-
/* Update stack. */
588-
if (!block_iters.empty()) {
589-
block_iters_st.push(std::move(block_iters));
590-
iter_bindings_st.push(std::move(iter_bindings));
591-
loop_vars_st.push(std::move(loop_vars));
592-
if (init_set) {
593-
place_init_st.push(false);
594-
} else {
595-
place_init_st.push(has_reduction_var);
596-
init_set |= has_reduction_var;
597-
}
598-
} else {
599-
break;
600-
}
601-
} while (true);
599+
// Update the last block.
600+
UpdateStack();
602601

603602
// Step 5. Generate the read-region and write-retion of the block.
604603
Array<BufferRegion> reads{};
@@ -608,14 +607,14 @@ class IndexTransformer : public StmtExprMutator {
608607
// Step 6. Generate nested blocks and loops from innermost to outermost.
609608
int blk_counter = 0;
610609
while (!block_iters_st.empty()) {
611-
Array<IterVar> block_iters = std::move(block_iters_st.top());
612-
Array<PrimExpr> iter_bindings = std::move(iter_bindings_st.top());
613-
Array<Var> loop_vars = std::move(loop_vars_st.top());
614-
bool place_init = place_init_st.top();
615-
block_iters_st.pop();
616-
iter_bindings_st.pop();
617-
loop_vars_st.pop();
618-
place_init_st.pop();
610+
Array<IterVar> block_iters = std::move(block_iters_st.back());
611+
Array<PrimExpr> iter_bindings = std::move(iter_bindings_st.back());
612+
Array<Var> loop_vars = std::move(loop_vars_st.back());
613+
bool place_init = place_init_st.back();
614+
block_iters_st.pop_back();
615+
iter_bindings_st.pop_back();
616+
loop_vars_st.pop_back();
617+
place_init_st.pop_back();
619618

620619
Map<String, ObjectRef> mapping;
621620
mapping.Set("sparse", Bool(true));

tests/python/sparsetir/test_tir_sparse_lower.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def csrmm(
4040
A = T.match_sparse_buffer(a, (I, J), "float32")
4141
B = T.match_sparse_buffer(b, (T.dense(J), K), "float32")
4242
C = T.match_sparse_buffer(c, (I, K), "float32")
43-
with T.iter([I, J, K], "SRS", "csrmm") as [vi, vj, vk]:
43+
with T.iter([I, K, J], "SSR", "csrmm") as [vi, vk, vj]:
4444
with T.init():
4545
C[vi, vk] = 0.0
4646
C[vi, vk] = C[vi, vk] + A[vi, vj] * B[vj, vk]
@@ -180,12 +180,12 @@ def bsrmm(
180180
B = T.match_sparse_buffer(b, (T.dense(J), BJ, F), "float32")
181181
C = T.match_sparse_buffer(c, (I, BI, F), "float32")
182182

183-
with T.iter([I, J, BI, BJ, F], "SRSRS", "bsrmm") as [
183+
with T.iter([I, BI, BJ, F, J], "SSRSR", "bsrmm") as [
184184
vi,
185-
vj,
186185
vbi,
187186
vbj,
188187
vf,
188+
vj,
189189
]:
190190
with T.init():
191191
C[vi, vbi, vf] = 0.0
@@ -314,7 +314,6 @@ def lowered_csr_element_wise(a: T.handle, b: T.handle, indptr: T.handle, indices
314314
def test_csrmm():
315315
mod = tvm.IRModule.from_expr(csrmm)
316316
mod = tvm.tir.transform.LowerSparseTIR()(mod)
317-
print(mod["main"].script())
318317
tvm.ir.assert_structural_equal(mod["main"], lowered_csrmm, True)
319318

320319
A = sp.random(512, 512, dtype="float32", density=0.0125, format="csr")
@@ -338,14 +337,12 @@ def test_csrmm():
338337
def test_csrmm_dense_iter():
339338
mod = tvm.IRModule.from_expr(csrmm_dense_iter)
340339
mod = tvm.tir.transform.LowerSparseTIR()(mod)
341-
print(mod["main"].script())
342340
# tvm.ir.assert_structural_equal(mod["main"], lowered_csrmm, True)
343341

344342

345343
def test_segment_reduce():
346344
mod = tvm.IRModule.from_expr(segment_reduce)
347345
mod = tvm.tir.transform.LowerSparseTIR()(mod)
348-
print(mod["main"].script())
349346

350347

351348
def test_csr_reduce():
@@ -412,7 +409,6 @@ def test_bsrmm():
412409
def test_ellpack_mm():
413410
mod = tvm.IRModule.from_expr(ellpack_mm)
414411
mod = tvm.tir.transform.LowerSparseTIR()(mod)
415-
print(mod["main"].script())
416412
tvm.ir.assert_structural_equal(mod["main"], lowered_ellpack_mm, True)
417413

418414
nnz_cols = 4

0 commit comments

Comments
 (0)