Skip to content

Commit 81319b4

Browse files
MasterJH5574yzh119
authored andcommitted
[BugFix] Fix offset caching in lowering (#38)
* Hack compact dataflow check in a dirty way * Add two-K square sum test * Mark skipped tests * Fix offset saving in lowering
1 parent 02ef77f commit 81319b4

File tree

3 files changed

+120
-12
lines changed

3 files changed

+120
-12
lines changed

src/tir/schedule/analysis/analysis.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,8 @@ Definition of a scope that is a stage pipeline:
168168
if (it_atomic != block->annotations.end()) {
169169
is_atomic = ((*it_atomic).second).as<IntImmNode>()->value;
170170
}
171-
if (!is_atomic) {
171+
// Todo(ruihang): Temporary hack. Deal with the "sparse" annotation later.
172+
if (!is_atomic && block->annotations.find("sparse") == block->annotations.end()) {
172173
throw NotCompactDataFlowError(self->mod, GetRef<Stmt>(scope_root_subtree->stmt),
173174
GetRef<Block>(block));
174175
}

src/tir/transforms/lower_sparse_tir.cc

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -322,10 +322,6 @@ class SparseBufferCtx {
322322
matches_.emplace_back(axis->name == sp_iter_var->axis->name);
323323
}
324324
}
325-
326-
// update offset
327-
PrimExpr new_offset = AggregateOffset(offsets_.back(), axis, std::move(coordinate), ana_);
328-
offsets_.emplace_back(std::move(new_offset));
329325
}
330326

331327
/*! \brief get the axis given dimension index of current buffer. */
@@ -341,7 +337,7 @@ class SparseBufferCtx {
341337
AggregateOffset(add(offsets_[dim], 1), axis, Integer(0), ana_)};
342338
}
343339

344-
private:
340+
public:
345341
String buf_name_;
346342
Array<Axis> axes_;
347343
std::vector<PrimExpr> offsets_;
@@ -375,7 +371,12 @@ class SparseBufferCtx {
375371
top()->Register(dim, std::move(coordinate), std::move(orig_idx));
376372
}
377373

378-
private:
374+
void AddOffset(int dim, PrimExpr offset) {
375+
ICHECK_EQ(dim + 1, static_cast<int>(top()->offsets_.size()));
376+
top()->offsets_.push_back(offset);
377+
}
378+
379+
public:
379380
std::vector<Scope> stack_;
380381
arith::Analyzer* ana_;
381382

@@ -421,18 +422,22 @@ class IndexTransformer : public StmtExprMutator {
421422
auto sf_axis = axis.as<SparseFixedAxisNode>();
422423
PrimExpr l, r;
423424
std::tie(l, r) = sp_buf_ctx_.GetIndicesRange(dim);
424-
offset = lower_bound(sf_axis->indices->data, coordinate, l, r);
425+
offset = lower_bound(sf_axis->indices->data, coordinate, l, r) - l;
425426
break;
426427
}
427428
case AxisKind::kSparseVariable:
428429
auto sv_axis = axis.as<SparseVariableAxisNode>();
429430
PrimExpr l, r;
430431
std::tie(l, r) = sp_buf_ctx_.GetIndicesRange(dim);
431-
offset = lower_bound(sv_axis->indices->data, coordinate, l, r);
432+
offset = lower_bound(sv_axis->indices->data, coordinate, l, r) - l;
432433
break;
433434
}
434435
}
435436

437+
// update offset
438+
PrimExpr new_offset = AggregateOffset(sp_buf_ctx_.top()->offsets_.back(), axis,
439+
offset, sp_buf_ctx_.ana_);
440+
sp_buf_ctx_.top()->offsets_.push_back(std::move(new_offset));
436441
return offset;
437442
}
438443

