Skip to content

Commit

Permalink
[SparseTIR] Index Lowering (apache#8)
Browse files Browse the repository at this point in the history
* Add StmtFunctor/ExprFunctor for SparseBufferStore/Load

* Add basic index lowering

* Finish index lowering (maybe)

* Address comments

* Convert CRLF to LF
  • Loading branch information
MasterJH5574 committed Dec 22, 2021
1 parent 296b3ca commit 30533b4
Show file tree
Hide file tree
Showing 8 changed files with 379 additions and 2 deletions.
4 changes: 4 additions & 0 deletions include/tvm/tir/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ class ExprFunctor<R(const PrimExpr& n, Args...)> {
return VisitExpr_(static_cast<const VarNode*>(op), std::forward<Args>(args)...);
}
virtual R VisitExpr_(const BufferLoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const SparseBufferLoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const ProducerLoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const LoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
Expand Down Expand Up @@ -165,6 +166,7 @@ class ExprFunctor<R(const PrimExpr& n, Args...)> {
IR_EXPR_FUNCTOR_DISPATCH(SizeVarNode);
IR_EXPR_FUNCTOR_DISPATCH(LoadNode);
IR_EXPR_FUNCTOR_DISPATCH(BufferLoadNode);
IR_EXPR_FUNCTOR_DISPATCH(SparseBufferLoadNode);
IR_EXPR_FUNCTOR_DISPATCH(ProducerLoadNode);
IR_EXPR_FUNCTOR_DISPATCH(LetNode);
IR_EXPR_FUNCTOR_DISPATCH(CallNode);
Expand Down Expand Up @@ -217,6 +219,7 @@ class TVM_DLL ExprVisitor : public ExprFunctor<void(const PrimExpr&)> {
void VisitExpr_(const SizeVarNode* op) override;
void VisitExpr_(const LoadNode* op) override;
void VisitExpr_(const BufferLoadNode* op) override;
void VisitExpr_(const SparseBufferLoadNode* op) override;
void VisitExpr_(const ProducerLoadNode* op) override;
void VisitExpr_(const LetNode* op) override;
void VisitExpr_(const CallNode* op) override;
Expand Down Expand Up @@ -264,6 +267,7 @@ class TVM_DLL ExprMutator : protected ExprFunctor<PrimExpr(const PrimExpr&)> {
PrimExpr VisitExpr_(const SizeVarNode* op) override;
PrimExpr VisitExpr_(const LoadNode* op) override;
PrimExpr VisitExpr_(const BufferLoadNode* op) override;
PrimExpr VisitExpr_(const SparseBufferLoadNode* op) override;
PrimExpr VisitExpr_(const ProducerLoadNode* op) override;
PrimExpr VisitExpr_(const LetNode* op) override;
PrimExpr VisitExpr_(const CallNode* op) override;
Expand Down
4 changes: 4 additions & 0 deletions include/tvm/tir/stmt_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
virtual R VisitStmt_(const AllocateNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const StoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const BufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const SparseBufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const BufferRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const AssertStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const ProducerStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
Expand Down Expand Up @@ -121,6 +122,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
IR_STMT_FUNCTOR_DISPATCH(SeqStmtNode);
IR_STMT_FUNCTOR_DISPATCH(EvaluateNode);
IR_STMT_FUNCTOR_DISPATCH(BufferStoreNode);
IR_STMT_FUNCTOR_DISPATCH(SparseBufferStoreNode);
IR_STMT_FUNCTOR_DISPATCH(BufferRealizeNode);
IR_STMT_FUNCTOR_DISPATCH(BlockNode);
IR_STMT_FUNCTOR_DISPATCH(BlockRealizeNode);
Expand Down Expand Up @@ -157,6 +159,7 @@ class TVM_DLL StmtVisitor : protected StmtFunctor<void(const Stmt&)> {
void VisitStmt_(const AllocateNode* op) override;
void VisitStmt_(const StoreNode* op) override;
void VisitStmt_(const BufferStoreNode* op) override;
void VisitStmt_(const SparseBufferStoreNode* op) override;
void VisitStmt_(const BufferRealizeNode* op) override;
void VisitStmt_(const AssertStmtNode* op) override;
void VisitStmt_(const ProducerStoreNode* op) override;
Expand Down Expand Up @@ -257,6 +260,7 @@ class TVM_DLL StmtMutator : protected StmtFunctor<Stmt(const Stmt&)> {
Stmt VisitStmt_(const AllocateNode* op) override;
Stmt VisitStmt_(const StoreNode* op) override;
Stmt VisitStmt_(const BufferStoreNode* op) override;
Stmt VisitStmt_(const SparseBufferStoreNode* op) override;
Stmt VisitStmt_(const BufferRealizeNode* op) override;
Stmt VisitStmt_(const AssertStmtNode* op) override;
Stmt VisitStmt_(const ProducerStoreNode* op) override;
Expand Down
6 changes: 6 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,12 @@ TVM_DLL Pass MergeDynamicSharedMemoryAllocations();
*/
TVM_DLL Pass ConvertForLoopsToSerial();

/*!
* \brief Lower SparseTIR to TIR.
* \return The pass.
*/
TVM_DLL Pass LowerSparseTIR();

} // namespace transform
} // namespace tir
} // namespace tvm
Expand Down
11 changes: 11 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,3 +749,14 @@ def ConvertForLoopsToSerial():
The result pass
"""
return _ffi_api.ConvertForLoopsToSerial() # type: ignore


def LowerSparseTIR():
"""Lower SparseTIR to TIR
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.LowerSparseTIR() # type: ignore
14 changes: 14 additions & 0 deletions src/tir/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ void ExprVisitor::VisitExpr_(const BufferLoadNode* op) {
VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); });
}

void ExprVisitor::VisitExpr_(const SparseBufferLoadNode* op) {
VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); });
}

