@@ -68,19 +68,20 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
6868/* ******* DenseVariableAxis ********/
6969
7070/* ! \brief Default constuctor of DenseVariableAxis */
71- DenseVariableAxis::DenseVariableAxis (String name, PrimExpr length, Buffer indptr) {
71+ DenseVariableAxis::DenseVariableAxis (String name, PrimExpr length, PrimExpr nnz, Buffer indptr) {
7272 ObjectPtr<DenseVariableAxisNode> node = make_object<DenseVariableAxisNode>();
7373 node->name = std::move (name);
7474 node->length = std::move (length);
75+ node->nnz_ = std::move (nnz);
7576 node->indptr = std::move (indptr);
7677 data_ = std::move (node);
7778}
7879
7980TVM_REGISTER_NODE_TYPE (DenseVariableAxisNode);
8081
8182TVM_REGISTER_GLOBAL (" tir.sparse.DenseVariableAxis" )
82- .set_body_typed([](String name, PrimExpr length, Buffer indptr) {
83- return DenseVariableAxis (name, length, indptr);
83+ .set_body_typed([](String name, PrimExpr length, PrimExpr nnz, Buffer indptr) {
84+ return DenseVariableAxis (std::move ( name), std::move ( length), std::move (nnz), std::move ( indptr) );
8485 });
8586
8687TVM_STATIC_IR_FUNCTOR (ReprPrinter, vtable)
@@ -128,17 +129,7 @@ FusedAxis::FusedAxis(Array<Axis> group, int index) {
128129 fused_name += group[i]->name ;
129130 }
130131 node->name = " fused_" + fused_name + " _" + group[index]->name ;
131-
132- if (const auto * df_axis = group[index].as <DenseFixedAxisNode>()) {
133- node->length = df_axis->length ;
134- } else if (const auto * sf_axis = group[index].as <SparseFixedAxisNode>()) {
135- // TODO(zihao): accumulate previous dimensions.
136- } else if (const auto * dv_axis = group[index].as <DenseVariableAxisNode>()) {
137- node->length = dv_axis->nnz ();
138- } else if (const auto * sv_axis = group[index].as <SparseVariableAxisNode>()) {
139- node->length = sv_axis->nnz ();
140- }
141-
132+ node->length = group[index]->nnz ();
142133 node->is_derived_axis = true ;
143134 node->group = std::move (group);
144135 node->index = index;
@@ -183,7 +174,7 @@ TVM_REGISTER_NODE_TYPE(SparseFixedAxisNode);
183174
184175TVM_REGISTER_GLOBAL (" tir.sparse.SparseFixedAxis" )
185176 .set_body_typed([](String name, PrimExpr length, Buffer indices, PrimExpr nnz_cols) {
186- return SparseFixedAxis (name, length, indices, nnz_cols);
177+ return SparseFixedAxis (std::move ( name), std::move ( length), std::move ( indices), std::move ( nnz_cols) );
187178 });
188179
189180TVM_STATIC_IR_FUNCTOR (ReprPrinter, vtable)
@@ -210,7 +201,7 @@ TVM_REGISTER_NODE_TYPE(SparseVariableAxisNode);
210201
211202TVM_REGISTER_GLOBAL (" tir.sparse.SparseVariableAxis" )
212203 .set_body_typed([](String name, PrimExpr length, Buffer indptr, Buffer indices) {
213- return SparseVariableAxis (name, length, indptr, indices);
204+ return SparseVariableAxis (std::move ( name), std::move ( length), std::move ( indptr), std::move ( indices) );
214205 });
215206
216207TVM_STATIC_IR_FUNCTOR (ReprPrinter, vtable)
@@ -259,7 +250,7 @@ TVM_REGISTER_NODE_TYPE(AxisTreeNode);
259250
260251TVM_REGISTER_GLOBAL (" tir.sparse.AxisTree" )
261252 .set_body_typed([](Array<String> axis_names, Array<Optional<String>> axis_parent_names) {
262- return AxisTree (axis_names, axis_parent_names);
253+ return AxisTree (std::move ( axis_names), std::move ( axis_parent_names) );
263254 });
264255
265256/* ******* SparseBuffer ********/
@@ -279,7 +270,7 @@ TVM_REGISTER_NODE_TYPE(SparseBufferNode);
279270
280271TVM_REGISTER_GLOBAL (" tir.sparse.SparseBuffer" )
281272 .set_body_typed([](Array<Axis> axes, Buffer data, String name) {
282- return SparseBuffer (axes, data, name);
273+ return SparseBuffer (std::move ( axes), std::move ( data), std::move ( name) );
283274 });
284275
285276TVM_STATIC_IR_FUNCTOR (ReprPrinter, vtable)
@@ -338,7 +329,7 @@ TVM_REGISTER_NODE_TYPE(SpIterVarNode);
338329
339330TVM_REGISTER_GLOBAL (" tir.sparse.SpIterVar" )
340331 .set_body_typed([](Var var, PrimExpr max_extent, bool is_reduction, Axis axis) {
341- return SpIterVar (var, max_extent, is_reduction, axis);
332+ return SpIterVar (std::move ( var), std::move ( max_extent) , is_reduction, std::move ( axis) );
342333 });
343334
344335TVM_STATIC_IR_FUNCTOR (ReprPrinter, vtable)
0 commit comments