@@ -99,23 +99,48 @@ class DenseAxis : public Axis {
9999 TVM_DEFINE_OBJECT_REF_METHODS (DenseAxis, Axis, DenseAxisNode);
100100};
101101
102+ /* !
103+ * \brief Sparse axis whose column indices is not consecutive.
104+ */
105+ class SparseAxisNode : public AxisNode {
106+ public:
107+ static constexpr const char * _type_key = " tir.sparse.SparseAxis" ;
108+ TVM_DECLARE_BASE_OBJECT_INFO (SparseAxisNode, AxisNode);
109+ };
110+
111+ /* !
112+ * \brief Managed reference to SparseAxisNode.
113+ * \sa SparseAxisNode
114+ */
115+ class SparseAxis : public Axis {
116+ public:
117+ TVM_DEFINE_OBJECT_REF_METHODS (SparseAxis, Axis, SparseAxisNode);
118+ };
119+
102120/* !
103121 * \brief Dense axis with fixed length per row.
104122 */
105123class DenseFixedAxisNode : public DenseAxisNode {
106124 public:
125+ Optional<SparseAxis> from_sparse;
126+
107127 void VisitAttrs (AttrVisitor* v) {
108128 v->Visit (" name" , &name);
109129 v->Visit (" length" , &length);
130+ v->Visit (" from_sparse" , &from_sparse);
110131 }
111132
112- bool SEqualReduce (const DenseAxisNode* other, SEqualReducer equal) const {
113- return equal (name, other->name ) && equal (length, other->length );
133+ bool SEqualReduce (const DenseFixedAxisNode* other, SEqualReducer equal) const {
134+ equal->MarkGraphNode ();
135+ return equal (name, other->name ) && equal (length, other->length ) &&
136+ equal (from_sparse, other->from_sparse );
114137 }
115138
116139 void SHashReduce (SHashReducer hash_reduce) const {
140+ hash_reduce->MarkGraphNode ();
117141 hash_reduce (name);
118142 hash_reduce (length);
143+ hash_reduce (from_sparse);
119144 }
120145
121146 static constexpr const char * _type_key = " tir.sparse.DenseFixedAxis" ;
@@ -128,7 +153,8 @@ class DenseFixedAxisNode : public DenseAxisNode {
128153 */
129154class DenseFixedAxis : public DenseAxis {
130155 public:
131- TVM_DLL explicit DenseFixedAxis (String name, PrimExpr length);
156+ TVM_DLL explicit DenseFixedAxis (String name, PrimExpr length,
157+ Optional<SparseAxis> from_sparse = NullOpt);
132158
133159 TVM_DEFINE_OBJECT_REF_METHODS (DenseFixedAxis, DenseAxis, DenseFixedAxisNode);
134160};
@@ -144,10 +170,12 @@ class DenseVariableAxisNode : public DenseAxisNode {
144170 }
145171
146172 bool SEqualReduce (const DenseVariableAxisNode* other, SEqualReducer equal) const {
173+ equal->MarkGraphNode ();
147174 return equal (name, other->name ) && equal (length, other->length ) && equal (indptr, other->indptr );
148175 }
149176
150177 void SHashReduce (SHashReducer hash_reduce) const {
178+ hash_reduce->MarkGraphNode ();
151179 hash_reduce (name);
152180 hash_reduce (length);
153181 hash_reduce (indptr);
@@ -168,24 +196,6 @@ class DenseVariableAxis : public DenseAxis {
168196 TVM_DEFINE_OBJECT_REF_METHODS (DenseVariableAxis, DenseAxis, DenseVariableAxisNode);
169197};
170198
171- /* !
172- * \brief Sparse axis whose column indices is not consecutive.
173- */
174- class SparseAxisNode : public AxisNode {
175- public:
176- static constexpr const char * _type_key = " tir.sparse.SparseAxis" ;
177- TVM_DECLARE_BASE_OBJECT_INFO (SparseAxisNode, AxisNode);
178- };
179-
180- /* !
181- * \brief Managed reference to SparseAxisNode.
182- * \sa SparseAxisNode
183- */
184- class SparseAxis : public Axis {
185- public:
186- TVM_DEFINE_OBJECT_REF_METHODS (SparseAxis, Axis, SparseAxisNode);
187- };
188-
189199/* !
190200 * \brief Sparse axis with fixed number of non-zero columns per row.
191201 */
@@ -203,11 +213,13 @@ class SparseFixedAxisNode : public SparseAxisNode {
203213 }
204214
205215 bool SEqualReduce (const SparseFixedAxisNode* other, SEqualReducer equal) const {
216+ equal->MarkGraphNode ();
206217 return equal (name, other->name ) && equal (length, other->length ) &&
207218 equal (indices, other->indices ) && equal (num_cols, other->num_cols );
208219 }
209220
210221 void SHashReduce (SHashReducer hash_reduce) const {
222+ hash_reduce->MarkGraphNode ();
211223 hash_reduce (name);
212224 hash_reduce (length);
213225 hash_reduce (indices);
@@ -245,11 +257,13 @@ class SparseVariableAxisNode : public SparseAxisNode {
245257 }
246258
247259 bool SEqualReduce (const SparseVariableAxisNode* other, SEqualReducer equal) const {
260+ equal->MarkGraphNode ();
248261 return equal (name, other->name ) && equal (length, other->length ) &&
249262 equal (indptr, other->indptr ) && equal (indices, other->indices );
250263 }
251264
252265 void SHashReduce (SHashReducer hash_reduce) const {
266+ hash_reduce->MarkGraphNode ();
253267 hash_reduce (name);
254268 hash_reduce (length);
255269 hash_reduce (indptr);
@@ -277,13 +291,27 @@ class SparseVariableAxis : public SparseAxis {
277291class AxisTreeNode : public Object {
278292 public:
279293 // unordered map that stores the parent relationship between axes.
280- std::unordered_map <String, Optional<String>, ObjectPtrHash, ObjectPtrEqual > parent;
294+ Map <String, Optional<String>> parent;
281295 // unordered map that stores the children relationship between axes.
282- std::unordered_map<Optional<String>, Array<String>, ObjectPtrHash, ObjectPtrEqual> children;
296+ Map<Optional<String>, Array<String>> children;
297+
298+ void VisitAttrs (AttrVisitor* v) {
299+ v->Visit (" parent" , &parent);
300+ v->Visit (" children" , &children);
301+ }
302+
303+ bool SEqualReduce (const AxisTreeNode* other, SEqualReducer equal) const {
304+ return equal (parent, other->parent ) && equal (children, other->children );
305+ }
283306
284- void VisitAttrs (AttrVisitor* v) {}
307+ void SHashReduce (SHashReducer hash_reduce) const {
308+ hash_reduce (parent);
309+ hash_reduce (children);
310+ }
285311
286312 static constexpr const char * _type_key = " tir.sparse.AxisTree" ;
313+ static constexpr const bool _type_has_method_sequal_reduce = true ;
314+ static constexpr const bool _type_has_method_shash_reduce = true ;
287315 TVM_DECLARE_FINAL_OBJECT_INFO (AxisTreeNode, Object);
288316};
289317
@@ -313,22 +341,26 @@ class SparseBufferNode : public Object {
313341 inline int ndim () const { return static_cast <int >(axes.size ()); }
314342
315343 void VisitAttrs (AttrVisitor* v) {
316- v->Visit (" length " , &axes);
317- v->Visit (" num_cols " , &data);
344+ v->Visit (" axes " , &axes);
345+ v->Visit (" data " , &data);
318346 v->Visit (" name" , &name);
319347 }
320348
321349 bool SEqualReduce (const SparseBufferNode* other, SEqualReducer equal) const {
350+ equal->MarkGraphNode ();
322351 return equal (axes, other->axes ) && equal (data, other->data ) && equal (name, other->name );
323352 }
324353
325354 void SHashReduce (SHashReducer hash_reduce) const {
355+ hash_reduce->MarkGraphNode ();
326356 hash_reduce (axes);
327357 hash_reduce (data);
328358 hash_reduce (name);
329359 }
330360
331361 static constexpr const char * _type_key = " tir.sparse.SparseBuffer" ;
362+ static constexpr const bool _type_has_method_sequal_reduce = true ;
363+ static constexpr const bool _type_has_method_shash_reduce = true ;
332364 TVM_DECLARE_FINAL_OBJECT_INFO (SparseBufferNode, Object);
333365};
334366
@@ -359,7 +391,7 @@ class SpIterVarNode : public Object {
359391 PrimExpr max_extent;
360392 SpIterKind kind;
361393 bool is_reduction;
362- Optional< Axis> axis;
394+ Axis axis;
363395
364396 void VisitAttrs (AttrVisitor* v) {
365397 v->Visit (" var" , &var);
@@ -392,7 +424,7 @@ class SpIterVarNode : public Object {
392424class SpIterVar : public ObjectRef {
393425 public:
394426 TVM_DLL explicit SpIterVar (Var var, PrimExpr max_extent, SpIterKind kind, bool is_reduction,
395- Optional< Axis> axis = NullOpt );
427+ Axis axis);
396428
397429 /* !
398430 * \return the corresponding var in the IterVar.
0 commit comments