Skip to content

Commit 8620909

Browse files
authored
Syntax simplification (apache#34)
1 parent 37bdbfb commit 8620909

File tree

13 files changed

+279
-373
lines changed

13 files changed

+279
-373
lines changed

include/tvm/tir/sparse.h

Lines changed: 20 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ enum class AxisKind : int {
4040
kSparseVariable = 3
4141
};
4242

43+
class Axis;
44+
4345
/*!
4446
* \brief Base type for axis in sparse formats.
4547
*/
@@ -73,6 +75,7 @@ class AxisNode : public Object {
7375
String GetName() const { return name; }
7476
PrimExpr GetLength() const { return length; }
7577
DataType GetIndexType() const { return length->dtype; }
78+
virtual Optional<Axis> GetParentAxis() const = 0;
7679

7780
virtual AxisKind kind() const = 0;
7881
virtual PrimExpr nnz() const = 0;
@@ -137,6 +140,8 @@ class DenseFixedAxisNode : public DenseAxisNode {
137140

138141
PrimExpr nnz() const final { return length; }
139142

143+
Optional<Axis> GetParentAxis() const final { return NullOpt; }
144+
140145
static constexpr const char* _type_key = "tir.sparse.DenseFixedAxis";
141146
TVM_DECLARE_BASE_OBJECT_INFO(DenseFixedAxisNode, DenseAxisNode);
142147
};
@@ -238,6 +243,7 @@ class DenseVariableAxisNode : public DenseAxisNode {
238243
public:
239244
Buffer indptr;
240245
PrimExpr nnz_;
246+
Axis parent_;
241247

242248
void VisitAttrs(AttrVisitor* v) {
243249
DenseAxisNode::VisitAttrs(v);
@@ -257,6 +263,8 @@ class DenseVariableAxisNode : public DenseAxisNode {
257263

258264
PrimExpr nnz() const final { return nnz_; }
259265

266+
Optional<Axis> GetParentAxis() const final { return parent_; }
267+
260268
static constexpr const char* _type_key = "tir.sparse.DenseVariableAxis";
261269
TVM_DECLARE_FINAL_OBJECT_INFO(DenseVariableAxisNode, DenseAxisNode);
262270
};
@@ -267,7 +275,8 @@ class DenseVariableAxisNode : public DenseAxisNode {
267275
*/
268276
class DenseVariableAxis : public DenseAxis {
269277
public:
270-
TVM_DLL explicit DenseVariableAxis(String name, PrimExpr length, PrimExpr nnz, Buffer indptr);
278+
TVM_DLL explicit DenseVariableAxis(String name, Axis parent, PrimExpr length, PrimExpr nnz,
279+
Buffer indptr);
271280

272281
TVM_DEFINE_OBJECT_REF_METHODS(DenseVariableAxis, DenseAxis, DenseVariableAxisNode);
273282
};
@@ -280,6 +289,7 @@ class SparseFixedAxisNode : public SparseAxisNode {
280289
Buffer indices;
281290
/* fixed number of non-zero columns of current sparse axis. */
282291
PrimExpr nnz_cols;
292+
Axis parent_;
283293

284294
void VisitAttrs(AttrVisitor* v) {
285295
SparseAxisNode::VisitAttrs(v);
@@ -302,6 +312,8 @@ class SparseFixedAxisNode : public SparseAxisNode {
302312

303313
AxisKind kind() const final { return AxisKind::kSparseFixed; }
304314

315+
Optional<Axis> GetParentAxis() const final { return parent_; }
316+
305317
static constexpr const char* _type_key = "tir.sparse.SparseFixedAxis";
306318
TVM_DECLARE_FINAL_OBJECT_INFO(SparseFixedAxisNode, SparseAxisNode);
307319
};
@@ -312,7 +324,8 @@ class SparseFixedAxisNode : public SparseAxisNode {
312324
*/
313325
class SparseFixedAxis : public SparseAxis {
314326
public:
315-
TVM_DLL explicit SparseFixedAxis(String name, PrimExpr length, Buffer indices, PrimExpr nnz_cols);
327+
TVM_DLL explicit SparseFixedAxis(String name, Axis parent, PrimExpr length, Buffer indices,
328+
PrimExpr nnz_cols);
316329

317330
TVM_DEFINE_OBJECT_REF_METHODS(SparseFixedAxis, SparseAxis, SparseFixedAxisNode);
318331
};
@@ -324,6 +337,7 @@ class SparseVariableAxisNode : public SparseAxisNode {
324337
public:
325338
Buffer indptr;
326339
Buffer indices;
340+
Axis parent_;
327341

328342
void VisitAttrs(AttrVisitor* v) {
329343
SparseAxisNode::VisitAttrs(v);
@@ -346,6 +360,8 @@ class SparseVariableAxisNode : public SparseAxisNode {
346360

347361
AxisKind kind() const final { return AxisKind::kSparseVariable; }
348362

363+
Optional<Axis> GetParentAxis() const final { return parent_; }
364+
349365
static constexpr const char* _type_key = "tir.sparse.SparseVariableAxis";
350366
TVM_DECLARE_FINAL_OBJECT_INFO(SparseVariableAxisNode, SparseAxisNode);
351367
};
@@ -356,52 +372,12 @@ class SparseVariableAxisNode : public SparseAxisNode {
356372
*/
357373
class SparseVariableAxis : public SparseAxis {
358374
public:
359-
TVM_DLL explicit SparseVariableAxis(String name, PrimExpr length, Buffer indptr, Buffer indices);
375+
TVM_DLL explicit SparseVariableAxis(String name, Axis parent, PrimExpr length, Buffer indptr,
376+
Buffer indices);
360377

361378
TVM_DEFINE_OBJECT_REF_METHODS(SparseVariableAxis, SparseAxis, SparseVariableAxisNode);
362379
};
363380

364-
/*!
365-
* \brief Axis Dependency Tree.
366-
*/
367-
class AxisTreeNode : public Object {
368-
public:
369-
// unordered map that stores the parent relationship between axes.
370-
Map<String, String> parent;
371-
// unordered map that stores the children relationship between axes.
372-
Map<String, Array<String>> children;
373-
374-
void VisitAttrs(AttrVisitor* v) {
375-
v->Visit("parent", &parent);
376-
v->Visit("children", &children);
377-
}
378-
379-
bool SEqualReduce(const AxisTreeNode* other, SEqualReducer equal) const {
380-
return equal(parent, other->parent) && equal(children, other->children);
381-
}
382-
383-
void SHashReduce(SHashReducer hash_reduce) const {
384-
hash_reduce(parent);
385-
hash_reduce(children);
386-
}
387-
388-
static constexpr const char* _type_key = "tir.sparse.AxisTree";
389-
static constexpr const bool _type_has_method_sequal_reduce = true;
390-
static constexpr const bool _type_has_method_shash_reduce = true;
391-
TVM_DECLARE_FINAL_OBJECT_INFO(AxisTreeNode, Object);
392-
};
393-
394-
/*!
395-
* \brief Managed reference to AxisRefNode.
396-
* \sa AxisTreeNode
397-
*/
398-
class AxisTree : public ObjectRef {
399-
public:
400-
TVM_DLL AxisTree(Array<String> axis_names, Array<Optional<String>> axis_parent_names);
401-
402-
TVM_DEFINE_OBJECT_REF_METHODS(AxisTree, ObjectRef, AxisTreeNode);
403-
};
404-
405381
/*!
406382
* \brief Class of sparse buffer.
407383
*/

include/tvm/tir/transform.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ TVM_DLL Pass ConvertForLoopsToSerial();
484484
* \param axis_tree The axis dependency tree.
485485
* \return The pass.
486486
*/
487-
TVM_DLL Pass LowerSparseTIR(AxisTree axis_tree);
487+
TVM_DLL Pass LowerSparseTIR();
488488

489489
} // namespace transform
490490
} // namespace tir

python/tvm/script/tir/special_stmt.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -874,6 +874,7 @@ class DenseVariable(SpecialStmt):
874874

875875
def __init__(self):
876876
def dense_variable(
877+
parent_axis: Axis,
877878
shape: Tuple[PrimExpr, PrimExpr],
878879
indptr_var: tvm.tir.Var,
879880
idtype: str = "int32",
@@ -885,11 +886,12 @@ def dense_variable(
885886
f"`dense_variable` expected assign to only one var, but got {names}", span
886887
)
887888

888-
length, indptr_len, nnz = shape
889+
length, nnz = shape
890+
indptr_len = parent_axis.length + 1
889891
indptr_buf = tvm.tir.decl_buffer(
890892
(indptr_len,), dtype=idtype, name=names[0] + "_indptr", span=span
891893
)
892-
axis = DenseVariableAxis(names[0], length, nnz, indptr_buf)
894+
axis = DenseVariableAxis(names[0], parent_axis, length, nnz, indptr_buf)
893895
self.context.sp_struct.append(axis)
894896
self.context.sp_struct_params.append([indptr_var])
895897
self.context.update_symbol(names[0], axis, self.node)
@@ -904,7 +906,8 @@ class SparseFixed(SpecialStmt):
904906

905907
def __init__(self):
906908
def sparse_fixed(
907-
shape: Tuple[PrimExpr, PrimExpr, PrimExpr],
909+
parent_axis: Axis,
910+
shape: Tuple[PrimExpr, PrimExpr],
908911
indices_var: tvm.tir.Var,
909912
idtype: str = "int32",
910913
span: Optional[Span] = None,
@@ -915,11 +918,12 @@ def sparse_fixed(
915918
f"`sparse_fixed` expected assign to only one var, but got {names}", span
916919
)
917920

918-
length, nnz, nnz_cols = shape
921+
length, nnz_cols = shape
922+
nnz = parent_axis.nnz * nnz_cols
919923
indices_buf = tvm.tir.decl_buffer(
920924
(nnz,), dtype=idtype, name=names[0] + "_indices", span=span
921925
)
922-
axis = SparseFixedAxis(names[0], length, indices_buf, nnz_cols)
926+
axis = SparseFixedAxis(names[0], parent_axis, length, indices_buf, nnz_cols)
923927
self.context.sp_struct.append(axis)
924928
self.context.sp_struct_params.append([indices_var])
925929
self.context.update_symbol(names[0], axis, self.node)
@@ -934,7 +938,8 @@ class SparseVariable(SpecialStmt):
934938

935939
def __init__(self):
936940
def sparse_variable(
937-
shape: Tuple[PrimExpr, PrimExpr, PrimExpr],
941+
parent_axis: Axis,
942+
shape: Tuple[PrimExpr, PrimExpr],
938943
data: Tuple[tvm.tir.Var, tvm.tir.Var],
939944
idtype: str = "int32",
940945
span: Optional[Span] = None,
@@ -945,15 +950,16 @@ def sparse_variable(
945950
f"`sparse_variable` expected assign to only one var, but got {names}", span
946951
)
947952

948-
length, indptr_len, nnz = shape
953+
length, nnz = shape
954+
indptr_len = parent_axis.nnz + 1
949955
indptr_var, indices_var = data
950956
indptr_buf = tvm.tir.decl_buffer(
951957
(indptr_len,), dtype=idtype, name=names[0] + "_indptr", span=span
952958
)
953959
indices_buf = tvm.tir.decl_buffer(
954960
(nnz,), dtype=idtype, name=names[0] + "_indices", span=span
955961
)
956-
axis = SparseVariableAxis(names[0], length, indptr_buf, indices_buf)
962+
axis = SparseVariableAxis(names[0], parent_axis, length, indptr_buf, indices_buf)
957963
self.context.sp_struct.append(axis)
958964
self.context.sp_struct_params.append([indptr_var, indices_var])
959965
self.context.update_symbol(names[0], axis, self.node)
@@ -971,10 +977,19 @@ def __init__(self):
971977
def match_sparse_buffer(
972978
param: tvm.tir.Var,
973979
axes: List[Axis],
974-
nnz: PrimExpr,
975980
dtype: str = "float32",
976981
span: Optional[Span] = None,
977982
):
983+
def infer_nnz(axes: List[Axis]) -> PrimExpr:
984+
"""Inference the number of non-zero elements in a sparse buffer."""
985+
ret = axes[0].nnz
986+
for axis in axes[1:]:
987+
if isinstance(axis, DenseFixedAxis):
988+
ret = ret * axis.nnz
989+
else:
990+
ret = axis.nnz
991+
return ret
992+
978993
if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1:
979994
self.context.report_error(
980995
"`match_sparse_buffer` must be assigned to a single sparse buffer, "
@@ -989,7 +1004,7 @@ def match_sparse_buffer(
9891004
)
9901005

9911006
if param in self.context.func_params:
992-
data = tvm.tir.decl_buffer(nnz, dtype, buffer_name + "_data", span=span)
1007+
data = tvm.tir.decl_buffer(infer_nnz(axes), dtype, buffer_name + "_data", span=span)
9931008
buffer = tvm.tir.sparse.SparseBuffer(axes, data, buffer_name)
9941009
self.context.sp_struct.append(buffer)
9951010
self.context.sp_struct_params.append([param])

python/tvm/tir/sparse.py

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ def length(self):
4242
@property
4343
def idtype(self):
4444
return _ffi_api.GetAxisIndexType(self)
45+
46+
@property
47+
def nnz(self):
48+
return _ffi_api.GetNNZ(self)
4549

4650

4751
@tvm._ffi.register_object("tir.sparse.DenseAxis")
@@ -117,6 +121,9 @@ class DenseVariableAxis(DenseAxis):
117121
----------
118122
name : str
119123
The name of the axis
124+
125+
parent : Axis
126+
The parent axis
120127
121128
length : PrimExpr
122129
The length of the axis
@@ -126,13 +133,14 @@ class DenseVariableAxis(DenseAxis):
126133
"""
127134

128135
name: str
136+
parent: Axis
129137
length: PrimExpr
130138
nnz: PrimExpr
131139
indptr: Buffer
132140

133-
def __init__(self, name, length, nnz, indptr):
141+
def __init__(self, name, parent, length, nnz, indptr):
134142
self.__init_handle_by_constructor__(
135-
_ffi_api.DenseVariableAxis, name, length, nnz, indptr # type: ignore
143+
_ffi_api.DenseVariableAxis, name, parent, length, nnz, indptr # type: ignore
136144
)
137145

138146

@@ -145,6 +153,9 @@ class SparseFixedAxis(DenseAxis):
145153
name : str
146154
The name of the axis
147155
156+
parent : Axis
157+
The parent axis
158+
148159
length : PrimExpr
149160
The length of the axis
150161
@@ -156,13 +167,14 @@ class SparseFixedAxis(DenseAxis):
156167
"""
157168

158169
name: str
170+
parent: Axis
159171
length: PrimExpr
160172
indices: Buffer
161173
nnz_cols: PrimExpr
162174

163-
def __init__(self, name, length, indices, nnz_cols):
175+
def __init__(self, name, parent, length, indices, nnz_cols):
164176
self.__init_handle_by_constructor__(
165-
_ffi_api.SparseFixedAxis, name, length, indices, nnz_cols # type: ignore
177+
_ffi_api.SparseFixedAxis, name, parent, length, indices, nnz_cols # type: ignore
166178
)
167179

168180

@@ -174,6 +186,9 @@ class SparseVariableAxis(DenseAxis):
174186
----------
175187
name : str
176188
The name of the axis
189+
190+
parent : Axis
191+
The parent axis
177192
178193
length : PrimExpr
179194
The length of the axis
@@ -186,33 +201,14 @@ class SparseVariableAxis(DenseAxis):
186201
"""
187202

188203
name: str
204+
parent: Axis
189205
length: PrimExpr
190206
indptr: Buffer
191207
indices: Buffer
192208

193-
def __init__(self, name, length, indptr, indices):
194-
self.__init_handle_by_constructor__(
195-
_ffi_api.SparseVariableAxis, name, length, indptr, indices # type: ignore
196-
)
197-
198-
199-
@tvm._ffi.register_object("tir.sparse.AxisTree")
200-
class AxisTree(Object):
201-
"""AxisTree node
202-
203-
Parameters
204-
----------
205-
axis_parent_map: Dict
206-
A dictionary that maps axis name to parent axis name, value is None if there is not parent axis.
207-
"""
208-
209-
axis_parent_map: Dict[str, Optional[str]]
210-
211-
def __init__(self, axis_parent_map) -> None:
212-
keys = list(axis_parent_map.keys())
213-
values = list(axis_parent_map.values())
209+
def __init__(self, name, parent, length, indptr, indices):
214210
self.__init_handle_by_constructor__(
215-
_ffi_api.AxisTree, keys, values # type:ignore
211+
_ffi_api.SparseVariableAxis, name, parent, length, indptr, indices # type: ignore
216212
)
217213

218214

0 commit comments

Comments
 (0)