@@ -38,14 +38,12 @@ DenseFixedAxis::DenseFixedAxis(String name, PrimExpr length) {
3838
3939TVM_REGISTER_NODE_TYPE (DenseFixedAxisNode);
4040
41- TVM_REGISTER_GLOBAL (" tir.sparse.DenseFixedAxis" )
42- .set_body_typed([](String name, PrimExpr length) {
43- return DenseFixedAxis (name, length);
44- });
41+ TVM_REGISTER_GLOBAL (" tir.sparse.DenseFixedAxis" ).set_body_typed([](String name, PrimExpr length) {
42+ return DenseFixedAxis (name, length);
43+ });
4544
4645// DenseVariableAxis
47- DenseVariableAxis::DenseVariableAxis (String name, PrimExpr length,
48- Buffer indptr) {
46+ DenseVariableAxis::DenseVariableAxis (String name, PrimExpr length, Buffer indptr) {
4947 ObjectPtr<DenseVariableAxisNode> node = make_object<DenseVariableAxisNode>();
5048 node->name = std::move (name);
5149 node->length = std::move (length);
@@ -61,8 +59,7 @@ TVM_REGISTER_GLOBAL("tir.sparse.DenseVariableAxis")
6159 });
6260
6361// SparseFixedAxis
64- SparseFixedAxis::SparseFixedAxis (String name, PrimExpr length, Buffer indices,
65- PrimExpr num_cols) {
62+ SparseFixedAxis::SparseFixedAxis (String name, PrimExpr length, Buffer indices, PrimExpr num_cols) {
6663 ObjectPtr<SparseFixedAxisNode> node = make_object<SparseFixedAxisNode>();
6764 node->name = std::move (name);
6865 node->length = std::move (length);
@@ -74,16 +71,14 @@ SparseFixedAxis::SparseFixedAxis(String name, PrimExpr length, Buffer indices,
7471TVM_REGISTER_NODE_TYPE (SparseFixedAxisNode);
7572
7673TVM_REGISTER_GLOBAL (" tir.sparse.SparseFixedAxis" )
77- .set_body_typed([](String name, PrimExpr length, Buffer indices,
78- PrimExpr num_cols) {
74+ .set_body_typed([](String name, PrimExpr length, Buffer indices, PrimExpr num_cols) {
7975 return SparseFixedAxis (name, length, indices, num_cols);
8076 });
8177
8278// SparseVariableAxis
83- SparseVariableAxis::SparseVariableAxis (String name, PrimExpr length,
84- Buffer indptr, Buffer indices) {
85- ObjectPtr<SparseVariableAxisNode> node =
86- make_object<SparseVariableAxisNode>();
79+ SparseVariableAxis::SparseVariableAxis (String name, PrimExpr length, Buffer indptr,
80+ Buffer indices) {
81+ ObjectPtr<SparseVariableAxisNode> node = make_object<SparseVariableAxisNode>();
8782 node->name = std::move (name);
8883 node->length = std::move (length);
8984 node->indptr = std::move (indptr);
@@ -94,14 +89,12 @@ SparseVariableAxis::SparseVariableAxis(String name, PrimExpr length,
9489TVM_REGISTER_NODE_TYPE (SparseVariableAxisNode);
9590
9691TVM_REGISTER_GLOBAL (" tir.sparse.SparseVariableAxis" )
97- .set_body_typed([](String name, PrimExpr length, Buffer indptr,
98- Buffer indices) {
92+ .set_body_typed([](String name, PrimExpr length, Buffer indptr, Buffer indices) {
9993 return SparseVariableAxis (name, length, indptr, indices);
10094 });
10195
10296// AxisTree
103- AxisTree::AxisTree (Array<Axis> axes,
104- Array<Optional<String>> axis_parent_names) {
97+ AxisTree::AxisTree (Array<Axis> axes, Array<Optional<String>> axis_parent_names) {
10598 CHECK_EQ (axes.size (), axis_parent_names.size ())
10699 << " ValueError: The axes array should have the same length as axis_parent_names "
107100 " array." ;
@@ -121,9 +114,7 @@ AxisTree::AxisTree(Array<Axis> axes,
121114 CHECK (node->axis_map .find (parent_name.value ()) != node->axis_map .end ())
122115 << " ValueError: Parent axis name doesn't exist." ;
123116 }
124- Axis parent_axis = (parent_name.get () != nullptr )
125- ? node->axis_map [parent_name.value ()]
126- : root;
117+ Axis parent_axis = (parent_name.get () != nullptr ) ? node->axis_map [parent_name.value ()] : root;
127118 node->parent [axis] = parent_axis;
128119 if (node->children .find (parent_axis) != node->children .end ()) {
129120 node->children [parent_axis].push_back (axis);
@@ -139,8 +130,7 @@ AxisTree::AxisTree(Array<Axis> axes,
139130TVM_REGISTER_NODE_TYPE (AxisTreeNode);
140131
141132TVM_REGISTER_GLOBAL (" tir.sparse.AxisTree" )
142- .set_body_typed([](Array<Axis> axes,
143- Array<Optional<String>> axis_parent_names) {
133+ .set_body_typed([](Array<Axis> axes, Array<Optional<String>> axis_parent_names) {
144134 return AxisTree (axes, axis_parent_names);
145135 });
146136
@@ -164,7 +154,8 @@ TVM_REGISTER_GLOBAL("tir.sparse.SparseBuffer")
164154 });
165155
166156// SpIterVar
167- SpIterVar::SpIterVar (String name, PrimExpr max_extent, SpIterKind kind, Optional<Axis> axis) {
157+ SpIterVar::SpIterVar (String name, PrimExpr max_extent, SpIterKind kind, bool is_reduction,
158+ Optional<Axis> axis) {
168159 ObjectPtr<SpIterVarNode> node = make_object<SpIterVarNode>();
169160
170161 if (kind != SpIterKind::kDenseFixed ) {
@@ -175,15 +166,17 @@ SpIterVar::SpIterVar(String name, PrimExpr max_extent, SpIterKind kind, Optional
175166 node->var = Var (std::move (name));
176167 node->max_extent = std::move (max_extent);
177168 node->kind = kind;
169+ node->is_reduction = is_reduction;
178170 node->axis = std::move (axis);
179171 data_ = std::move (node);
180172}
181173
182174TVM_REGISTER_NODE_TYPE (SpIterVarNode);
183175
184176TVM_REGISTER_GLOBAL (" tir.sparse.SpIterVar" )
185- .set_body_typed([](String name, PrimExpr max_extent, SpIterKind kind, Optional<Axis> axis) {
186- return SpIterVar (name, max_extent, kind, axis);
177+ .set_body_typed([](String name, PrimExpr max_extent, SpIterKind kind, bool is_reduction,
178+ Optional<Axis> axis) {
179+ return SpIterVar (name, max_extent, kind, is_reduction, axis);
187180 });
188181
189182} // namespace tir
0 commit comments