Skip to content

Commit 675fdf9

Browse files
MasterJH5574yzh119
authored andcommitted
[SparseTIR] SparseBlock on C++/Python side (#11)
* Fix a bug in the last commit * SparseBlock on C++ & Python side
1 parent 257f769 commit 675fdf9

File tree

4 files changed

+170
-27
lines changed

4 files changed

+170
-27
lines changed

include/tvm/tir/stmt.h

Lines changed: 55 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -327,28 +327,6 @@ class BufferStore : public Stmt {
327327
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferStoreNode);
328328
};
329329

330-
/*!
331-
* \brief Sparse Block node.
332-
*/
333-
class SparseBlockNode : public StmtNode {
334-
public:
335-
/*! \brief The sparse iteration variables of the block. */
336-
Array<SpIterVar> sp_iter_vars;
337-
/*! \brief The sparse buffers defined in the block. */
338-
Array<SparseBuffer> sp_buffers;
339-
/*! \brief The body of the block */
340-
Stmt body;
341-
342-
static constexpr const char* _type_key = "tir.SparseBlock";
343-
TVM_DECLARE_FINAL_OBJECT_INFO(SparseBlockNode, StmtNode);
344-
};
345-
346-
class SparseBlock : public Stmt {
347-
public:
348-
TVM_DEFINE_OBJECT_REF_METHODS(SparseBlock, Stmt, SparseBlockNode);
349-
};
350-
351-
352330
/*!
353331
* \brief Store value to the high dimension sparse buffer.
354332
*
@@ -1300,6 +1278,61 @@ class BlockRealize : public Stmt {
13001278
TVM_DEFINE_OBJECT_REF_COW_METHOD(BlockRealizeNode);
13011279
};
13021280

1281+
/*!
1282+
* \brief Sparse Block node.
1283+
*/
1284+
class SparseBlockNode : public StmtNode {
1285+
public:
1286+
/*! \brief The sparse iteration variables of the block. */
1287+
Array<SpIterVar> sp_iter_vars;
1288+
/*! \brief The sparse buffers defined in the block. */
1289+
Array<SparseBuffer> sp_buffers;
1290+
/*! \brief The name of the block */
1291+
String name;
1292+
/*! \brief The body of the block */
1293+
Stmt body;
1294+
/*! \brief The init statement of the block */
1295+
Optional<Stmt> init;
1296+
1297+
void VisitAttrs(AttrVisitor* v) {
1298+
v->Visit("sp_iter_vars", &sp_iter_vars);
1299+
v->Visit("sp_buffers", &sp_buffers);
1300+
v->Visit("name", &name);
1301+
v->Visit("body", &body);
1302+
v->Visit("init", &init);
1303+
}
1304+
1305+
bool SEqualReduce(const SparseBlockNode* other, SEqualReducer equal) const {
1306+
return equal(sp_iter_vars, other->sp_iter_vars) && equal(sp_buffers, other->sp_buffers) &&
1307+
equal(name, other->name) && equal(body, other->body) && equal(init, other->init);
1308+
}
1309+
1310+
void SHashReduce(SHashReducer hash_reduce) const {
1311+
hash_reduce(sp_iter_vars);
1312+
hash_reduce(sp_buffers);
1313+
hash_reduce(name);
1314+
hash_reduce(body);
1315+
hash_reduce(init);
1316+
}
1317+
1318+
static constexpr const char* _type_key = "tir.SparseBlock";
1319+
TVM_DECLARE_FINAL_OBJECT_INFO(SparseBlockNode, StmtNode);
1320+
};
1321+
1322+
/*!
1323+
* \brief Managed reference to SparseBufferNode
1324+
* \sa SparseBufferNode
1325+
*/
1326+
class SparseBlock : public Stmt {
1327+
public:
1328+
TVM_DLL explicit SparseBlock(Array<SpIterVar> sp_iter_vars, Array<SparseBuffer> sp_buffers,
1329+
String name, Stmt body, Optional<Stmt> init = NullOpt,
1330+
Span span = Span());
1331+
1332+
TVM_DEFINE_OBJECT_REF_METHODS(SparseBlock, Stmt, SparseBlockNode);
1333+
TVM_DEFINE_OBJECT_REF_COW_METHOD(SparseBlockNode);
1334+
};
1335+
13031336
/*! \brief namespace of possible attribute sin AttrStmt.attr_key */
13041337
namespace attr {
13051338
// The above attr does not pass to ir stage.

python/tvm/tir/sparse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ class SpIterVar(Object):
236236
SparseFixed = 2
237237
SparseVariable = 3
238238

239-
def __init__(self, var, max_extent, kind, axis=None):
239+
def __init__(self, var, max_extent, kind, is_reduction, axis=None):
240240
self.__init_handle_by_constructor__(
241241
_ffi_api.SpIterVar, var, max_extent, kind, is_reduction, axis # type: ignore
242242
)

python/tvm/tir/stmt.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from . import _ffi_api
3737
from .buffer import Buffer
3838
from .expr import IterVar
39+
from .sparse import SpIterVar, SparseBuffer
3940

4041

4142
class Stmt(Object):
@@ -614,6 +615,58 @@ def __init__(
614615
) # type: ignore
615616

616617

618+
@tvm._ffi.register_object("tir.SparseBlock")
619+
class SparseBlock(Stmt):
620+
"""SparseBlock node.
621+
622+
Parameters
623+
----------
624+
sp_iter_vars : List[SpIterVar]
625+
The sparse iteration variables of the block.
626+
627+
sp_buffers : List[SparseBuffer]
628+
The sparse buffers defined in the block.
629+
630+
name : str
631+
The name of the block.
632+
633+
body : Stmt
634+
The body of the block.
635+
636+
init : Optional[Stmt]
637+
The init statement of the block.
638+
639+
span : Optional[Span]
640+
The location of this block in the source code.
641+
"""
642+
643+
sp_iter_vars: List[SpIterVar]
644+
sp_buffers: List[SparseBuffer]
645+
name: str
646+
body: Stmt
647+
init: Optional[Stmt]
648+
span: Optional[Span]
649+
650+
def __init__(
651+
self,
652+
sp_iter_vars: List[SpIterVar],
653+
sp_buffers: List[SparseBuffer],
654+
name: str,
655+
body: Stmt,
656+
init: Optional[Stmt] = None,
657+
span: Optional[Span] = None,
658+
):
659+
self.__init_handle_by_constructor__(
660+
_ffi_api.SparseBlock, # type: ignore
661+
sp_iter_vars,
662+
sp_buffers,
663+
name,
664+
body,
665+
init,
666+
span,
667+
) # type: ignore
668+
669+
617670
@tvm._ffi.register_object("tir.BlockRealize")
618671
class BlockRealize(Stmt):
619672
"""BlockRealize node.

src/tir/ir/stmt.cc

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -876,17 +876,21 @@ void PrintBlockSignature(const BlockNode* op, ReprPrinter* p) {
876876
}
877877
}
878878

879-
void PrintBlockBody(const BlockNode* op, ReprPrinter* p) {
880-
// Print init
881-
if (op->init.defined()) {
879+
void PrintInitStmt(const Optional<Stmt>& init, ReprPrinter* p) {
880+
if (init.defined()) {
882881
p->PrintIndent();
883882
p->stream << "with init() {\n";
884883
p->indent += 2;
885-
p->Print(op->init.value());
884+
p->Print(init.value());
886885
p->indent -= 2;
887886
p->PrintIndent();
888887
p->stream << "}\n";
889888
}
889+
}
890+
891+
void PrintBlockBody(const BlockNode* op, ReprPrinter* p) {
892+
// Print init
893+
PrintInitStmt(op->init, p);
890894
// Print body
891895
p->Print(op->body);
892896
}
@@ -964,6 +968,59 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
964968
p->stream << "}\n";
965969
});
966970

971+
SparseBlock::SparseBlock(Array<SpIterVar> sp_iter_vars, Array<SparseBuffer> sp_buffers, String name,
972+
Stmt body, Optional<Stmt> init, Span span) {
973+
ObjectPtr<SparseBlockNode> node = make_object<SparseBlockNode>();
974+
node->sp_iter_vars = std::move(sp_iter_vars);
975+
node->sp_buffers = std::move(sp_buffers);
976+
node->name = std::move(name);
977+
node->body = std::move(body);
978+
node->init = std::move(init);
979+
node->span = std::move(span);
980+
data_ = std::move(node);
981+
}
982+
983+
TVM_REGISTER_GLOBAL("tir.SparseBlock")
984+
.set_body_typed([](Array<SpIterVar> sp_iter_vars, Array<SparseBuffer> sp_buffers, String name,
985+
Stmt body, Optional<Stmt> init, Span span) {
986+
return SparseBlock(sp_iter_vars, sp_buffers, name, body, init, span);
987+
});
988+
989+
TVM_REGISTER_NODE_TYPE(SparseBlockNode);
990+
991+
void PrintSparseBlockTitle(const SparseBlockNode* op, ReprPrinter* p) {
992+
p->stream << "sparse_block " << op->name << "(";
993+
for (int i = 0; i < static_cast<int>(op->sp_iter_vars.size()); ++i) {
994+
p->Print(op->sp_iter_vars[i]);
995+
if (i < static_cast<int>(op->sp_iter_vars.size()) - 1) {
996+
p->stream << ", ";
997+
}
998+
}
999+
p->stream << ")";
1000+
}
1001+
1002+
void PrintSparseBlockBody(const SparseBlockNode* op, ReprPrinter* p) {
1003+
// Print init
1004+
PrintInitStmt(op->init, p);
1005+
// Print body
1006+
p->Print(op->body);
1007+
}
1008+
1009+
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
1010+
.set_dispatch<SparseBlockNode>([](const ObjectRef& node, ReprPrinter* p) {
1011+
auto* op = static_cast<const SparseBlockNode*>(node.get());
1012+
p->PrintIndent();
1013+
PrintSparseBlockTitle(op, p);
1014+
p->stream << " {\n";
1015+
p->indent += 2;
1016+
1017+
PrintSparseBlockBody(op, p);
1018+
1019+
p->indent -= 2;
1020+
p->PrintIndent();
1021+
p->stream << "}\n";
1022+
});
1023+
9671024
PrimExpr TypeAnnotation(DataType dtype, Span span) {
9681025
static auto op = Op::Get("tir.type_annotation");
9691026
return tir::Call(dtype, op, {}, span);

0 commit comments

Comments
 (0)