Skip to content

Commit e9a1771

Browse files
yzh119MasterJH5574
andcommitted
Frontend update, demo scripts. (#10)
* Format and Buffer data structure (#1) * [SparseTIR] Constructors and Python Interface for `Axis` and `SparseBuffer` (#2) * add methods for Object * axis constructors * methods for SparseBuffer * put into registry * python interface * [CherryPick][Intrinsic] lower_bound and upper_bound for binary search in Sparse TIR. (apache#483) (#4) * upd * upd * fix * upd * upd * upd * upd * upd * fix * upd * upd * upd * upd * upd * upd * upd * codegen-rule * upd * upd * test * upd * fix * two arguments Co-authored-by: Zihao Ye <expye@outlook.com> * Fix AxisTree (#3) * fix axis tree * upd * Format and Buffer data structure (#1) * [SparseTIR] Constructors and Python Interface for `Axis` and `SparseBuffer` (#2) * add methods for Object * axis constructors * methods for SparseBuffer * put into registry * python interface * fix axis tree * upd * Format and Buffer data structure (#1) * [SparseTIR] Constructors and Python Interface for `Axis` and `SparseBuffer` (#2) * add methods for Object * axis constructors * methods for SparseBuffer * put into registry * python interface * [CherryPick][Intrinsic] lower_bound and upper_bound for binary search in Sparse TIR. (apache#483) (#4) * upd * upd * fix * upd * upd * upd * upd * upd * fix * upd * upd * upd * upd * upd * upd * upd * codegen-rule * upd * upd * test * upd * fix * two arguments Co-authored-by: Zihao Ye <expye@outlook.com> * Fix AxisTree (#3) * fix axis tree * upd * [SparseTIR] Add SparseBufferLoad/SparseBufferStore (#5) * Add dtype for SparseBuffer * Add name for SparseBuffer. Remove `ndim` * Remove namespace sparse * Add SparseBufferLoad/Store * Add method `ndim()` * Format and Buffer data structure (#1) * [SparseTIR] Constructors and Python Interface for `Axis` and `SparseBuffer` (#2) * add methods for Object * axis constructors * methods for SparseBuffer * put into registry * python interface * [CherryPick][Intrinsic] lower_bound and upper_bound for binary search in Sparse TIR. (apache#483) (#4) * upd * upd * fix * upd * upd * upd * upd * upd * fix * upd * upd * upd * upd * upd * upd * upd * codegen-rule * upd * upd * test * upd * fix * two arguments Co-authored-by: Zihao Ye <expye@outlook.com> * Fix AxisTree (#3) * fix axis tree * upd * [SparseTIR] Add SparseBufferLoad/SparseBufferStore (#5) * Add dtype for SparseBuffer * Add name for SparseBuffer. Remove `ndim` * Remove namespace sparse * Add SparseBufferLoad/Store * Add method `ndim()` * [SparseTIR] Introduce SpIterVar (#6) * [SparseTIR] Introduce SpIterVar * Add conversion to PrimExpr * [BugFix] Fix binary search & SpIterVar (#7) * [BugFix] Add field `is_reduction` for SpIterVar (#9) * [BugFix] Add field `is_reduction` for SpIterVar * Formatting * upd * upd Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
1 parent affb1af commit e9a1771

File tree

8 files changed

+451
-99
lines changed

8 files changed

+451
-99
lines changed

include/tvm/tir/sparse.h

Lines changed: 39 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ class AxisNode : public Object {
4444
* the current axis. */
4545
PrimExpr length;
4646

47+
String GetName() const { return name; }
48+
PrimExpr GetLength() const { return length; }
49+
DataType GetIndexType() const { return length->dtype; }
50+
4751
static constexpr const char* _type_key = "tir.sparse.Axis";
4852
static constexpr const bool _type_has_method_sequal_reduce = true;
4953
static constexpr const bool _type_has_method_shash_reduce = true;
@@ -139,8 +143,10 @@ class DenseVariableAxisNode : public DenseAxisNode {
139143
v->Visit("indptr", &indptr);
140144
}
141145

142-
bool SEqualReduce(const DenseVariableAxisNode* other, SEqualReducer equal) const {
143-
return equal(name, other->name) && equal(length, other->length) && equal(indptr, other->indptr);
146+
bool SEqualReduce(const DenseVariableAxisNode* other,
147+
SEqualReducer equal) const {
148+
return equal(name, other->name) && equal(length, other->length) &&
149+
equal(indptr, other->indptr);
144150
}
145151

146152
void SHashReduce(SHashReducer hash_reduce) const {
@@ -159,9 +165,11 @@ class DenseVariableAxisNode : public DenseAxisNode {
159165
*/
160166
class DenseVariableAxis : public DenseAxis {
161167
public:
162-
TVM_DLL explicit DenseVariableAxis(String name, PrimExpr length, Buffer indptr);
168+
TVM_DLL explicit DenseVariableAxis(String name, PrimExpr length,
169+
Buffer indptr);
163170

164-
TVM_DEFINE_OBJECT_REF_METHODS(DenseVariableAxis, DenseAxis, DenseVariableAxisNode);
171+
TVM_DEFINE_OBJECT_REF_METHODS(DenseVariableAxis, DenseAxis,
172+
DenseVariableAxisNode);
165173
};
166174

167175
/*!
@@ -198,7 +206,8 @@ class SparseFixedAxisNode : public SparseAxisNode {
198206
v->Visit("num_cols", &num_cols);
199207
}
200208

201-
bool SEqualReduce(const SparseFixedAxisNode* other, SEqualReducer equal) const {
209+
bool SEqualReduce(const SparseFixedAxisNode* other,
210+
SEqualReducer equal) const {
202211
return equal(name, other->name) && equal(length, other->length) &&
203212
equal(indices, other->indices) && equal(num_cols, other->num_cols);
204213
}
@@ -220,9 +229,11 @@ class SparseFixedAxisNode : public SparseAxisNode {
220229
*/
221230
class SparseFixedAxis : public SparseAxis {
222231
public:
223-
TVM_DLL explicit SparseFixedAxis(String name, PrimExpr length, Buffer indices, PrimExpr num_cols);
232+
TVM_DLL explicit SparseFixedAxis(String name, PrimExpr length, Buffer indices,
233+
PrimExpr num_cols);
224234

225-
TVM_DEFINE_OBJECT_REF_METHODS(SparseFixedAxis, SparseAxis, SparseFixedAxisNode);
235+
TVM_DEFINE_OBJECT_REF_METHODS(SparseFixedAxis, SparseAxis,
236+
SparseFixedAxisNode);
226237
};
227238

228239
/*!
@@ -240,7 +251,8 @@ class SparseVariableAxisNode : public SparseAxisNode {
240251
v->Visit("indices", &indices);
241252
}
242253

243-
bool SEqualReduce(const SparseVariableAxisNode* other, SEqualReducer equal) const {
254+
bool SEqualReduce(const SparseVariableAxisNode* other,
255+
SEqualReducer equal) const {
244256
return equal(name, other->name) && equal(length, other->length) &&
245257
equal(indptr, other->indptr) && equal(indices, other->indices);
246258
}
@@ -262,24 +274,25 @@ class SparseVariableAxisNode : public SparseAxisNode {
262274
*/
263275
class SparseVariableAxis : public SparseAxis {
264276
public:
265-
TVM_DLL explicit SparseVariableAxis(String name, PrimExpr length, Buffer indptr, Buffer indices);
277+
TVM_DLL explicit SparseVariableAxis(String name, PrimExpr length,
278+
Buffer indptr, Buffer indices);
266279

267-
TVM_DEFINE_OBJECT_REF_METHODS(SparseVariableAxis, SparseAxis, SparseVariableAxisNode);
280+
TVM_DEFINE_OBJECT_REF_METHODS(SparseVariableAxis, SparseAxis,
281+
SparseVariableAxisNode);
268282
};
269283

270284
/*!
271285
* \brief Axis Dependency Tree.
272286
*/
273287
class AxisTreeNode : public Object {
274288
public:
275-
// mapping from names to axes.
276-
std::unordered_map<String, Axis> axis_map;
277289
// unordered map that stores the parent relationship between axes.
278-
std::unordered_map<Axis, Axis, ObjectPtrHash, ObjectPtrEqual> parent;
290+
std::unordered_map<String, Optional<String>, ObjectPtrHash, ObjectPtrEqual>
291+
parent;
279292
// unordered map that stores the children relationship between axes.
280-
std::unordered_map<Axis, Array<Axis>, ObjectPtrHash, ObjectPtrEqual> children;
281-
// The root axis.
282-
Axis root;
293+
std::unordered_map<Optional<String>, Array<String>, ObjectPtrHash,
294+
ObjectPtrEqual>
295+
children;
283296

284297
void VisitAttrs(AttrVisitor* v) {}
285298

@@ -293,7 +306,9 @@ class AxisTreeNode : public Object {
293306
*/
294307
class AxisTree : public ObjectRef {
295308
public:
296-
TVM_DLL AxisTree(Array<Axis> axes, Array<Optional<String>> axis_parent_names);
309+
TVM_DLL AxisTree(Array<String> axis_names,
310+
Array<Optional<String>> axis_parent_names);
311+
297312
TVM_DEFINE_OBJECT_REF_METHODS(AxisTree, ObjectRef, AxisTreeNode);
298313
};
299314

@@ -302,38 +317,30 @@ class AxisTree : public ObjectRef {
302317
*/
303318
class SparseBufferNode : public Object {
304319
public:
305-
/* Root of Axis Dependency Tree. */
306-
AxisTree tree;
307320
/* Axes */
308321
Array<Axis> axes;
309322
/* Buffer corresponding to flattened value */
310323
Buffer data;
311324
/* Buffer Name */
312325
String name;
313-
/* Data type */
314-
runtime::DataType dtype;
315326

316327
inline int ndim() const { return static_cast<int>(axes.size()); }
317328

318329
void VisitAttrs(AttrVisitor* v) {
319-
v->Visit("name", &tree);
320330
v->Visit("length", &axes);
321331
v->Visit("num_cols", &data);
322332
v->Visit("name", &name);
323-
v->Visit("dtype", &dtype);
324333
}
325334

326335
bool SEqualReduce(const SparseBufferNode* other, SEqualReducer equal) const {
327-
return equal(tree, other->tree) && equal(axes, other->axes) && equal(data, other->data) &&
328-
equal(name, other->name) && equal(dtype, other->dtype);
336+
return equal(axes, other->axes) && equal(data, other->data) &&
337+
equal(name, other->name);
329338
}
330339

331340
void SHashReduce(SHashReducer hash_reduce) const {
332-
hash_reduce(tree);
333341
hash_reduce(axes);
334342
hash_reduce(data);
335343
hash_reduce(name);
336-
hash_reduce(dtype);
337344
}
338345

339346
static constexpr const char* _type_key = "tir.sparse.SparseBuffer";
@@ -346,8 +353,7 @@ class SparseBufferNode : public Object {
346353
*/
347354
class SparseBuffer : public ObjectRef {
348355
public:
349-
TVM_DLL explicit SparseBuffer(AxisTree tree, Array<Axis> axes, Buffer data, String name,
350-
DataType dtype);
356+
TVM_DLL explicit SparseBuffer(Array<Axis> axes, Buffer data, String name);
351357

352358
TVM_DEFINE_OBJECT_REF_METHODS(SparseBuffer, ObjectRef, SparseBufferNode);
353359
};
@@ -380,8 +386,8 @@ class SpIterVarNode : public Object {
380386

381387
bool SEqualReduce(const SpIterVarNode* other, SEqualReducer equal) const {
382388
return equal(var, other->var) && equal(max_extent, other->max_extent) &&
383-
equal(axis, other->axis) && equal(is_reduction, other->is_reduction) &&
384-
equal(kind, other->kind);
389+
equal(axis, other->axis) &&
390+
equal(is_reduction, other->is_reduction) && equal(kind, other->kind);
385391
}
386392

387393
void SHashReduce(SHashReducer hash_reduce) const {
@@ -400,8 +406,8 @@ class SpIterVarNode : public Object {
400406

401407
class SpIterVar : public ObjectRef {
402408
public:
403-
TVM_DLL explicit SpIterVar(String name, PrimExpr max_extent, SpIterKind kind, bool is_reduction,
404-
Optional<Axis> axis = NullOpt);
409+
TVM_DLL explicit SpIterVar(String name, PrimExpr max_extent, SpIterKind kind,
410+
bool is_reduction, Optional<Axis> axis = NullOpt);
405411

406412
/*!
407413
* \return the corresponding var in the IterVar.

include/tvm/tir/stmt.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,28 @@ class BufferStore : public Stmt {
327327
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferStoreNode);
328328
};
329329

330+
/*!
331+
* \brief Sparse Block node.
332+
*/
333+
class SparseBlockNode : public StmtNode {
334+
public:
335+
/*! \brief The sparse iteration variables of the block. */
336+
Array<SpIterVar> sp_iter_vars;
337+
/*! \brief The sparse buffers defined in the block. */
338+
Array<SparseBuffer> sp_buffers;
339+
/*! \brief The body of the block */
340+
Stmt body;
341+
342+
static constexpr const char* _type_key = "tir.SparseBlock";
343+
TVM_DECLARE_FINAL_OBJECT_INFO(SparseBlockNode, StmtNode);
344+
};
345+
346+
class SparseBlock : public Stmt {
347+
public:
348+
TVM_DEFINE_OBJECT_REF_METHODS(SparseBlock, Stmt, SparseBlockNode);
349+
};
350+
351+
330352
/*!
331353
* \brief Store value to the high dimension sparse buffer.
332354
*

python/tvm/script/context_maintainer.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,11 @@
2323
import tvm
2424
from tvm.ir import Span
2525
from tvm.ir.expr import Range
26+
from tvm.script.tir.sparse import MatchSparseBuffer
2627
from tvm.tir import Var, Buffer, PrimExpr, Stmt, MatchBufferRegion
2728
from tvm.runtime import Object
2829
from tvm.tir.expr import IterVar
30+
from tvm.tir.sparse import Axis, SparseBuffer
2931
from .tir.node import BufferSlice
3032

3133

@@ -74,6 +76,10 @@ def example_func(a: T.handle, b: T.handle, c: T.handle) -> None:
7476
"""List[Buffer]: list of T.alloc_buffer statements in the block signature"""
7577
match_buffers: List[MatchBufferRegion] = []
7678
"""List[MatchBufferRegion]: list of T.match_buffer statements in the block signature"""
79+
axes: List[Axis] = []
80+
"""List[Axis]: list of sparse axis created in the block signature."""
81+
match_sparse_buffers: List[MatchSparseBuffer]
82+
"""List[MatchSparseBuffer]: list of T.match_sparse_buffer statements in the block signature."""
7783
iter_values: List[PrimExpr] = []
7884
"""List[PrimExpr]: list of binding values for iter vars"""
7985
iter_vars: List[IterVar] = []
@@ -119,14 +125,16 @@ class ContextMaintainer:
119125
"""List[BlockInfo]: The block info for the current block scope"""
120126
loop_stack: Dict[Var, Range] = {}
121127
"""Dict[Var, Range]: The dict from loop var to its domain outside the block"""
122-
symbols: List[Dict[str, Union[Var, Buffer]]] = []
128+
symbols: List[Dict[str, Union[Var, Buffer, SparseBuffer, Axis]]] = []
123129
"""List[Dict[str, Union[Var, Buffer]]]: Symbol map from name to object for the current scope"""
124130

125131
# function context
126132
func_params: List[Var] = []
127133
"""List[Var]: The function parameters"""
128134
func_buffer_map: Mapping[Var, Buffer] = {}
129135
"""Mapping[Var, Buffer]: The function buffer map"""
136+
func_sparse_buffer_map: Mapping[Var, SparseBuffer] = {}
137+
"""Mapping[Var, SparseBuffer]: The function sparse buffer map"""
130138
func_dict_attr: Mapping[str, Object] = {}
131139
"""Mapping[str, Object]: The function attrs"""
132140
func_var_env_dict: Mapping[Var, str] = {}
@@ -151,6 +159,7 @@ def __init__(self, _report_error: Callable[[str, Union[Span, synr.ast.Span]], No
151159
# function context
152160
self.func_params = []
153161
self.func_buffer_map = {}
162+
self.func_sparse_buffer_map = {}
154163
self.func_dict_attr = {}
155164
self.func_var_env_dict = {}
156165
# parser and analyzer
@@ -208,9 +217,9 @@ def exit_block_scope(self):
208217
# Pop block_info
209218
self.block_info_stack.pop()
210219

211-
def update_symbol(self, name: str, symbol: Union[Buffer, Var], node: synr.ast.Node):
220+
def update_symbol(self, name: str, symbol: Union[Buffer, Var, SparseBuffer, Axis], node: synr.ast.Node):
212221
"""Append a symbol into current scope"""
213-
if isinstance(symbol, Buffer):
222+
if isinstance(symbol, (Buffer, Var, SparseBuffer, Axis)):
214223
if name in self.symbols[0]:
215224
self.report_error("Duplicate Buffer name: " + symbol.name, node.span)
216225
self.symbols[0][name] = symbol
@@ -225,7 +234,7 @@ def remove_symbol(self, name: str):
225234
return
226235
raise RuntimeError("Internal error of tvm script parser: no symbol named " + name)
227236

228-
def lookup_symbol(self, name: str) -> Optional[Union[Buffer, Var]]:
237+
def lookup_symbol(self, name: str) -> Optional[Union[Buffer, Var, SparseBuffer, Axis]]:
229238
"""Look up symbol by name"""
230239
for symbols in reversed(self.symbols):
231240
if name in symbols:

0 commit comments

Comments
 (0)