Skip to content

Commit 7fc66fc

Browse files
MasterJH5574yzh119
authored andcommitted
[SparseTIR] ReprPrinter for Axis and SpIterVar (apache#16)
1 parent 995e2f4 commit 7fc66fc

File tree

2 files changed

+78
-0
lines changed

2 files changed

+78
-0
lines changed

include/tvm/tir/sparse.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,9 @@ enum class SpIterKind : int {
382382
kSparseVariable = 3
383383
};
384384

385+
// overload printing of for type.
386+
TVM_DLL std::ostream& operator<<(std::ostream& os, SpIterKind kind);
387+
385388
/*!
386389
* \brief Iterator variables in SparseTIR
387390
*/
@@ -437,6 +440,22 @@ class SpIterVar : public ObjectRef {
437440
// inline implementations
438441
inline SpIterVar::operator PrimExpr() const { return (*this)->var; }
439442

443+
// inline implementations
444+
inline const char* SpIterKind2String(SpIterKind t) {
445+
switch (t) {
446+
case SpIterKind::kDenseFixed:
447+
return "dense_fixed";
448+
case SpIterKind::kDenseVariable:
449+
return "dense_variable";
450+
case SpIterKind::kSparseFixed:
451+
return "sparse_fixed";
452+
case SpIterKind::kSparseVariable:
453+
return "sparse_variable";
454+
}
455+
LOG(FATAL) << "Unknown SpIterKind" << t;
456+
throw;
457+
}
458+
440459
} // namespace tir
441460
} // namespace tvm
442461

src/tir/ir/sparse.cc

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,16 @@ TVM_REGISTER_GLOBAL("tir.sparse.DenseFixedAxis")
5858
return DenseFixedAxis(name, length, from_sparse);
5959
});
6060

61+
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
62+
.set_dispatch<DenseFixedAxisNode>([](const ObjectRef& node, ReprPrinter* p) {
63+
auto* op = static_cast<const DenseFixedAxisNode*>(node.get());
64+
p->stream << "dense_fixed(" << op->name << ", " << op->length;
65+
if (op->from_sparse.defined()) {
66+
p->stream << ", from_sparse=" << op->from_sparse.value();
67+
}
68+
p->stream << ")";
69+
});
70+
6171
// DenseVariableAxis
6272
DenseVariableAxis::DenseVariableAxis(String name, PrimExpr length, Buffer indptr) {
6373
ObjectPtr<DenseVariableAxisNode> node = make_object<DenseVariableAxisNode>();
@@ -74,6 +84,12 @@ TVM_REGISTER_GLOBAL("tir.sparse.DenseVariableAxis")
7484
return DenseVariableAxis(name, length, indptr);
7585
});
7686

87+
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
88+
.set_dispatch<DenseVariableAxisNode>([](const ObjectRef& node, ReprPrinter* p) {
89+
auto* op = static_cast<const DenseVariableAxisNode*>(node.get());
90+
p->stream << "dense_variable(" << op->name << ", " << op->length << ", " << op->indptr->name;
91+
});
92+
7793
// SparseFixedAxis
7894
SparseFixedAxis::SparseFixedAxis(String name, PrimExpr length, Buffer indices, PrimExpr num_cols) {
7995
ObjectPtr<SparseFixedAxisNode> node = make_object<SparseFixedAxisNode>();
@@ -91,6 +107,13 @@ TVM_REGISTER_GLOBAL("tir.sparse.SparseFixedAxis")
91107
return SparseFixedAxis(name, length, indices, num_cols);
92108
});
93109

110+
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
111+
.set_dispatch<SparseFixedAxisNode>([](const ObjectRef& node, ReprPrinter* p) {
112+
auto* op = static_cast<const SparseFixedAxisNode*>(node.get());
113+
p->stream << "sparse_fixed(" << op->name << ", " << op->length << ", " << op->num_cols << ", "
114+
<< op->indices->name << ")";
115+
});
116+
94117
// SparseVariableAxis
95118
SparseVariableAxis::SparseVariableAxis(String name, PrimExpr length, Buffer indptr,
96119
Buffer indices) {
@@ -109,6 +132,13 @@ TVM_REGISTER_GLOBAL("tir.sparse.SparseVariableAxis")
109132
return SparseVariableAxis(name, length, indptr, indices);
110133
});
111134

135+
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
136+
.set_dispatch<SparseVariableAxisNode>([](const ObjectRef& node, ReprPrinter* p) {
137+
auto* op = static_cast<const SparseVariableAxisNode*>(node.get());
138+
p->stream << "sparse_variable(" << op->name << ", " << op->length << ", " << op->indptr->name
139+
<< ", " << op->indices->name << ")";
140+
});
141+
112142
// AxisTree
113143
AxisTree::AxisTree(Array<String> axis_names, Array<Optional<String>> axis_parent_names) {
114144
CHECK_EQ(axis_names.size(), axis_parent_names.size())
@@ -178,6 +208,27 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
178208
p->stream << "], " << op->data << ")";
179209
});
180210

211+
// SpIterKind
212+
std::ostream& operator<<(std::ostream& out, SpIterKind type) {
213+
switch (type) {
214+
case SpIterKind::kDenseFixed:
215+
out << "dense-fixed";
216+
break;
217+
case SpIterKind::kDenseVariable:
218+
out << "dense-variable";
219+
break;
220+
case SpIterKind::kSparseFixed:
221+
out << "sparse-fixed";
222+
break;
223+
case SpIterKind::kSparseVariable:
224+
out << "sparse-variable";
225+
break;
226+
default:
227+
LOG(FATAL) << "Cannot reach here";
228+
}
229+
return out;
230+
}
231+
181232
// SpIterVar
182233
SpIterVar::SpIterVar(Var var, PrimExpr max_extent, SpIterKind kind, bool is_reduction, Axis axis) {
183234
ObjectPtr<SpIterVarNode> node = make_object<SpIterVarNode>();
@@ -210,5 +261,13 @@ TVM_REGISTER_GLOBAL("tir.sparse.SpIterVar")
210261
return SpIterVar(var, max_extent, SpIterKind(kind), is_reduction, axis);
211262
});
212263

264+
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
265+
.set_dispatch<SpIterVarNode>([](const ObjectRef& node, ReprPrinter* p) {
266+
auto* op = static_cast<const SpIterVarNode*>(node.get());
267+
p->stream << "sp_iter_var(" << op->var->name_hint << ", " << op->max_extent << ", "
268+
<< op->kind << ", " << (op->is_reduction ? "reduction" : "spatial") << ", "
269+
<< op->axis->name << ")";
270+
});
271+
213272
} // namespace tir
214273
} // namespace tvm

0 commit comments

Comments
 (0)