Skip to content

Commit cc894db

Browse files
committed
[SparseTIR] SparseBlock on C++/Python side (#11)
* Fix a bug in the last commit * SparseBlock on C++ & Python side
1 parent d472fd6 commit cc894db

File tree

4 files changed

+170
-27
lines changed

4 files changed

+170
-27
lines changed

include/tvm/tir/stmt.h

Lines changed: 55 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -327,28 +327,6 @@ class BufferStore : public Stmt {
327327
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferStoreNode);
328328
};
329329

330-
/*!
331-
* \brief Sparse Block node.
332-
*/
333-
class SparseBlockNode : public StmtNode {
334-
public:
335-
/*! \brief The sparse iteration variables of the block. */
336-
Array<SpIterVar> sp_iter_vars;
337-
/*! \brief The sparse buffers defined in the block. */
338-
Array<SparseBuffer> sp_buffers;
339-
/*! \brief The body of the block */
340-
Stmt body;
341-
342-
static constexpr const char* _type_key = "tir.SparseBlock";
343-
TVM_DECLARE_FINAL_OBJECT_INFO(SparseBlockNode, StmtNode);
344-
};
345-
346-
class SparseBlock : public Stmt {
347-
public:
348-
TVM_DEFINE_OBJECT_REF_METHODS(SparseBlock, Stmt, SparseBlockNode);
349-
};
350-
351-
352330
/*!
353331
* \brief Store value to the high dimension sparse buffer.
354332
*
@@ -1300,6 +1278,61 @@ class BlockRealize : public Stmt {
13001278
TVM_DEFINE_OBJECT_REF_COW_METHOD(BlockRealizeNode);
13011279
};
13021280

1281+
/*!
1282+
* \brief Sparse Block node.
1283+
*/
1284+
class SparseBlockNode : public StmtNode {
1285+
public:
1286+
/*! \brief The sparse iteration variables of the block. */
1287+
Array<SpIterVar> sp_iter_vars;
1288+
/*! \brief The sparse buffers defined in the block. */
1289+
Array<SparseBuffer> sp_buffers;
1290+
/*! \brief The name of the block */
1291+
String name;
1292+
/*! \brief The body of the block */
1293+
Stmt body;
1294+
/*! \brief The init statement of the block */
1295+
Optional<Stmt> init;
1296+
1297+
void VisitAttrs(AttrVisitor* v) {
1298+
v->Visit("sp_iter_vars", &sp_iter_vars);
1299+
v->Visit("sp_buffers", &sp_buffers);
1300+
v->Visit("name", &name);
1301+
v->Visit("body", &body);
1302+
v->Visit("init", &init);
1303+
}
1304+
1305+
bool SEqualReduce(const SparseBlockNode* other, SEqualReducer equal) const {
1306+
return equal(sp_iter_vars, other->sp_iter_vars) && equal(sp_buffers, other->sp_buffers) &&
1307+
equal(name, other->name) && equal(body, other->body) && equal(init, other->init);
1308+
}
1309+
1310+
void SHashReduce(SHashReducer hash_reduce) const {
1311+
hash_reduce(sp_iter_vars);
1312+
hash_reduce(sp_buffers);
1313+
hash_reduce(name);
1314+
hash_reduce(body);
1315+
hash_reduce(init);
1316+
}
1317+
1318+
static constexpr const char* _type_key = "tir.SparseBlock";
1319+
TVM_DECLARE_FINAL_OBJECT_INFO(SparseBlockNode, StmtNode);
1320+
};
1321+
1322+
/*!
1323+
* \brief Managed reference to SparseBufferNode
1324+
* \sa SparseBufferNode
1325+
*/
1326+
class SparseBlock : public Stmt {
1327+
public:
1328+
TVM_DLL explicit SparseBlock(Array<SpIterVar> sp_iter_vars, Array<SparseBuffer> sp_buffers,
1329+
String name, Stmt body, Optional<Stmt> init = NullOpt,
1330+
Span span = Span());
1331+
1332+
TVM_DEFINE_OBJECT_REF_METHODS(SparseBlock, Stmt, SparseBlockNode);
1333+
TVM_DEFINE_OBJECT_REF_COW_METHOD(SparseBlockNode);
1334+
};
1335+
13031336
/*! \brief namespace of possible attribute sin AttrStmt.attr_key */
13041337
namespace attr {
13051338
// The above attr does not pass to ir stage.

python/tvm/tir/sparse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ class SpIterVar(Object):
236236
SparseFixed = 2
237237
SparseVariable = 3
238238

239-
def __init__(self, var, max_extent, kind, axis=None):
239+
def __init__(self, var, max_extent, kind, is_reduction, axis=None):
240240
self.__init_handle_by_constructor__(
241241
_ffi_api.SpIterVar, var, max_extent, kind, is_reduction, axis # type: ignore
242242
)

python/tvm/tir/stmt.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from . import _ffi_api
3737
from .buffer import Buffer
3838
from .expr import IterVar
39+
from .sparse import SpIterVar, SparseBuffer
3940

4041

4142
class Stmt(Object):
@@ -614,6 +615,58 @@ def __init__(
614615
) # type: ignore
615616

616617

618+
@tvm._ffi.register_object("tir.SparseBlock")
619+
class SparseBlock(Stmt):
620+
"""SparseBlock node.
621+
622+
Parameters
623+
----------
624+
sp_iter_vars : List[SpIterVar]
625+
The sparse iteration variables of the block.
626+
627+
sp_buffers : List[SparseBuffer]
628+
The sparse buffers defined in the block.
629+
630+
name : str
631+
The name of the block.
632+
633+
body : Stmt
634+
The body of the block.
635+
636+
init : Optional[Stmt]
637+
The init statement of the block.
638+
639+
span : Optional[Span]
640+
The location of this block in the source code.
641+
"""
642+
643+
sp_iter_vars: List[SpIterVar]
644+
sp_buffers: List[SparseBuffer]
645+
name: str
646+
body: Stmt
647+
init: Optional[Stmt]
648+
span: Optional[Span]
649+
650+
def __init__(
651+
self,
652+
sp_iter_vars: List[SpIterVar],
653+
sp_buffers: List[SparseBuffer],
654+
name: str,
655+
body: Stmt,
656+
init: Optional[Stmt] = None,
657+
span: Optional[Span] = None,
658+
):
659+
self.__init_handle_by_constructor__(
660+
_ffi_api.SparseBlock, # type: ignore
661+
sp_iter_vars,
662+
sp_buffers,
663+
name,
664+
body,
665+
init,
666+
span,
667+
) # type: ignore
668+
669+
617670
@tvm._ffi.register_object("tir.BlockRealize")
618671
class BlockRealize(Stmt):
619672
"""BlockRealize node.

src/tir/ir/stmt.cc

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -883,17 +883,21 @@ void PrintBlockSignature(const BlockNode* op, ReprPrinter* p) {
883883
}
884884
}
885885

886-
void PrintBlockBody(const BlockNode* op, ReprPrinter* p) {
887-
// Print init
888-
if (op->init.defined()) {
886+
void PrintInitStmt(const Optional<Stmt>& init, ReprPrinter* p) {
887+
if (init.defined()) {
889888
p->PrintIndent();
890889
p->stream << "with init() {\n";
891890
p->indent += 2;
892-
p->Print(op->init.value());
891+
p->Print(init.value());
893892
p->indent -= 2;
894893
p->PrintIndent();
895894
p->stream << "}\n";
896895
}
896+
}
897+
898+
void PrintBlockBody(const BlockNode* op, ReprPrinter* p) {
899+
// Print init
900+
PrintInitStmt(op->init, p);
897901
// Print body
898902
p->Print(op->body);
899903
}
@@ -971,6 +975,59 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
971975
p->stream << "}\n";
972976
});
973977

978+
SparseBlock::SparseBlock(Array<SpIterVar> sp_iter_vars, Array<SparseBuffer> sp_buffers, String name,
979+
Stmt body, Optional<Stmt> init, Span span) {
980+
ObjectPtr<SparseBlockNode> node = make_object<SparseBlockNode>();
981+
node->sp_iter_vars = std::move(sp_iter_vars);
982+
node->sp_buffers = std::move(sp_buffers);
983+
node->name = std::move(name);
984+
node->body = std::move(body);
985+
node->init = std::move(init);
986+
node->span = std::move(span);
987+
data_ = std::move(node);
988+
}
989+
990+
TVM_REGISTER_GLOBAL("tir.SparseBlock")
991+
.set_body_typed([](Array<SpIterVar> sp_iter_vars, Array<SparseBuffer> sp_buffers, String name,
992+
Stmt body, Optional<Stmt> init, Span span) {
993+
return SparseBlock(sp_iter_vars, sp_buffers, name, body, init, span);
994+
});
995+
996+
TVM_REGISTER_NODE_TYPE(SparseBlockNode);
997+
998+
void PrintSparseBlockTitle(const SparseBlockNode* op, ReprPrinter* p) {
999+
p->stream << "sparse_block " << op->name << "(";
1000+
for (int i = 0; i < static_cast<int>(op->sp_iter_vars.size()); ++i) {
1001+
p->Print(op->sp_iter_vars[i]);
1002+
if (i < static_cast<int>(op->sp_iter_vars.size()) - 1) {
1003+
p->stream << ", ";
1004+
}
1005+
}
1006+
p->stream << ")";
1007+
}
1008+
1009+
void PrintSparseBlockBody(const SparseBlockNode* op, ReprPrinter* p) {
1010+
// Print init
1011+
PrintInitStmt(op->init, p);
1012+
// Print body
1013+
p->Print(op->body);
1014+
}
1015+
1016+
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
1017+
.set_dispatch<SparseBlockNode>([](const ObjectRef& node, ReprPrinter* p) {
1018+
auto* op = static_cast<const SparseBlockNode*>(node.get());
1019+
p->PrintIndent();
1020+
PrintSparseBlockTitle(op, p);
1021+
p->stream << " {\n";
1022+
p->indent += 2;
1023+
1024+
PrintSparseBlockBody(op, p);
1025+
1026+
p->indent -= 2;
1027+
p->PrintIndent();
1028+
p->stream << "}\n";
1029+
});
1030+
9741031
PrimExpr TypeAnnotation(DataType dtype, Span span) {
9751032
static auto op = Op::Get("tir.type_annotation");
9761033
return tir::Call(dtype, op, {}, span);

0 commit comments

Comments
 (0)