@@ -327,28 +327,6 @@ class BufferStore : public Stmt {
327327 TVM_DEFINE_OBJECT_REF_COW_METHOD (BufferStoreNode);
328328};
329329
330- /* !
331- * \brief Sparse Block node.
332- */
333- class SparseBlockNode : public StmtNode {
334- public:
335- /* ! \brief The sparse iteration variables of the block. */
336- Array<SpIterVar> sp_iter_vars;
337- /* ! \brief The sparse buffers defined in the block. */
338- Array<SparseBuffer> sp_buffers;
339- /* ! \brief The body of the block */
340- Stmt body;
341-
342- static constexpr const char * _type_key = " tir.SparseBlock" ;
343- TVM_DECLARE_FINAL_OBJECT_INFO (SparseBlockNode, StmtNode);
344- };
345-
346- class SparseBlock : public Stmt {
347- public:
348- TVM_DEFINE_OBJECT_REF_METHODS (SparseBlock, Stmt, SparseBlockNode);
349- };
350-
351-
352330/* !
353331 * \brief Store value to the high dimension sparse buffer.
354332 *
@@ -1300,6 +1278,61 @@ class BlockRealize : public Stmt {
13001278 TVM_DEFINE_OBJECT_REF_COW_METHOD (BlockRealizeNode);
13011279};
13021280
1281+ /* !
1282+ * \brief Sparse Block node.
1283+ */
1284+ class SparseBlockNode : public StmtNode {
1285+ public:
1286+ /* ! \brief The sparse iteration variables of the block. */
1287+ Array<SpIterVar> sp_iter_vars;
1288+ /* ! \brief The sparse buffers defined in the block. */
1289+ Array<SparseBuffer> sp_buffers;
1290+ /* ! \brief The name of the block */
1291+ String name;
1292+ /* ! \brief The body of the block */
1293+ Stmt body;
1294+ /* ! \brief The init statement of the block */
1295+ Optional<Stmt> init;
1296+
1297+ void VisitAttrs (AttrVisitor* v) {
1298+ v->Visit (" sp_iter_vars" , &sp_iter_vars);
1299+ v->Visit (" sp_buffers" , &sp_buffers);
1300+ v->Visit (" name" , &name);
1301+ v->Visit (" body" , &body);
1302+ v->Visit (" init" , &init);
1303+ }
1304+
1305+ bool SEqualReduce (const SparseBlockNode* other, SEqualReducer equal) const {
1306+ return equal (sp_iter_vars, other->sp_iter_vars ) && equal (sp_buffers, other->sp_buffers ) &&
1307+ equal (name, other->name ) && equal (body, other->body ) && equal (init, other->init );
1308+ }
1309+
1310+ void SHashReduce (SHashReducer hash_reduce) const {
1311+ hash_reduce (sp_iter_vars);
1312+ hash_reduce (sp_buffers);
1313+ hash_reduce (name);
1314+ hash_reduce (body);
1315+ hash_reduce (init);
1316+ }
1317+
1318+ static constexpr const char * _type_key = " tir.SparseBlock" ;
1319+ TVM_DECLARE_FINAL_OBJECT_INFO (SparseBlockNode, StmtNode);
1320+ };
1321+
1322+ /* !
1323+ * \brief Managed reference to SparseBufferNode
1324+ * \sa SparseBufferNode
1325+ */
1326+ class SparseBlock : public Stmt {
1327+ public:
1328+ TVM_DLL explicit SparseBlock (Array<SpIterVar> sp_iter_vars, Array<SparseBuffer> sp_buffers,
1329+ String name, Stmt body, Optional<Stmt> init = NullOpt,
1330+ Span span = Span());
1331+
1332+ TVM_DEFINE_OBJECT_REF_METHODS (SparseBlock, Stmt, SparseBlockNode);
1333+ TVM_DEFINE_OBJECT_REF_COW_METHOD (SparseBlockNode);
1334+ };
1335+
13031336/* ! \brief namespace of possible attribute sin AttrStmt.attr_key */
13041337namespace attr {
13051338// The above attr does not pass to ir stage.
0 commit comments