Skip to content

Commit 94dfbe9

Browse files
MasterJH5574yzh119
authored andcommitted
[BugFix] Add field is_reduction for SpIterVar (#9)
* [BugFix] Add field `is_reduction` for SpIterVar * Formatting
1 parent 06ce7e2 commit 94dfbe9

File tree

3 files changed

+32
-34
lines changed

3 files changed

+32
-34
lines changed

include/tvm/tir/sparse.h

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ class SparseAxis : public Axis {
187187
*/
188188
class SparseFixedAxisNode : public SparseAxisNode {
189189
public:
190-
Buffer indices;
190+
Buffer indices;
191191
/* fixed number of columns of current sparse axis. */
192192
PrimExpr num_cols;
193193

@@ -267,7 +267,6 @@ class SparseVariableAxis : public SparseAxis {
267267
TVM_DEFINE_OBJECT_REF_METHODS(SparseVariableAxis, SparseAxis, SparseVariableAxisNode);
268268
};
269269

270-
271270
/*!
272271
* \brief Axis Dependency Tree.
273272
*/
@@ -314,9 +313,7 @@ class SparseBufferNode : public Object {
314313
/* Data type */
315314
runtime::DataType dtype;
316315

317-
inline int ndim() const {
318-
return static_cast<int>(axes.size());
319-
}
316+
inline int ndim() const { return static_cast<int>(axes.size()); }
320317

321318
void VisitAttrs(AttrVisitor* v) {
322319
v->Visit("name", &tree);
@@ -370,24 +367,28 @@ class SpIterVarNode : public Object {
370367
Var var;
371368
PrimExpr max_extent;
372369
SpIterKind kind;
370+
bool is_reduction;
373371
Optional<Axis> axis;
374372

375373
void VisitAttrs(AttrVisitor* v) {
376374
v->Visit("var", &var);
377375
v->Visit("max_extent", &max_extent);
378376
v->Visit("axis", &axis);
377+
v->Visit("is_reduction", &is_reduction);
379378
v->Visit("kind", &kind);
380379
}
381380

382381
bool SEqualReduce(const SpIterVarNode* other, SEqualReducer equal) const {
383382
return equal(var, other->var) && equal(max_extent, other->max_extent) &&
384-
equal(axis, other->axis) && equal(kind, other->kind);
383+
equal(axis, other->axis) && equal(is_reduction, other->is_reduction) &&
384+
equal(kind, other->kind);
385385
}
386386

387387
void SHashReduce(SHashReducer hash_reduce) const {
388388
hash_reduce(var);
389389
hash_reduce(max_extent);
390390
hash_reduce(axis);
391+
hash_reduce(is_reduction);
391392
hash_reduce(kind);
392393
}
393394

@@ -399,7 +400,7 @@ class SpIterVarNode : public Object {
399400

400401
class SpIterVar : public ObjectRef {
401402
public:
402-
TVM_DLL explicit SpIterVar(String name, PrimExpr max_extent, SpIterKind kind,
403+
TVM_DLL explicit SpIterVar(String name, PrimExpr max_extent, SpIterKind kind, bool is_reduction,
403404
Optional<Axis> axis = NullOpt);
404405

405406
/*!

python/tvm/tir/sparse.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,9 @@ class SpIterVar(Object):
214214
215215
kind : int
216216
The kind of the SpIterVar
217+
218+
is_reduction : bool
219+
Whether the SpIterVar is a reduction iterator
217220
218221
axis : Optional[Axis]
219222
The axis over which the SpIterVar iterates. Required to be defined
@@ -222,6 +225,7 @@ class SpIterVar(Object):
222225
var: Var
223226
max_extent: PrimExpr
224227
kind: int
228+
is_reduction: bool
225229
axis: Optional[Axis]
226230

227231
DenseFixed = 0
@@ -231,6 +235,6 @@ class SpIterVar(Object):
231235

232236
def __init__(self, var, max_extent, kind, axis=None):
233237
self.__init_handle_by_constructor__(
234-
_ffi_api.SpIterVar, var, max_extent, kind, axis # type: ignore
238+
_ffi_api.SpIterVar, var, max_extent, kind, is_reduction, axis # type: ignore
235239
)
236240

src/tir/ir/sparse.cc

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,12 @@ DenseFixedAxis::DenseFixedAxis(String name, PrimExpr length) {
3838

3939
TVM_REGISTER_NODE_TYPE(DenseFixedAxisNode);
4040

41-
TVM_REGISTER_GLOBAL("tir.sparse.DenseFixedAxis")
42-
.set_body_typed([](String name, PrimExpr length) {
43-
return DenseFixedAxis(name, length);
44-
});
41+
TVM_REGISTER_GLOBAL("tir.sparse.DenseFixedAxis").set_body_typed([](String name, PrimExpr length) {
42+
return DenseFixedAxis(name, length);
43+
});
4544

4645
// DenseVariableAxis
47-
DenseVariableAxis::DenseVariableAxis(String name, PrimExpr length,
48-
Buffer indptr) {
46+
DenseVariableAxis::DenseVariableAxis(String name, PrimExpr length, Buffer indptr) {
4947
ObjectPtr<DenseVariableAxisNode> node = make_object<DenseVariableAxisNode>();
5048
node->name = std::move(name);
5149
node->length = std::move(length);
@@ -61,8 +59,7 @@ TVM_REGISTER_GLOBAL("tir.sparse.DenseVariableAxis")
6159
});
6260

6361
// SparseFixedAxis
64-
SparseFixedAxis::SparseFixedAxis(String name, PrimExpr length, Buffer indices,
65-
PrimExpr num_cols) {
62+
SparseFixedAxis::SparseFixedAxis(String name, PrimExpr length, Buffer indices, PrimExpr num_cols) {
6663
ObjectPtr<SparseFixedAxisNode> node = make_object<SparseFixedAxisNode>();
6764
node->name = std::move(name);
6865
node->length = std::move(length);
@@ -74,16 +71,14 @@ SparseFixedAxis::SparseFixedAxis(String name, PrimExpr length, Buffer indices,
7471
TVM_REGISTER_NODE_TYPE(SparseFixedAxisNode);
7572

7673
TVM_REGISTER_GLOBAL("tir.sparse.SparseFixedAxis")
77-
.set_body_typed([](String name, PrimExpr length, Buffer indices,
78-
PrimExpr num_cols) {
74+
.set_body_typed([](String name, PrimExpr length, Buffer indices, PrimExpr num_cols) {
7975
return SparseFixedAxis(name, length, indices, num_cols);
8076
});
8177

8278
// SparseVariableAxis
83-
SparseVariableAxis::SparseVariableAxis(String name, PrimExpr length,
84-
Buffer indptr, Buffer indices) {
85-
ObjectPtr<SparseVariableAxisNode> node =
86-
make_object<SparseVariableAxisNode>();
79+
SparseVariableAxis::SparseVariableAxis(String name, PrimExpr length, Buffer indptr,
80+
Buffer indices) {
81+
ObjectPtr<SparseVariableAxisNode> node = make_object<SparseVariableAxisNode>();
8782
node->name = std::move(name);
8883
node->length = std::move(length);
8984
node->indptr = std::move(indptr);
@@ -94,14 +89,12 @@ SparseVariableAxis::SparseVariableAxis(String name, PrimExpr length,
9489
TVM_REGISTER_NODE_TYPE(SparseVariableAxisNode);
9590

9691
TVM_REGISTER_GLOBAL("tir.sparse.SparseVariableAxis")
97-
.set_body_typed([](String name, PrimExpr length, Buffer indptr,
98-
Buffer indices) {
92+
.set_body_typed([](String name, PrimExpr length, Buffer indptr, Buffer indices) {
9993
return SparseVariableAxis(name, length, indptr, indices);
10094
});
10195

10296
// AxisTree
103-
AxisTree::AxisTree(Array<Axis> axes,
104-
Array<Optional<String>> axis_parent_names) {
97+
AxisTree::AxisTree(Array<Axis> axes, Array<Optional<String>> axis_parent_names) {
10598
CHECK_EQ(axes.size(), axis_parent_names.size())
10699
<< "ValueError: The axes array should have the same length as axis_parent_names "
107100
"array.";
@@ -121,9 +114,7 @@ AxisTree::AxisTree(Array<Axis> axes,
121114
CHECK(node->axis_map.find(parent_name.value()) != node->axis_map.end())
122115
<< "ValueError: Parent axis name doesn't exist.";
123116
}
124-
Axis parent_axis = (parent_name.get() != nullptr)
125-
? node->axis_map[parent_name.value()]
126-
: root;
117+
Axis parent_axis = (parent_name.get() != nullptr) ? node->axis_map[parent_name.value()] : root;
127118
node->parent[axis] = parent_axis;
128119
if (node->children.find(parent_axis) != node->children.end()) {
129120
node->children[parent_axis].push_back(axis);
@@ -139,8 +130,7 @@ AxisTree::AxisTree(Array<Axis> axes,
139130
TVM_REGISTER_NODE_TYPE(AxisTreeNode);
140131

141132
TVM_REGISTER_GLOBAL("tir.sparse.AxisTree")
142-
.set_body_typed([](Array<Axis> axes,
143-
Array<Optional<String>> axis_parent_names) {
133+
.set_body_typed([](Array<Axis> axes, Array<Optional<String>> axis_parent_names) {
144134
return AxisTree(axes, axis_parent_names);
145135
});
146136

@@ -164,7 +154,8 @@ TVM_REGISTER_GLOBAL("tir.sparse.SparseBuffer")
164154
});
165155

166156
// SpIterVar
167-
SpIterVar::SpIterVar(String name, PrimExpr max_extent, SpIterKind kind, Optional<Axis> axis) {
157+
SpIterVar::SpIterVar(String name, PrimExpr max_extent, SpIterKind kind, bool is_reduction,
158+
Optional<Axis> axis) {
168159
ObjectPtr<SpIterVarNode> node = make_object<SpIterVarNode>();
169160

170161
if (kind != SpIterKind::kDenseFixed) {
@@ -175,15 +166,17 @@ SpIterVar::SpIterVar(String name, PrimExpr max_extent, SpIterKind kind, Optional
175166
node->var = Var(std::move(name));
176167
node->max_extent = std::move(max_extent);
177168
node->kind = kind;
169+
node->is_reduction = is_reduction;
178170
node->axis = std::move(axis);
179171
data_ = std::move(node);
180172
}
181173

182174
TVM_REGISTER_NODE_TYPE(SpIterVarNode);
183175

184176
TVM_REGISTER_GLOBAL("tir.sparse.SpIterVar")
185-
.set_body_typed([](String name, PrimExpr max_extent, SpIterKind kind, Optional<Axis> axis) {
186-
return SpIterVar(name, max_extent, kind, axis);
177+
.set_body_typed([](String name, PrimExpr max_extent, SpIterKind kind, bool is_reduction,
178+
Optional<Axis> axis) {
179+
return SpIterVar(name, max_extent, kind, is_reduction, axis);
187180
});
188181

189182
} // namespace tir

0 commit comments

Comments
 (0)