@@ -355,6 +355,64 @@ class SparseBuffer : public ObjectRef {
355355 TVM_DEFINE_OBJECT_REF_METHODS (SparseBuffer, ObjectRef, SparseBufferNode);
356356};
357357
358+ enum class SpIterKind : int {
359+ kDenseFixed = 0 ,
360+ kDenseVariable = 1 ,
361+ kSparseFixed = 2 ,
362+ kSparseVariable = 3
363+ };
364+
365+ /* !
366+ * \brief Iterator variables in SparseTIR
367+ */
368+ class SpIterVarNode : public Object {
369+ public:
370+ Var var;
371+ PrimExpr max_extent;
372+ SpIterKind kind;
373+ Optional<Axis> axis;
374+
375+ void VisitAttrs (AttrVisitor* v) {
376+ v->Visit (" var" , &var);
377+ v->Visit (" max_extent" , &max_extent);
378+ v->Visit (" axis" , &axis);
379+ v->Visit (" kind" , &kind);
380+ }
381+
382+ bool SEqualReduce (const SpIterVarNode* other, SEqualReducer equal) const {
383+ return equal (var, other->var ) && equal (max_extent, other->max_extent ) &&
384+ equal (axis, other->axis ) && equal (kind, other->kind );
385+ }
386+
387+ void SHashReduce (SHashReducer hash_reduce) const {
388+ hash_reduce (var);
389+ hash_reduce (max_extent);
390+ hash_reduce (axis);
391+ hash_reduce (kind);
392+ }
393+
394+ static constexpr const char * _type_key = " tir.sparse.SpIterVar" ;
395+ static constexpr const bool _type_has_method_sequal_reduce = true ;
396+ static constexpr const bool _type_has_method_shash_reduce = true ;
397+ TVM_DECLARE_FINAL_OBJECT_INFO (SpIterVarNode, Object);
398+ };
399+
400+ class SpIterVar : public ObjectRef {
401+ public:
402+ TVM_DLL explicit SpIterVar (String name, PrimExpr max_extent, SpIterKind kind,
403+ Optional<Axis> axis = NullOpt);
404+
405+ /* !
406+ * \return the corresponding var in the IterVar.
407+ */
408+ inline operator PrimExpr () const ;
409+
410+ TVM_DEFINE_OBJECT_REF_METHODS (SpIterVar, ObjectRef, SpIterVarNode);
411+ };
412+
413+ // inline implementations
414+ inline SpIterVar::operator PrimExpr () const { return (*this )->var ; }
415+
358416} // namespace tir
359417} // namespace tvm
360418
0 commit comments