Skip to content

Commit e4e093e

Browse files
MasterJH5574yzh119
authored andcommitted
[BugFix][SparseTIR] TVMScript Parser for Axis & SpIterVar (#12)
* Update `cord` and `pos` * Fix `idtype` * Formatting.. * Bug fix 1 * Move new special stmts * Parser for Axis and SpIterVar * Fix context_maintainer.py
1 parent 227c6b4 commit e4e093e

File tree

7 files changed

+218
-282
lines changed

7 files changed

+218
-282
lines changed

include/tvm/tir/sparse.h

Lines changed: 18 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -143,10 +143,8 @@ class DenseVariableAxisNode : public DenseAxisNode {
143143
v->Visit("indptr", &indptr);
144144
}
145145

146-
bool SEqualReduce(const DenseVariableAxisNode* other,
147-
SEqualReducer equal) const {
148-
return equal(name, other->name) && equal(length, other->length) &&
149-
equal(indptr, other->indptr);
146+
bool SEqualReduce(const DenseVariableAxisNode* other, SEqualReducer equal) const {
147+
return equal(name, other->name) && equal(length, other->length) && equal(indptr, other->indptr);
150148
}
151149

152150
void SHashReduce(SHashReducer hash_reduce) const {
@@ -165,11 +163,9 @@ class DenseVariableAxisNode : public DenseAxisNode {
165163
*/
166164
class DenseVariableAxis : public DenseAxis {
167165
public:
168-
TVM_DLL explicit DenseVariableAxis(String name, PrimExpr length,
169-
Buffer indptr);
166+
TVM_DLL explicit DenseVariableAxis(String name, PrimExpr length, Buffer indptr);
170167

171-
TVM_DEFINE_OBJECT_REF_METHODS(DenseVariableAxis, DenseAxis,
172-
DenseVariableAxisNode);
168+
TVM_DEFINE_OBJECT_REF_METHODS(DenseVariableAxis, DenseAxis, DenseVariableAxisNode);
173169
};
174170

175171
/*!
@@ -206,8 +202,7 @@ class SparseFixedAxisNode : public SparseAxisNode {
206202
v->Visit("num_cols", &num_cols);
207203
}
208204

209-
bool SEqualReduce(const SparseFixedAxisNode* other,
210-
SEqualReducer equal) const {
205+
bool SEqualReduce(const SparseFixedAxisNode* other, SEqualReducer equal) const {
211206
return equal(name, other->name) && equal(length, other->length) &&
212207
equal(indices, other->indices) && equal(num_cols, other->num_cols);
213208
}
@@ -229,11 +224,9 @@ class SparseFixedAxisNode : public SparseAxisNode {
229224
*/
230225
class SparseFixedAxis : public SparseAxis {
231226
public:
232-
TVM_DLL explicit SparseFixedAxis(String name, PrimExpr length, Buffer indices,
233-
PrimExpr num_cols);
227+
TVM_DLL explicit SparseFixedAxis(String name, PrimExpr length, Buffer indices, PrimExpr num_cols);
234228

235-
TVM_DEFINE_OBJECT_REF_METHODS(SparseFixedAxis, SparseAxis,
236-
SparseFixedAxisNode);
229+
TVM_DEFINE_OBJECT_REF_METHODS(SparseFixedAxis, SparseAxis, SparseFixedAxisNode);
237230
};
238231

239232
/*!
@@ -251,8 +244,7 @@ class SparseVariableAxisNode : public SparseAxisNode {
251244
v->Visit("indices", &indices);
252245
}
253246

254-
bool SEqualReduce(const SparseVariableAxisNode* other,
255-
SEqualReducer equal) const {
247+
bool SEqualReduce(const SparseVariableAxisNode* other, SEqualReducer equal) const {
256248
return equal(name, other->name) && equal(length, other->length) &&
257249
equal(indptr, other->indptr) && equal(indices, other->indices);
258250
}
@@ -274,11 +266,9 @@ class SparseVariableAxisNode : public SparseAxisNode {
274266
*/
275267
class SparseVariableAxis : public SparseAxis {
276268
public:
277-
TVM_DLL explicit SparseVariableAxis(String name, PrimExpr length,
278-
Buffer indptr, Buffer indices);
269+
TVM_DLL explicit SparseVariableAxis(String name, PrimExpr length, Buffer indptr, Buffer indices);
279270

280-
TVM_DEFINE_OBJECT_REF_METHODS(SparseVariableAxis, SparseAxis,
281-
SparseVariableAxisNode);
271+
TVM_DEFINE_OBJECT_REF_METHODS(SparseVariableAxis, SparseAxis, SparseVariableAxisNode);
282272
};
283273

284274
/*!
@@ -287,12 +277,9 @@ class SparseVariableAxis : public SparseAxis {
287277
class AxisTreeNode : public Object {
288278
public:
289279
// unordered map that stores the parent relationship between axes.
290-
std::unordered_map<String, Optional<String>, ObjectPtrHash, ObjectPtrEqual>
291-
parent;
280+
std::unordered_map<String, Optional<String>, ObjectPtrHash, ObjectPtrEqual> parent;
292281
// unordered map that stores the children relationship between axes.
293-
std::unordered_map<Optional<String>, Array<String>, ObjectPtrHash,
294-
ObjectPtrEqual>
295-
children;
282+
std::unordered_map<Optional<String>, Array<String>, ObjectPtrHash, ObjectPtrEqual> children;
296283

297284
void VisitAttrs(AttrVisitor* v) {}
298285

@@ -306,8 +293,7 @@ class AxisTreeNode : public Object {
306293
*/
307294
class AxisTree : public ObjectRef {
308295
public:
309-
TVM_DLL AxisTree(Array<String> axis_names,
310-
Array<Optional<String>> axis_parent_names);
296+
TVM_DLL AxisTree(Array<String> axis_names, Array<Optional<String>> axis_parent_names);
311297

312298
TVM_DEFINE_OBJECT_REF_METHODS(AxisTree, ObjectRef, AxisTreeNode);
313299
};
@@ -333,8 +319,7 @@ class SparseBufferNode : public Object {
333319
}
334320

335321
bool SEqualReduce(const SparseBufferNode* other, SEqualReducer equal) const {
336-
return equal(axes, other->axes) && equal(data, other->data) &&
337-
equal(name, other->name);
322+
return equal(axes, other->axes) && equal(data, other->data) && equal(name, other->name);
338323
}
339324

340325
void SHashReduce(SHashReducer hash_reduce) const {
@@ -386,8 +371,8 @@ class SpIterVarNode : public Object {
386371

387372
bool SEqualReduce(const SpIterVarNode* other, SEqualReducer equal) const {
388373
return equal(var, other->var) && equal(max_extent, other->max_extent) &&
389-
equal(axis, other->axis) &&
390-
equal(is_reduction, other->is_reduction) && equal(kind, other->kind);
374+
equal(axis, other->axis) && equal(is_reduction, other->is_reduction) &&
375+
equal(kind, other->kind);
391376
}
392377

393378
void SHashReduce(SHashReducer hash_reduce) const {
@@ -406,8 +391,8 @@ class SpIterVarNode : public Object {
406391

407392
class SpIterVar : public ObjectRef {
408393
public:
409-
TVM_DLL explicit SpIterVar(String name, PrimExpr max_extent, SpIterKind kind,
410-
bool is_reduction, Optional<Axis> axis = NullOpt);
394+
TVM_DLL explicit SpIterVar(Var var, PrimExpr max_extent, SpIterKind kind, bool is_reduction,
395+
Optional<Axis> axis = NullOpt);
411396

412397
/*!
413398
* \return the corresponding var in the IterVar.

python/tvm/script/context_maintainer.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import tvm
2424
from tvm.ir import Span
2525
from tvm.ir.expr import Range
26-
from tvm.script.tir.sparse import MatchSparseBuffer
2726
from tvm.tir import Var, Buffer, PrimExpr, Stmt, MatchBufferRegion
2827
from tvm.runtime import Object
2928
from tvm.tir.expr import IterVar
@@ -76,10 +75,6 @@ def example_func(a: T.handle, b: T.handle, c: T.handle) -> None:
7675
"""List[Buffer]: list of T.alloc_buffer statements in the block signature"""
7776
match_buffers: List[MatchBufferRegion] = []
7877
"""List[MatchBufferRegion]: list of T.match_buffer statements in the block signature"""
79-
axes: List[Axis] = []
80-
"""List[Axis]: list of sparse axis created in the block signature."""
81-
match_sparse_buffers: List[MatchSparseBuffer]
82-
"""List[MatchSparseBuffer]: list of T.match_sparse_buffer statements in the block signature."""
8378
iter_values: List[PrimExpr] = []
8479
"""List[PrimExpr]: list of binding values for iter vars"""
8580
iter_vars: List[IterVar] = []
@@ -211,7 +206,9 @@ def exit_block_scope(self):
211206
# Pop block_info
212207
self.block_info_stack.pop()
213208

214-
def update_symbol(self, name: str, symbol: Union[Buffer, Var, SparseBuffer, Axis], node: synr.ast.Node):
209+
def update_symbol(
210+
self, name: str, symbol: Union[Buffer, Var, SparseBuffer, Axis], node: synr.ast.Node
211+
):
215212
"""Append a symbol into current scope"""
216213
if isinstance(symbol, (Buffer, Var, SparseBuffer, Axis)):
217214
if name in self.symbols[0]:

python/tvm/script/tir/sparse.py

Lines changed: 0 additions & 207 deletions
This file was deleted.

0 commit comments

Comments
 (0)