@@ -562,7 +567,8 @@ class IndexTransformer : public StmtExprMutator {
562567
Axis axis = sp_it_var->axis;
563568
auto parent = axis->GetParentAxis();
564569
bool create_new_blk = false;
565-
bool is_fixed_axis = axis->kind() == AxisKind::kDenseFixed || axis->kind() == AxisKind::kSparseFixed;
570+
bool is_fixed_axis =
571+
axis->kind() == AxisKind::kDenseFixed || axis->kind() == AxisKind::kSparseFixed;
566572
if (!is_fixed_axis && parent.defined()) {
567573
const AxisNode* parent_node = parent.value().get();
568574
if (in_block.find(parent_node) != in_block.end()) {
@@ -572,7 +578,8 @@ class IndexTransformer : public StmtExprMutator {
572578
/* parent node is in the previous blocks in the stack, no need to create new block. */
573579
create_new_blk = false;
574580
} else {
575-
CHECK(false) << "The parent axis of " << axis->GetName() << " should appear before " << axis->GetName() << " when defining a sparse block.";
581+
CHECK(false) << "The parent axis of " << axis->GetName() << " should appear before "
582+
<< axis->GetName() << " when defining a sparse block.";
576583
}
577584
}
578585
if (create_new_blk) {

tests/python/sparsetir/test_tir_sparse_lower.py

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import tvm.tir as tir
2020
import scipy.sparse as sp
2121
import numpy as np
22+
import pytest
2223
from tvm.script import tir as T
2324

2425

@@ -367,7 +368,7 @@ def lowered_square_sum(a: T.handle, b: T.handle, indptr_j: T.handle, indices_j:
367368
J_indices = T.match_buffer(indices_j, [nnz_j], dtype="int32")
368369
K_indptr = T.match_buffer(indptr_k, [nnz_j + 1], dtype="int32")
369370
K_indices = T.match_buffer(indices_k, [nnz_k], dtype="int32")
370-
371+
371372
for v_vi in T.serial(0, M):
372373
with T.block("square_sum_2"):
373374
vi = T.axis.spatial(M, v_vi)
@@ -391,6 +392,58 @@ def lowered_square_sum(a: T.handle, b: T.handle, indptr_j: T.handle, indices_j:
391392
B_data[vi] = B_data[vi] + A_data[K_indptr[J_indptr[vi] + vj] + vk]
392393

393394

395+
@T.prim_func
396+
def square_sum_two_K(a: T.handle, b: T.handle, indptr_j: T.handle, indices_j: T.handle, indptr_k0: T.handle, indices_k0: T.handle, indptr_k1: T.handle, indices_k1: T.handle, nnz_j: T.int32, nnz_k: T.int32, M: T.int32, N1: T.int32, N2: T.int32):
397+
# Used only for testing `GetIndicesRange()`.
398+
# Currently it is ensured that `indptr_k0` is the same as `indptr_k1`, and `indices_k0` is the
399+
# same as `indices_k1`.
400+
I = T.dense_fixed(M)
401+
J = T.sparse_variable(I, (N1, nnz_j), (indptr_j, indices_j), "int32")
402+
K0 = T.sparse_variable(J, (N2, nnz_k), (indptr_k0, indices_k0), "int32")
403+
K1 = T.sparse_variable(J, (N2, nnz_k), (indptr_k1, indices_k1), "int32")
404+
A = T.match_sparse_buffer(a, (I, J, K0), "float32")
405+
B = T.match_sparse_buffer(b, (I,), "float32")
406+
407+
with T.iter([I, J, K1], "SRR", "square_sum") as [vi, vj, vk]:
408+
with T.init():
409+
B[vi] = 0.0
410+
B[vi] = B[vi] + A[vi, vj, vk]
411+
412+
413+
@T.prim_func
414+
def lowered_square_sum_two_K(a: T.handle, b: T.handle, indptr_j: T.handle, indices_j: T.handle, indptr_k0: T.handle, indices_k0: T.handle, indptr_k1: T.handle, indices_k1: T.handle, nnz_j: T.int32, nnz_k: T.int32, M: T.int32, N1: T.int32, N2: T.int32) -> None:
415+
A_data = T.match_buffer(a, [nnz_k], dtype="float32")
416+
B_data = T.match_buffer(b, [M], dtype="float32")
417+
J_indptr = T.match_buffer(indptr_j, [M + 1], dtype="int32")
418+
J_indices = T.match_buffer(indices_j, [nnz_j], dtype="int32")
419+
K0_indptr = T.match_buffer(indptr_k0, [nnz_j + 1], dtype="int32")
420+
K0_indices = T.match_buffer(indices_k0, [nnz_k], dtype="int32")
421+
K1_indptr = T.match_buffer(indptr_k1, [nnz_j + 1], dtype="int32")
422+
K1_indices = T.match_buffer(indices_k1, [nnz_k], dtype="int32")
423+
424+
for v_vi in T.serial(0, M):
425+
with T.block("square_sum_2"):
426+
vi = T.axis.spatial(M, v_vi)
427+
T.reads([J_indptr[0 : M + 1], J_indices[0 : nnz_j], K0_indptr[0 : nnz_j + 1], K0_indices[0 : nnz_k], K1_indptr[0 : nnz_j + 1], K1_indices[0 : nnz_k], A_data[0 : nnz_k], B_data[0 : M]])
428+
T.writes([B_data[0 : M]])
429+
T.block_attr({"sparse":True})
430+
for v_vj in T.serial(0, J_indptr[v_vi + 1] - J_indptr[v_vi]):
431+
with T.block("square_sum_1"):
432+
vj = T.axis.reduce(J_indptr[v_vi + 1] - J_indptr[v_vi], v_vj)
433+
T.reads([J_indptr[0 : M + 1], J_indices[0 : nnz_j], K0_indptr[0 : nnz_j + 1], K0_indices[0 : nnz_k], K1_indptr[0 : nnz_j + 1], K1_indices[0 : nnz_k], A_data[0 : nnz_k], B_data[0 : M]])
434+
T.writes([B_data[0 : M]])
435+
T.block_attr({"sparse":True})
436+
with T.init():
437+
B_data[vi] = T.float32(0)
438+
for v_vk in T.serial(0, K1_indptr[J_indptr[v_vi] + v_vj + 1] - K1_indptr[J_indptr[v_vi] + v_vj]):
439+
with T.block("square_sum"):
440+
vk = T.axis.reduce(K1_indptr[J_indptr[v_vi] + v_vj + 1] - K1_indptr[J_indptr[v_vi] + v_vj], v_vk)
441+
T.reads([J_indptr[0 : M + 1], J_indices[0 : nnz_j], K0_indptr[0 : nnz_j + 1], K0_indices[0 : nnz_k], K1_indptr[0 : nnz_j + 1], K1_indices[0 : nnz_k], A_data[0 : nnz_k], B_data[0 : M]])
442+
T.writes([B_data[0 : M]])
443+
T.block_attr({"sparse":True})
444+
B_data[vi] = B_data[vi] + A_data[T.tvm_lower_bound(K0_indices.data, K1_indices[K1_indptr[J_indptr[vi] + vj] + vk], K0_indptr[J_indptr[vi] + vj], K0_indptr[J_indptr[vi] + vj + 1], dtype="int32")]
445+
446+
394447
def test_csrmm():
395448
mod = tvm.IRModule.from_expr(csrmm)
396449
mod = tvm.tir.transform.LowerSparseTIR()(mod)
@@ -414,13 +467,15 @@ def test_csrmm():
414467
tvm.testing.assert_allclose(y_ground_truth.reshape(-1), Y_nd.numpy(), rtol=1e-5, atol=1e-5)
415468

416469

470+
@pytest.mark.skip(reason="Under implementation")
417471
def test_csrmm_dense_iter():
418472
mod = tvm.IRModule.from_expr(csrmm_dense_iter)
419473
mod = tvm.tir.transform.LowerSparseTIR()(mod)
420474
# tvm.ir.assert_structural_equal(mod["main"], lowered_csrmm, True)
421475
# Todo
422476

423477

478+
@pytest.mark.skip(reason="Under implementation")
424479
def test_segment_reduce():
425480
mod = tvm.IRModule.from_expr(segment_reduce)
426481
mod = tvm.tir.transform.LowerSparseTIR()(mod)
@@ -557,6 +612,7 @@ def test_csr_element_wise():
557612
tvm.testing.assert_allclose(b_ground_truth.data.reshape(-1), B_nd.numpy(), rtol=1e-5, atol=1e-5)
558613

559614

615+
@pytest.mark.skip(reason="Under implementation")
560616
def test_bmm():
561617
mod = tvm.IRModule.from_expr(bmm)
562618
mod = tvm.tir.transform.LowerSparseTIR()(mod)
@@ -600,6 +656,49 @@ def test_square_sum():
600656
tvm.testing.assert_allclose(b_ground_truth, B_data.numpy(), rtol=1e-5, atol=1e-5)
601657

602658

659+
def test_square_sum_two_K():
660+
mod = tvm.IRModule.from_expr(square_sum_two_K)
661+
mod = tvm.tir.transform.LowerSparseTIR()(mod)
662+
tvm.ir.assert_structural_equal(mod["main"], lowered_square_sum_two_K, True)
663+
664+
sch = tir.Schedule(mod, debug_mask="all")
665+
i, = sch.get_loops(sch.get_block("square_sum_2"))
666+
sch.bind(i, "threadIdx.x")
667+
668+
density = 0.0125
669+
M = N1 = N2 = 128
670+
A_J = sp.random(M, N1, dtype="float32", density=1 - (1 - density) ** N2, format="csr")
671+
indptr_j = A_J.indptr
672+
indices_j = A_J.indices
673+
nnz_j = A_J.nnz
674+
A_K = sp.random(nnz_j, N2, dtype="float32", density=density, format="csr")
675+
indptr_k = A_K.indptr
676+
indices_k = A_K.indices
677+
nnz_k = A_K.nnz
678+
data = A_K.data
679+
680+
b_ij = np.asarray(A_K.sum(axis=1)).squeeze()
681+
A_J = sp.csr_matrix((b_ij, indices_j, indptr_j), shape=(M, N1))
682+
b_ground_truth = np.asarray(A_J.sum(axis=1)).squeeze()
683+
b = np.zeros((M,)).astype("float32")
684+
685+
v_nnz_j, v_nnz_k, v_M, v_N1, v_N2 = square_sum_two_K.params[-5:]
686+
f = tvm.build(sch.mod["main"].specialize({v_nnz_j: nnz_j, v_nnz_k: nnz_k, v_M: M, v_N1: N1, v_N2: N2}), target="cuda")
687+
688+
ctx = tvm.device("cuda")
689+
A_data = tvm.nd.array(data.astype("float32"), device=ctx)
690+
A_indptr_j = tvm.nd.array(indptr_j.astype("int32"), device=ctx)
691+
A_indices_j = tvm.nd.array(indices_j.astype("int32"), device=ctx)
692+
A_indptr_k0 = tvm.nd.array(indptr_k.astype("int32"), device=ctx)
693+
A_indices_k0 = tvm.nd.array(indices_k.astype("int32"), device=ctx)
694+
A_indptr_k1 = tvm.nd.array(indptr_k.astype("int32"), device=ctx)
695+
A_indices_k1 = tvm.nd.array(indices_k.astype("int32"), device=ctx)
696+
B_data = tvm.nd.array(b.astype("float32"), device=ctx)
697+
f(A_data, B_data, A_indptr_j, A_indices_j, A_indptr_k0, A_indices_k0, A_indptr_k1, A_indices_k1)
698+
699+
tvm.testing.assert_allclose(b_ground_truth, B_data.numpy(), rtol=1e-5, atol=1e-5)
700+
701+
603702
if __name__ == "__main__":
604703
test_csrmm()
605704
test_csrmm_dense_iter()
@@ -610,3 +709,4 @@ def test_square_sum():
610709
test_csr_element_wise()
611710
test_bmm()
612711
test_square_sum()
712+
test_square_sum_two_K()

0 commit comments

Comments
 (0)