@@ -44,7 +44,10 @@ class AxisNode : public Object {
4444 /* length of current axis. For sparse axis, length refers to the upperbound of
4545 * the current axis. */
4646 PrimExpr length;
47+
4748 static constexpr const char * _type_key = " tir.sparse.Axis" ;
49+ static constexpr const bool _type_has_method_sequal_reduce = true ;
50+ static constexpr const bool _type_has_method_shash_reduce = true ;
4851 TVM_DECLARE_BASE_OBJECT_INFO (AxisNode, Object);
4952};
5053
@@ -98,6 +101,20 @@ class DenseAxis : public Axis {
98101 */
99102class DenseFixedAxisNode : public DenseAxisNode {
100103 public:
104+ void VisitAttrs (AttrVisitor* v) {
105+ v->Visit (" name" , &name);
106+ v->Visit (" length" , &length);
107+ }
108+
109+ bool SEqualReduce (const DenseAxisNode* other, SEqualReducer equal) const {
110+ return equal (name, other->name ) && equal (length, other->length );
111+ }
112+
113+ void SHashReduce (SHashReducer hash_reduce) const {
114+ hash_reduce (name);
115+ hash_reduce (length);
116+ }
117+
101118 static constexpr const char * _type_key = " tir.sparse.DenseFixedAxis" ;
102119 TVM_DECLARE_FINAL_OBJECT_INFO (DenseFixedAxisNode, DenseAxisNode);
103120};
@@ -108,12 +125,31 @@ class DenseFixedAxisNode : public DenseAxisNode {
108125 */
109126class DenseFixedAxis : public DenseAxis {
110127 public:
128+ TVM_DLL explicit DenseFixedAxis (String name, PrimExpr length);
129+
111130 TVM_DEFINE_OBJECT_REF_METHODS (DenseFixedAxis, DenseAxis, DenseFixedAxisNode);
112131};
113132
114133class DenseVariableAxisNode : public DenseAxisNode {
115134 public:
116135 Buffer indptr;
136+
137+ void VisitAttrs (AttrVisitor* v) {
138+ v->Visit (" name" , &name);
139+ v->Visit (" length" , &length);
140+ v->Visit (" indptr" , &indptr);
141+ }
142+
143+ bool SEqualReduce (const DenseVariableAxisNode* other, SEqualReducer equal) const {
144+ return equal (name, other->name ) && equal (length, other->length ) && equal (indptr, other->indptr );
145+ }
146+
147+ void SHashReduce (SHashReducer hash_reduce) const {
148+ hash_reduce (name);
149+ hash_reduce (length);
150+ hash_reduce (indptr);
151+ }
152+
117153 static constexpr const char * _type_key = " tir.sparse.DenseVariableAxis" ;
118154 TVM_DECLARE_FINAL_OBJECT_INFO (DenseVariableAxisNode, DenseAxisNode);
119155};
@@ -124,8 +160,9 @@ class DenseVariableAxisNode : public DenseAxisNode {
124160 */
125161class DenseVariableAxis : public DenseAxis {
126162 public:
127- TVM_DEFINE_OBJECT_REF_METHODS (DenseVariableAxis, DenseAxis,
128- DenseVariableAxisNode);
163+ TVM_DLL explicit DenseVariableAxis (String name, PrimExpr length, Buffer indptr);
164+
165+ TVM_DEFINE_OBJECT_REF_METHODS (DenseVariableAxis, DenseAxis, DenseVariableAxisNode);
129166};
130167
131168/* !
@@ -154,6 +191,26 @@ class SparseFixedAxisNode : public SparseAxisNode {
154191 Buffer indices;
155192 /* fixed number of columns of current sparse axis. */
156193 PrimExpr num_cols;
194+
195+ void VisitAttrs (AttrVisitor* v) {
196+ v->Visit (" name" , &name);
197+ v->Visit (" length" , &length);
198+ v->Visit (" indptr" , &indices);
199+ v->Visit (" num_cols" , &num_cols);
200+ }
201+
202+ bool SEqualReduce (const SparseFixedAxisNode* other, SEqualReducer equal) const {
203+ return equal (name, other->name ) && equal (length, other->length ) &&
204+ equal (indices, other->indices ) && equal (num_cols, other->num_cols );
205+ }
206+
207+ void SHashReduce (SHashReducer hash_reduce) const {
208+ hash_reduce (name);
209+ hash_reduce (length);
210+ hash_reduce (indices);
211+ hash_reduce (num_cols);
212+ }
213+
157214 static constexpr const char * _type_key = " tir.sparse.SparseFixedAxis" ;
158215 TVM_DECLARE_FINAL_OBJECT_INFO (SparseFixedAxisNode, SparseAxisNode);
159216};
@@ -164,17 +221,39 @@ class SparseFixedAxisNode : public SparseAxisNode {
164221 */
165222class SparseFixedAxis : public SparseAxis {
166223 public:
167- TVM_DEFINE_OBJECT_REF_METHODS (SparseFixedAxis, SparseAxis,
168- SparseFixedAxisNode);
224+ TVM_DLL explicit SparseFixedAxis (String name, PrimExpr length, Buffer indices, PrimExpr num_cols);
225+
226+ TVM_DEFINE_OBJECT_REF_METHODS (SparseFixedAxis, SparseAxis, SparseFixedAxisNode);
169227};
170228
171229/* !
172230 * \brief Sparse axis with variable number of non-zero columns per row.
173231 */
174232class SparseVariableAxisNode : public SparseAxisNode {
175233 public:
176- Buffer indptr, indices;
177- static constexpr const char * _type_key = " tir.sparse.SparseVariabledAxis" ;
234+ Buffer indptr;
235+ Buffer indices;
236+
237+ void VisitAttrs (AttrVisitor* v) {
238+ v->Visit (" name" , &name);
239+ v->Visit (" length" , &length);
240+ v->Visit (" indptr" , &indptr);
241+ v->Visit (" indices" , &indices);
242+ }
243+
244+ bool SEqualReduce (const SparseVariableAxisNode* other, SEqualReducer equal) const {
245+ return equal (name, other->name ) && equal (length, other->length ) &&
246+ equal (indptr, other->indptr ) && equal (indices, other->indices );
247+ }
248+
249+ void SHashReduce (SHashReducer hash_reduce) const {
250+ hash_reduce (name);
251+ hash_reduce (length);
252+ hash_reduce (indptr);
253+ hash_reduce (indices);
254+ }
255+
256+ static constexpr const char * _type_key = " tir.sparse.SparseVariableAxis" ;
178257 TVM_DECLARE_FINAL_OBJECT_INFO (SparseVariableAxisNode, SparseAxisNode);
179258};
180259
@@ -184,8 +263,9 @@ class SparseVariableAxisNode : public SparseAxisNode {
184263 */
185264class SparseVariableAxis : public SparseAxis {
186265 public:
187- TVM_DEFINE_OBJECT_REF_METHODS (SparseVariableAxis, SparseAxis,
188- SparseVariableAxisNode);
266+ TVM_DLL explicit SparseVariableAxis (String name, PrimExpr length, Buffer indptr, Buffer indices);
267+
268+ TVM_DEFINE_OBJECT_REF_METHODS (SparseVariableAxis, SparseAxis, SparseVariableAxisNode);
189269};
190270
191271/* !
@@ -223,6 +303,26 @@ class SparseBufferNode : public Object {
223303 int ndim;
224304 /* Buffer corresponding to flattened value */
225305 Buffer data;
306+
307+ void VisitAttrs (AttrVisitor* v) {
308+ v->Visit (" name" , &root);
309+ v->Visit (" length" , &axes);
310+ v->Visit (" indptr" , &ndim);
311+ v->Visit (" num_cols" , &data);
312+ }
313+
314+ bool SEqualReduce (const SparseBufferNode* other, SEqualReducer equal) const {
315+ return equal (root, other->root ) && equal (axes, other->axes ) && equal (ndim, other->ndim ) &&
316+ equal (data, other->data );
317+ }
318+
319+ void SHashReduce (SHashReducer hash_reduce) const {
320+ hash_reduce (root);
321+ hash_reduce (axes);
322+ hash_reduce (ndim);
323+ hash_reduce (data);
324+ }
325+
226326 static constexpr const char * _type_key = " tir.sparse.SparseBufferNode" ;
227327 TVM_DECLARE_FINAL_OBJECT_INFO (SparseBufferNode, Object);
228328};
@@ -233,11 +333,13 @@ class SparseBufferNode : public Object {
233333 */
234334class SparseBuffer : public ObjectRef {
235335 public:
336+ TVM_DLL explicit SparseBuffer (AxisTree root, Array<Axis> axes, int ndim, Buffer data);
337+
236338 TVM_DEFINE_OBJECT_REF_METHODS (SparseBuffer, ObjectRef, SparseBufferNode);
237339};
238340
239341} // namespace sparse
240342} // namespace tir
241343} // namespace tvm
242344
243- #endif // TVM_TIR_BUFFER_H_
345+ #endif // TVM_TIR_SPARSE_H_
0 commit comments