@@ -58,6 +58,16 @@ TVM_REGISTER_GLOBAL("tir.sparse.DenseFixedAxis")
5858 return DenseFixedAxis (name, length, from_sparse);
5959 });
6060
61+ TVM_STATIC_IR_FUNCTOR (ReprPrinter, vtable)
62+ .set_dispatch<DenseFixedAxisNode>([](const ObjectRef& node, ReprPrinter* p) {
63+ auto * op = static_cast <const DenseFixedAxisNode*>(node.get ());
64+ p->stream << " dense_fixed(" << op->name << " , " << op->length ;
65+ if (op->from_sparse .defined ()) {
66+ p->stream << " , from_sparse=" << op->from_sparse .value ();
67+ }
68+ p->stream << " )" ;
69+ });
70+
6171// DenseVariableAxis
6272DenseVariableAxis::DenseVariableAxis (String name, PrimExpr length, Buffer indptr) {
6373 ObjectPtr<DenseVariableAxisNode> node = make_object<DenseVariableAxisNode>();
@@ -74,6 +84,12 @@ TVM_REGISTER_GLOBAL("tir.sparse.DenseVariableAxis")
7484 return DenseVariableAxis (name, length, indptr);
7585 });
7686
87+ TVM_STATIC_IR_FUNCTOR (ReprPrinter, vtable)
88+ .set_dispatch<DenseVariableAxisNode>([](const ObjectRef& node, ReprPrinter* p) {
89+ auto * op = static_cast <const DenseVariableAxisNode*>(node.get ());
90+ p->stream << " dense_variable(" << op->name << " , " << op->length << " , " << op->indptr ->name ;
91+ });
92+
7793// SparseFixedAxis
7894SparseFixedAxis::SparseFixedAxis (String name, PrimExpr length, Buffer indices, PrimExpr num_cols) {
7995 ObjectPtr<SparseFixedAxisNode> node = make_object<SparseFixedAxisNode>();
@@ -91,6 +107,13 @@ TVM_REGISTER_GLOBAL("tir.sparse.SparseFixedAxis")
91107 return SparseFixedAxis (name, length, indices, num_cols);
92108 });
93109
110+ TVM_STATIC_IR_FUNCTOR (ReprPrinter, vtable)
111+ .set_dispatch<SparseFixedAxisNode>([](const ObjectRef& node, ReprPrinter* p) {
112+ auto * op = static_cast <const SparseFixedAxisNode*>(node.get ());
113+ p->stream << " sparse_fixed(" << op->name << " , " << op->length << " , " << op->num_cols << " , "
114+ << op->indices ->name << " )" ;
115+ });
116+
94117// SparseVariableAxis
95118SparseVariableAxis::SparseVariableAxis (String name, PrimExpr length, Buffer indptr,
96119 Buffer indices) {
@@ -109,6 +132,13 @@ TVM_REGISTER_GLOBAL("tir.sparse.SparseVariableAxis")
109132 return SparseVariableAxis (name, length, indptr, indices);
110133 });
111134
135+ TVM_STATIC_IR_FUNCTOR (ReprPrinter, vtable)
136+ .set_dispatch<SparseVariableAxisNode>([](const ObjectRef& node, ReprPrinter* p) {
137+ auto * op = static_cast <const SparseVariableAxisNode*>(node.get ());
138+ p->stream << " sparse_variable(" << op->name << " , " << op->length << " , " << op->indptr ->name
139+ << " , " << op->indices ->name << " )" ;
140+ });
141+
112142// AxisTree
113143AxisTree::AxisTree (Array<String> axis_names, Array<Optional<String>> axis_parent_names) {
114144 CHECK_EQ (axis_names.size (), axis_parent_names.size ())
@@ -178,6 +208,27 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
178208 p->stream << " ], " << op->data << " )" ;
179209 });
180210
211+ // SpIterKind
212+ std::ostream& operator <<(std::ostream& out, SpIterKind type) {
213+ switch (type) {
214+ case SpIterKind::kDenseFixed :
215+ out << " dense-fixed" ;
216+ break ;
217+ case SpIterKind::kDenseVariable :
218+ out << " dense-variable" ;
219+ break ;
220+ case SpIterKind::kSparseFixed :
221+ out << " sparse-fixed" ;
222+ break ;
223+ case SpIterKind::kSparseVariable :
224+ out << " sparse-variable" ;
225+ break ;
226+ default :
227+ LOG (FATAL) << " Cannot reach here" ;
228+ }
229+ return out;
230+ }
231+
181232// SpIterVar
182233SpIterVar::SpIterVar (Var var, PrimExpr max_extent, SpIterKind kind, bool is_reduction, Axis axis) {
183234 ObjectPtr<SpIterVarNode> node = make_object<SpIterVarNode>();
@@ -210,5 +261,13 @@ TVM_REGISTER_GLOBAL("tir.sparse.SpIterVar")
210261 return SpIterVar (var, max_extent, SpIterKind (kind), is_reduction, axis);
211262 });
212263
264+ TVM_STATIC_IR_FUNCTOR (ReprPrinter, vtable)
265+ .set_dispatch<SpIterVarNode>([](const ObjectRef& node, ReprPrinter* p) {
266+ auto * op = static_cast <const SpIterVarNode*>(node.get ());
267+ p->stream << " sp_iter_var(" << op->var ->name_hint << " , " << op->max_extent << " , "
268+ << op->kind << " , " << (op->is_reduction ? " reduction" : " spatial" ) << " , "
269+ << op->axis ->name << " )" ;
270+ });
271+
213272} // namespace tir
214273} // namespace tvm
0 commit comments