Skip to content

Commit 9544e17

Browse files
committed
Fatal bugfix and change the signature of DenseVariableAxis. (#33)
1 parent e4ed6ad commit 9544e17

File tree

4 files changed

+26
-28
lines changed

4 files changed

+26
-28
lines changed

include/tvm/tir/sparse.h

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ class AxisNode : public Object {
7575
DataType GetIndexType() const { return length->dtype; }
7676

7777
virtual AxisKind kind() const = 0;
78+
virtual PrimExpr nnz() const = 0;
7879

7980
static constexpr const char* _type_key = "tir.sparse.Axis";
8081
static constexpr const bool _type_has_method_sequal_reduce = true;
@@ -134,6 +135,8 @@ class DenseFixedAxisNode : public DenseAxisNode {
134135
public:
135136
AxisKind kind() const final { return AxisKind::kDenseFixed; }
136137

138+
PrimExpr nnz() const final { return length; }
139+
137140
static constexpr const char* _type_key = "tir.sparse.DenseFixedAxis";
138141
TVM_DECLARE_BASE_OBJECT_INFO(DenseFixedAxisNode, DenseAxisNode);
139142
};
@@ -234,6 +237,7 @@ class FusedAxis : public DenseFixedAxis {
234237
class DenseVariableAxisNode : public DenseAxisNode {
235238
public:
236239
Buffer indptr;
240+
PrimExpr nnz_;
237241

238242
void VisitAttrs(AttrVisitor* v) {
239243
DenseAxisNode::VisitAttrs(v);
@@ -249,10 +253,10 @@ class DenseVariableAxisNode : public DenseAxisNode {
249253
hash_reduce(indptr);
250254
}
251255

252-
PrimExpr nnz() const { return indptr->shape[0]; }
253-
254256
AxisKind kind() const final { return AxisKind::kDenseVariable; }
255257

258+
PrimExpr nnz() const final { return nnz_; }
259+
256260
static constexpr const char* _type_key = "tir.sparse.DenseVariableAxis";
257261
TVM_DECLARE_FINAL_OBJECT_INFO(DenseVariableAxisNode, DenseAxisNode);
258262
};
@@ -263,7 +267,7 @@ class DenseVariableAxisNode : public DenseAxisNode {
263267
*/
264268
class DenseVariableAxis : public DenseAxis {
265269
public:
266-
TVM_DLL explicit DenseVariableAxis(String name, PrimExpr length, Buffer indptr);
270+
TVM_DLL explicit DenseVariableAxis(String name, PrimExpr length, PrimExpr nnz, Buffer indptr);
267271

268272
TVM_DEFINE_OBJECT_REF_METHODS(DenseVariableAxis, DenseAxis, DenseVariableAxisNode);
269273
};
@@ -289,11 +293,13 @@ class SparseFixedAxisNode : public SparseAxisNode {
289293
}
290294

291295
void SHashReduce(SHashReducer hash_reduce) const {
292-
SparseFixedAxisNode::SHashReduce(hash_reduce);
296+
SparseAxisNode::SHashReduce(hash_reduce);
293297
hash_reduce(indices);
294298
hash_reduce(nnz_cols);
295299
}
296300

301+
PrimExpr nnz() const { return indices->shape[0]; }
302+
297303
AxisKind kind() const final { return AxisKind::kSparseFixed; }
298304

299305
static constexpr const char* _type_key = "tir.sparse.SparseFixedAxis";
@@ -336,7 +342,7 @@ class SparseVariableAxisNode : public SparseAxisNode {
336342
hash_reduce(indices);
337343
}
338344

339-
PrimExpr nnz() const { return indptr->shape[0]; }
345+
PrimExpr nnz() const { return indices->shape[0]; }
340346

341347
AxisKind kind() const final { return AxisKind::kSparseVariable; }
342348

python/tvm/script/tir/special_stmt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -931,11 +931,11 @@ def dense_variable(
931931
f"`dense_variable` expected assign to only one var, but got {names}", span
932932
)
933933

934-
length, indptr_len = shape
934+
length, indptr_len, nnz = shape
935935
indptr_buf = tvm.tir.decl_buffer(
936936
(indptr_len,), dtype=idtype, name=names[0] + "_indptr", span=span
937937
)
938-
axis = DenseVariableAxis(names[0], length, indptr_buf)
938+
axis = DenseVariableAxis(names[0], length, nnz, indptr_buf)
939939
self.context.sp_struct.append(axis)
940940
self.context.sp_struct_params.append([indptr_var])
941941
self.context.update_symbol(names[0], axis, self.node)

python/tvm/tir/sparse.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,11 +127,12 @@ class DenseVariableAxis(DenseAxis):
127127

128128
name: str
129129
length: PrimExpr
130+
nnz: PrimExpr
130131
indptr: Buffer
131132

132-
def __init__(self, name, length, indptr):
133+
def __init__(self, name, length, nnz, indptr):
133134
self.__init_handle_by_constructor__(
134-
_ffi_api.DenseVariableAxis, name, length, indptr # type: ignore
135+
_ffi_api.DenseVariableAxis, name, length, nnz, indptr # type: ignore
135136
)
136137

137138

src/tir/ir/sparse.cc

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -68,19 +68,20 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
6868
/******** DenseVariableAxis ********/
6969

7070
/*! \brief Default constuctor of DenseVariableAxis */
71-
DenseVariableAxis::DenseVariableAxis(String name, PrimExpr length, Buffer indptr) {
71+
DenseVariableAxis::DenseVariableAxis(String name, PrimExpr length, PrimExpr nnz, Buffer indptr) {
7272
ObjectPtr<DenseVariableAxisNode> node = make_object<DenseVariableAxisNode>();
7373
node->name = std::move(name);
7474
node->length = std::move(length);
75+
node->nnz_ = std::move(nnz);
7576
node->indptr = std::move(indptr);
7677
data_ = std::move(node);
7778
}
7879

7980
TVM_REGISTER_NODE_TYPE(DenseVariableAxisNode);
8081

8182
TVM_REGISTER_GLOBAL("tir.sparse.DenseVariableAxis")
82-
.set_body_typed([](String name, PrimExpr length, Buffer indptr) {
83-
return DenseVariableAxis(name, length, indptr);
83+
.set_body_typed([](String name, PrimExpr length, PrimExpr nnz, Buffer indptr) {
84+
return DenseVariableAxis(std::move(name), std::move(length), std::move(nnz), std::move(indptr));
8485
});
8586

8687
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
@@ -128,17 +129,7 @@ FusedAxis::FusedAxis(Array<Axis> group, int index) {
128129
fused_name += group[i]->name;
129130
}
130131
node->name = "fused_" + fused_name + "_" + group[index]->name;
131-
132-
if (const auto* df_axis = group[index].as<DenseFixedAxisNode>()) {
133-
node->length = df_axis->length;
134-
} else if (const auto* sf_axis = group[index].as<SparseFixedAxisNode>()) {
135-
// TODO(zihao): accumulate previous dimensions.
136-
} else if (const auto* dv_axis = group[index].as<DenseVariableAxisNode>()) {
137-
node->length = dv_axis->nnz();
138-
} else if (const auto* sv_axis = group[index].as<SparseVariableAxisNode>()) {
139-
node->length = sv_axis->nnz();
140-
}
141-
132+
node->length = group[index]->nnz();
142133
node->is_derived_axis = true;
143134
node->group = std::move(group);
144135
node->index = index;
@@ -183,7 +174,7 @@ TVM_REGISTER_NODE_TYPE(SparseFixedAxisNode);
183174

184175
TVM_REGISTER_GLOBAL("tir.sparse.SparseFixedAxis")
185176
.set_body_typed([](String name, PrimExpr length, Buffer indices, PrimExpr nnz_cols) {
186-
return SparseFixedAxis(name, length, indices, nnz_cols);
177+
return SparseFixedAxis(std::move(name), std::move(length), std::move(indices), std::move(nnz_cols));
187178
});
188179

189180
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
@@ -210,7 +201,7 @@ TVM_REGISTER_NODE_TYPE(SparseVariableAxisNode);
210201

211202
TVM_REGISTER_GLOBAL("tir.sparse.SparseVariableAxis")
212203
.set_body_typed([](String name, PrimExpr length, Buffer indptr, Buffer indices) {
213-
return SparseVariableAxis(name, length, indptr, indices);
204+
return SparseVariableAxis(std::move(name), std::move(length), std::move(indptr), std::move(indices));
214205
});
215206

216207
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
@@ -259,7 +250,7 @@ TVM_REGISTER_NODE_TYPE(AxisTreeNode);
259250

260251
TVM_REGISTER_GLOBAL("tir.sparse.AxisTree")
261252
.set_body_typed([](Array<String> axis_names, Array<Optional<String>> axis_parent_names) {
262-
return AxisTree(axis_names, axis_parent_names);
253+
return AxisTree(std::move(axis_names), std::move(axis_parent_names));
263254
});
264255

265256
/******** SparseBuffer ********/
@@ -279,7 +270,7 @@ TVM_REGISTER_NODE_TYPE(SparseBufferNode);
279270

280271
TVM_REGISTER_GLOBAL("tir.sparse.SparseBuffer")
281272
.set_body_typed([](Array<Axis> axes, Buffer data, String name) {
282-
return SparseBuffer(axes, data, name);
273+
return SparseBuffer(std::move(axes), std::move(data), std::move(name));
283274
});
284275

285276
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
@@ -338,7 +329,7 @@ TVM_REGISTER_NODE_TYPE(SpIterVarNode);
338329

339330
TVM_REGISTER_GLOBAL("tir.sparse.SpIterVar")
340331
.set_body_typed([](Var var, PrimExpr max_extent, bool is_reduction, Axis axis) {
341-
return SpIterVar(var, max_extent, is_reduction, axis);
332+
return SpIterVar(std::move(var), std::move(max_extent), is_reduction, std::move(axis));
342333
});
343334

344335
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)

0 commit comments

Comments
 (0)