Skip to content

Commit 3ecc489

Browse files
committed
Fused SDDMM example. (#46)
* upd * wip * fix
1 parent d032263 commit 3ecc489

File tree

3 files changed

+50
-20
lines changed

3 files changed

+50
-20
lines changed

include/tvm/tir/sparse.h

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,11 @@ class FusedAxis;
214214
/*! \brief Derivation axis, constructed by T.fuse(axis1, axis2, ...) */
215215
class FusedAxisNode : public DenseFixedAxisNode {
216216
public:
217+
/* The group of axes to be fused. */
218+
Array<Axis> group;
219+
/* The index of current FusedAxis in the group. */
220+
int index;
221+
217222
void VisitAttrs(AttrVisitor* v) {
218223
DenseFixedAxisNode::VisitAttrs(v);
219224
v->Visit("group", &group);
@@ -231,11 +236,6 @@ class FusedAxisNode : public DenseFixedAxisNode {
231236
hash_reduce(index);
232237
}
233238

234-
/* The group of axes to be fused. */
235-
Array<Axis> group;
236-
/* The index of current FusedAxis in the group. */
237-
int index;
238-
239239
static constexpr const char* _type_key = "tir.sparse.FusedAxis";
240240
TVM_DECLARE_FINAL_OBJECT_INFO(FusedAxisNode, DenseFixedAxisNode);
241241
};
@@ -313,8 +313,6 @@ class AttachedAxisNode : public DenseVariableAxisNode {
313313

314314
PrimExpr Aggregate(SparseCtx* ctx, PrimExpr index) const;
315315

316-
PrimExpr Aggregate(SparseCtx* ctx, PrimExpr index) const;
317-
318316
static constexpr const char* _type_key = "tir.sparse.AttachedAxis";
319317
TVM_DECLARE_FINAL_OBJECT_INFO(AttachedAxisNode, DenseVariableAxisNode);
320318
};

src/tir/transforms/lower_sparse_tir.cc

Lines changed: 43 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -101,15 +101,12 @@ class SparseBlockCtx : public SparseCtx {
101101
struct Scope {
102102
explicit Scope(SparseBlock sp_block) : sp_block(std::move(sp_block)) {
103103
for (const SpIterVar& sp_iter_var : this->sp_block->sp_iter_vars) {
104-
axis2sp_iter.Set(sp_iter_var->axis, sp_iter_var);
105104
sp_iter_var_map.Set(sp_iter_var->var, sp_iter_var);
106105
}
107106
}
108107

109108
/*! \brief The sparse block */
110109
SparseBlock sp_block;
111-
/*! \brief A mapping from axes to the sparse iterators that go over them */
112-
Map<Axis, SpIterVar> axis2sp_iter;
113110
/*! \brief A mapping from the internal variables of sparse iterators to the iterators */
114111
Map<Var, SpIterVar> sp_iter_var_map;
115112
/*! \brief The stored offsets of the axis in the sparse block */
@@ -125,14 +122,40 @@ class SparseBlockCtx : public SparseCtx {
125122
stack_.emplace_back(GetRef<SparseBlock>(sp_block));
126123
/* Compute offsets and coordinates */
127124
size_t n_iters = sp_block->sp_iter_vars.size();
128-
for (size_t i = 0; i < n_iters; ++i) {
125+
for (size_t i = 0; i < n_iters;) {
129126
SpIterVar sp_iter_var = sp_block->sp_iter_vars[i];
130127
Axis axis = sp_iter_var->axis;
131128

132-
PrimExpr offset = AggregateOffset(this, axis, sp_iter_var->var, ana_);
133-
SetOffset(axis, offset);
134-
PrimExpr coordinate = axis->Decompress(this, offset, sp_iter_var->var);
135-
SetCoordinate(axis, coordinate);
129+
PrimExpr offset, index;
130+
if (auto fused_axis = axis.as<FusedAxisNode>()) {
131+
auto group = fused_axis->group;
132+
offset = sp_block->sp_iter_vars[i + group.size() - 1]->var;
133+
for (int j = group.size() - 1; j >= 0; --j) {
134+
Axis orig = group[j];
135+
SetOffset(orig, offset);
136+
if (j > 0) {
137+
// TODO(zihao): support more than sv axis.
138+
offset = lower_bound(Downcast<SparseVariableAxis>(orig)->indptr->data, offset,
139+
Integer(0), orig->GetNNZ());
140+
}
141+
}
142+
for (size_t j = 0; j < group.size(); ++j) {
143+
Axis orig = group[j];
144+
offset = GetOffset(orig);
145+
PrimExpr lb = std::get<0>(orig->GetOffsetExtent(this));
146+
index = offset - lb;
147+
PrimExpr coordinate = orig->Decompress(this, offset, index);
148+
SetCoordinate(orig, coordinate);
149+
i++;
150+
}
151+
} else {
152+
offset = AggregateOffset(this, axis, sp_iter_var->var, ana_);
153+
index = sp_iter_var->var;
154+
PrimExpr coordinate = axis->Decompress(this, offset, index);
155+
SetOffset(axis, offset);
156+
SetCoordinate(axis, coordinate);
157+
i++;
158+
}
136159
}
137160
}
138161

@@ -164,7 +187,7 @@ class SparseBlockCtx : public SparseCtx {
164187
*/
165188
PrimExpr GetOffset(Axis axis) const {
166189
Optional<PrimExpr> try_offset = top()->cached_offsets.Get(axis);
167-
CHECK(try_offset.defined()) << "The offset of axis not defined yet.";
190+
CHECK(try_offset.defined()) << "The offset of axis " << axis->name << " not defined yet.";
168191
PrimExpr offset = try_offset.value();
169192
return std::move(offset);
170193
}
@@ -202,7 +225,7 @@ class SparseBlockCtx : public SparseCtx {
202225
}
203226

204227
Optional<Axis> MatchAxis(SparseCtx* buf_ctx, Axis axis) {
205-
if (!top()->axis2sp_iter.Get(axis).defined()) {
228+
if (!top()->cached_offsets.Get(axis).defined()) {
206229
return NullOpt;
207230
} else {
208231
Axis axis_ = axis;
@@ -233,7 +256,11 @@ class SparseBlockCtx : public SparseCtx {
233256
if (!try_sp_iter_var.defined()) {
234257
return false;
235258
}
236-
return try_sp_iter_var.value()->axis == matched_axis.value();
259+
Axis axis = try_sp_iter_var.value()->axis;
260+
if (auto fused_axis = axis.as<FusedAxisNode>()) {
261+
axis = fused_axis->group[fused_axis->index];
262+
}
263+
return axis == matched_axis.value();
237264
}
238265

239266
private:
@@ -403,7 +430,11 @@ class IndexTransformer : public StmtExprMutator {
403430
auto try_sp_iter = sp_blk_ctx_.GetSparseIterVar(var);
404431
if (try_sp_iter.defined()) {
405432
SpIterVar sp_iter = try_sp_iter.value();
406-
return sp_blk_ctx_.GetCoordinate(sp_iter->axis);
433+
Axis axis = sp_iter->axis;
434+
if (auto fused_axis = axis.as<FusedAxisNode>()) {
435+
axis = fused_axis->group[fused_axis->index];
436+
}
437+
return sp_blk_ctx_.GetCoordinate(axis);
407438
} else {
408439
return GetRef<PrimExpr>(var);
409440
}

tests/python/sparsetir/test_tir_sparse_lower.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,7 @@ def test_sddmm():
662662
def test_fused_sddmm():
663663
mod = tvm.IRModule.from_expr(fused_sddmm)
664664
mod = tvm.tir.transform.LowerSparseTIR()(mod)
665+
print(mod["main"].script())
665666
# TODO
666667

667668

@@ -754,7 +755,7 @@ def test_square_sum_two_K():
754755
test_ellpack_mm()
755756
test_csr_element_wise()
756757
test_sddmm()
757-
# test_fused_sddmm()
758+
test_fused_sddmm()
758759
test_bmm()
759760
test_square_sum()
760761
test_square_sum_two_K()

0 commit comments

Comments
 (0)