void ExprVisitor::VisitExpr_(const ProducerLoadNode* op) {
VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); });
}
Expand Down Expand Up @@ -146,6 +150,16 @@ PrimExpr ExprMutator::VisitExpr_(const BufferLoadNode* op) {
}
}

PrimExpr ExprMutator::VisitExpr_(const SparseBufferLoadNode* op) {
auto fmutate = [this](const PrimExpr& e) { return this->VisitExpr(e); };
Array<PrimExpr> indices = MutateArray(op->indices, fmutate);
if (indices.same_as(op->indices)) {
return GetRef<PrimExpr>(op);
} else {
return SparseBufferLoad(op->buffer, indices);
}
};

PrimExpr ExprMutator::VisitExpr_(const ProducerLoadNode* op) {
auto fmutate = [this](const PrimExpr& e) { return this->VisitExpr(e); };
Array<PrimExpr> indices = MutateArray(op->indices, fmutate);
Expand Down
17 changes: 15 additions & 2 deletions src/tir/ir/sparse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
* \file sparse.cc
* \brief buffers and formats in sparse tir.
*/
#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/sparse.h>
Expand Down Expand Up @@ -158,9 +159,21 @@ SpIterVar::SpIterVar(String name, PrimExpr max_extent, SpIterKind kind, bool is_
Optional<Axis> axis) {
ObjectPtr<SpIterVarNode> node = make_object<SpIterVarNode>();

arith::Analyzer ana;
if (axis.defined()) {
CHECK(ana.CanProveEqual(axis.value()->length, max_extent));
}
if (kind != SpIterKind::kDenseFixed) {
CHECK(axis.defined()) << "ValueError: To create a SpIterVar that is not fixed-dense, one must "
"specify the axis over which the SpIterVar iterates";
const char* err_str = "ValueError: The given kind doesn't match the type of the given axis";
if (kind == SpIterKind::kDenseVariable) {
CHECK(axis.value()->IsInstance<DenseFixedAxisNode>()) << err_str;
} else if (kind == SpIterKind::kSparseFixed) {
CHECK(axis.value()->IsInstance<SparseFixedAxisNode>()) << err_str;
} else if (kind == SpIterKind::kSparseVariable) {
CHECK(axis.value()->IsInstance<SparseVariableAxisNode>()) << err_str;
}
}

node->var = Var(std::move(name));
Expand All @@ -174,9 +187,9 @@ SpIterVar::SpIterVar(String name, PrimExpr max_extent, SpIterKind kind, bool is_
TVM_REGISTER_NODE_TYPE(SpIterVarNode);

TVM_REGISTER_GLOBAL("tir.sparse.SpIterVar")
.set_body_typed([](String name, PrimExpr max_extent, SpIterKind kind, bool is_reduction,
.set_body_typed([](String name, PrimExpr max_extent, int kind, bool is_reduction,
Optional<Axis> axis) {
return SpIterVar(name, max_extent, kind, is_reduction, axis);
return SpIterVar(name, max_extent, SpIterKind(kind), is_reduction, axis);
});

} // namespace tir
Expand Down
19 changes: 19 additions & 0 deletions src/tir/ir/stmt_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ void StmtVisitor::VisitStmt_(const BufferStoreNode* op) {
VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); });
}

void StmtVisitor::VisitStmt_(const SparseBufferStoreNode* op) {
this->VisitExpr(op->value);
VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); });
}

void StmtVisitor::VisitStmt_(const BufferRealizeNode* op) {
VisitArray(op->bounds, [this](const Range& r) {
this->VisitExpr(r->min);
Expand Down Expand Up @@ -367,6 +372,20 @@ Stmt StmtMutator::VisitStmt_(const BufferStoreNode* op) {
}
}

Stmt StmtMutator::VisitStmt_(const SparseBufferStoreNode* op) {
PrimExpr value = this->VisitExpr(op->value);
Array<PrimExpr> indices = Internal::Mutate(this, op->indices);

if (value.same_as(op->value) && indices.same_as(op->indices)) {
return GetRef<Stmt>(op);
} else {
auto n = CopyOnWrite(op);
n->value = std::move(value);
n->indices = std::move(indices);
return Stmt(n);
}
}

Stmt StmtMutator::VisitStmt_(const BufferRealizeNode* op) {
Region bounds = Internal::Mutate(this, op->bounds);
PrimExpr condition = this->VisitExpr(op->condition);
Expand Down
Loading

0 comments on commit 30533b4

Please sign in to comment.