Skip to content

Commit affb1af

Browse files
MasterJH5574yzh119
authored andcommitted
[SparseTIR] Index Lowering (#8)
* Add StmtFunctor/ExprFunctor for SparseBufferStore/Load * Add basic index lowering * Finish index lowering (maybe) * Address comments * Convert CRLF to LF
1 parent 65ce747 commit affb1af

File tree

8 files changed

+379
-2
lines changed

8 files changed

+379
-2
lines changed

include/tvm/tir/expr_functor.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ class ExprFunctor<R(const PrimExpr& n, Args...)> {
119119
return VisitExpr_(static_cast<const VarNode*>(op), std::forward<Args>(args)...);
120120
}
121121
virtual R VisitExpr_(const BufferLoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
122+
virtual R VisitExpr_(const SparseBufferLoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
122123
virtual R VisitExpr_(const ProducerLoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
123124
virtual R VisitExpr_(const LoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
124125
virtual R VisitExpr_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
@@ -165,6 +166,7 @@ class ExprFunctor<R(const PrimExpr& n, Args...)> {
165166
IR_EXPR_FUNCTOR_DISPATCH(SizeVarNode);
166167
IR_EXPR_FUNCTOR_DISPATCH(LoadNode);
167168
IR_EXPR_FUNCTOR_DISPATCH(BufferLoadNode);
169+
IR_EXPR_FUNCTOR_DISPATCH(SparseBufferLoadNode);
168170
IR_EXPR_FUNCTOR_DISPATCH(ProducerLoadNode);
169171
IR_EXPR_FUNCTOR_DISPATCH(LetNode);
170172
IR_EXPR_FUNCTOR_DISPATCH(CallNode);
@@ -217,6 +219,7 @@ class TVM_DLL ExprVisitor : public ExprFunctor<void(const PrimExpr&)> {
217219
void VisitExpr_(const SizeVarNode* op) override;
218220
void VisitExpr_(const LoadNode* op) override;
219221
void VisitExpr_(const BufferLoadNode* op) override;
222+
void VisitExpr_(const SparseBufferLoadNode* op) override;
220223
void VisitExpr_(const ProducerLoadNode* op) override;
221224
void VisitExpr_(const LetNode* op) override;
222225
void VisitExpr_(const CallNode* op) override;
@@ -264,6 +267,7 @@ class TVM_DLL ExprMutator : protected ExprFunctor<PrimExpr(const PrimExpr&)> {
264267
PrimExpr VisitExpr_(const SizeVarNode* op) override;
265268
PrimExpr VisitExpr_(const LoadNode* op) override;
266269
PrimExpr VisitExpr_(const BufferLoadNode* op) override;
270+
PrimExpr VisitExpr_(const SparseBufferLoadNode* op) override;
267271
PrimExpr VisitExpr_(const ProducerLoadNode* op) override;
268272
PrimExpr VisitExpr_(const LetNode* op) override;
269273
PrimExpr VisitExpr_(const CallNode* op) override;

include/tvm/tir/stmt_functor.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
8989
virtual R VisitStmt_(const AllocateNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
9090
virtual R VisitStmt_(const StoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
9191
virtual R VisitStmt_(const BufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
92+
virtual R VisitStmt_(const SparseBufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
9293
virtual R VisitStmt_(const BufferRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
9394
virtual R VisitStmt_(const AssertStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
9495
virtual R VisitStmt_(const ProducerStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
@@ -121,6 +122,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
121122
IR_STMT_FUNCTOR_DISPATCH(SeqStmtNode);
122123
IR_STMT_FUNCTOR_DISPATCH(EvaluateNode);
123124
IR_STMT_FUNCTOR_DISPATCH(BufferStoreNode);
125+
IR_STMT_FUNCTOR_DISPATCH(SparseBufferStoreNode);
124126
IR_STMT_FUNCTOR_DISPATCH(BufferRealizeNode);
125127
IR_STMT_FUNCTOR_DISPATCH(BlockNode);
126128
IR_STMT_FUNCTOR_DISPATCH(BlockRealizeNode);
@@ -157,6 +159,7 @@ class TVM_DLL StmtVisitor : protected StmtFunctor<void(const Stmt&)> {
157159
void VisitStmt_(const AllocateNode* op) override;
158160
void VisitStmt_(const StoreNode* op) override;
159161
void VisitStmt_(const BufferStoreNode* op) override;
162+
void VisitStmt_(const SparseBufferStoreNode* op) override;
160163
void VisitStmt_(const BufferRealizeNode* op) override;
161164
void VisitStmt_(const AssertStmtNode* op) override;
162165
void VisitStmt_(const ProducerStoreNode* op) override;
@@ -257,6 +260,7 @@ class TVM_DLL StmtMutator : protected StmtFunctor<Stmt(const Stmt&)> {
257260
Stmt VisitStmt_(const AllocateNode* op) override;
258261
Stmt VisitStmt_(const StoreNode* op) override;
259262
Stmt VisitStmt_(const BufferStoreNode* op) override;
263+
Stmt VisitStmt_(const SparseBufferStoreNode* op) override;
260264
Stmt VisitStmt_(const BufferRealizeNode* op) override;
261265
Stmt VisitStmt_(const AssertStmtNode* op) override;
262266
Stmt VisitStmt_(const ProducerStoreNode* op) override;

include/tvm/tir/transform.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,12 @@ TVM_DLL Pass ConvertForLoopsToSerial();
500500
*/
501501
TVM_DLL Pass UnifiedStaticMemoryPlanner();
502502

503+
/*!
504+
* \brief Lower SparseTIR to TIR.
505+
* \return The pass.
506+
*/
507+
TVM_DLL Pass LowerSparseTIR();
508+
503509
} // namespace transform
504510
} // namespace tir
505511
} // namespace tvm

python/tvm/tir/transform/transform.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -760,3 +760,14 @@ def ConvertForLoopsToSerial():
760760
The result pass
761761
"""
762762
return _ffi_api.ConvertForLoopsToSerial() # type: ignore
763+
764+
765+
def LowerSparseTIR():
766+
"""Lower SparseTIR to TIR
767+
768+
Returns
769+
-------
770+
fpass : tvm.transform.Pass
771+
The result pass
772+
"""
773+
return _ffi_api.LowerSparseTIR() # type: ignore

src/tir/ir/expr_functor.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ void ExprVisitor::VisitExpr_(const BufferLoadNode* op) {
4343
VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); });
4444
}
4545

46+
void ExprVisitor::VisitExpr_(const SparseBufferLoadNode* op) {
47+
VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); });
48+
}
49+
4650
void ExprVisitor::VisitExpr_(const ProducerLoadNode* op) {
4751
VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); });
4852
}
@@ -146,6 +150,16 @@ PrimExpr ExprMutator::VisitExpr_(const BufferLoadNode* op) {
146150
}
147151
}
148152

153+
PrimExpr ExprMutator::VisitExpr_(const SparseBufferLoadNode* op) {
154+
auto fmutate = [this](const PrimExpr& e) { return this->VisitExpr(e); };
155+
Array<PrimExpr> indices = MutateArray(op->indices, fmutate);
156+
if (indices.same_as(op->indices)) {
157+
return GetRef<PrimExpr>(op);
158+
} else {
159+
return SparseBufferLoad(op->buffer, indices);
160+
}
161+
};
162+
149163
PrimExpr ExprMutator::VisitExpr_(const ProducerLoadNode* op) {
150164
auto fmutate = [this](const PrimExpr& e) { return this->VisitExpr(e); };
151165
Array<PrimExpr> indices = MutateArray(op->indices, fmutate);

src/tir/ir/sparse.cc

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
* \file sparse.cc
2222
* \brief buffers and formats in sparse tir.
2323
*/
24+
#include <tvm/arith/analyzer.h>
2425
#include <tvm/runtime/registry.h>
2526
#include <tvm/tir/buffer.h>
2627
#include <tvm/tir/sparse.h>
@@ -158,9 +159,21 @@ SpIterVar::SpIterVar(String name, PrimExpr max_extent, SpIterKind kind, bool is_
158159
Optional<Axis> axis) {
159160
ObjectPtr<SpIterVarNode> node = make_object<SpIterVarNode>();
160161

162+
arith::Analyzer ana;
163+
if (axis.defined()) {
164+
CHECK(ana.CanProveEqual(axis.value()->length, max_extent));
165+
}
161166
if (kind != SpIterKind::kDenseFixed) {
162167
CHECK(axis.defined()) << "ValueError: To create a SpIterVar that is not fixed-dense, one must "
163168
"specify the axis over which the SpIterVar iterates";
169+
const char* err_str = "ValueError: The given kind doesn't match the type of the given axis";
170+
if (kind == SpIterKind::kDenseVariable) {
171+
CHECK(axis.value()->IsInstance<DenseFixedAxisNode>()) << err_str;
172+
} else if (kind == SpIterKind::kSparseFixed) {
173+
CHECK(axis.value()->IsInstance<SparseFixedAxisNode>()) << err_str;
174+
} else if (kind == SpIterKind::kSparseVariable) {
175+
CHECK(axis.value()->IsInstance<SparseVariableAxisNode>()) << err_str;
176+
}
164177
}
165178

166179
node->var = Var(std::move(name));
@@ -174,9 +187,9 @@ SpIterVar::SpIterVar(String name, PrimExpr max_extent, SpIterKind kind, bool is_
174187
TVM_REGISTER_NODE_TYPE(SpIterVarNode);
175188

176189
TVM_REGISTER_GLOBAL("tir.sparse.SpIterVar")
177-
.set_body_typed([](String name, PrimExpr max_extent, SpIterKind kind, bool is_reduction,
190+
.set_body_typed([](String name, PrimExpr max_extent, int kind, bool is_reduction,
178191
Optional<Axis> axis) {
179-
return SpIterVar(name, max_extent, kind, is_reduction, axis);
192+
return SpIterVar(name, max_extent, SpIterKind(kind), is_reduction, axis);
180193
});
181194

182195
} // namespace tir

src/tir/ir/stmt_functor.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,11 @@ void StmtVisitor::VisitStmt_(const BufferStoreNode* op) {
6969
VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); });
7070
}
7171

72+
void StmtVisitor::VisitStmt_(const SparseBufferStoreNode* op) {
73+
this->VisitExpr(op->value);
74+
VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); });
75+
}
76+
7277
void StmtVisitor::VisitStmt_(const BufferRealizeNode* op) {
7378
VisitArray(op->bounds, [this](const Range& r) {
7479
this->VisitExpr(r->min);
@@ -367,6 +372,20 @@ Stmt StmtMutator::VisitStmt_(const BufferStoreNode* op) {
367372
}
368373
}
369374

375+
Stmt StmtMutator::VisitStmt_(const SparseBufferStoreNode* op) {
376+
PrimExpr value = this->VisitExpr(op->value);
377+
Array<PrimExpr> indices = Internal::Mutate(this, op->indices);
378+
379+
if (value.same_as(op->value) && indices.same_as(op->indices)) {
380+
return GetRef<Stmt>(op);
381+
} else {
382+
auto n = CopyOnWrite(op);
383+
n->value = std::move(value);
384+
n->indices = std::move(indices);
385+
return Stmt(n);
386+
}
387+
}
388+
370389
Stmt StmtMutator::VisitStmt_(const BufferRealizeNode* op) {
371390
Region bounds = Internal::Mutate(this, op->bounds);
372391
PrimExpr condition = this->VisitExpr(op->condition);

0 commit comments

Comments
 (0)