Skip to content

Commit d12dfdf

Browse files
MasterJH5574yzh119
authored andcommitted
[SparseTIR] Parser, Printer, Roundtrip (#14)
* SparseBlock scope handler (part 1) * SparseBlock scope handler (part 2) * SparseBlock scope handler (part 3) * SparseBlock scope handler (fix 1) * Add SparseBufferLoad/Store on Python side * Parser for SparseBufferLoad/Store * Add SparseBlock to Python __init__ * StmtFunctor for SparseBlock * Ensure at least one dimension for SparseBuffer * Make `axis` field of SpIterVar mandatory * SparseBlock scope handler (fix 2) * Update Axis syntax by removing `name` parameter * Move to intrin.py * Add filed `from_sparse` to DenseFixedAxis * SparseTIR script printer * Roundtrip test * `update_symbol` bug fix * Fix attr visit in SparseBuffer * Define then compare in SparseBlock * Fix printer bug for SparseBuffer * Enable graph match for Axis and SparseBuffer * Complete HashReduce and EqualReduce for AxisTree and SparseBuffer * Fix typo * Rename test * Bug fix 1 * Bug fix 2 * Add more tests
1 parent fffe67f commit d12dfdf

File tree

20 files changed

+783
-234
lines changed

20 files changed

+783
-234
lines changed

include/tvm/tir/expr.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -656,7 +656,7 @@ class BufferLoad : public PrimExpr {
656656
*/
657657
class SparseBufferLoadNode : public PrimExprNode {
658658
public:
659-
/*! \brief The buffer variable. */
659+
/*! \brief The buffer to be loaded. */
660660
SparseBuffer buffer;
661661
/*! \brief The indices location to be loaded. */
662662
Array<PrimExpr> indices;

include/tvm/tir/sparse.h

Lines changed: 60 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -99,23 +99,48 @@ class DenseAxis : public Axis {
9999
TVM_DEFINE_OBJECT_REF_METHODS(DenseAxis, Axis, DenseAxisNode);
100100
};
101101

102+
/*!
103+
* \brief Sparse axis whose column indices is not consecutive.
104+
*/
105+
class SparseAxisNode : public AxisNode {
106+
public:
107+
static constexpr const char* _type_key = "tir.sparse.SparseAxis";
108+
TVM_DECLARE_BASE_OBJECT_INFO(SparseAxisNode, AxisNode);
109+
};
110+
111+
/*!
112+
* \brief Managed reference to SparseAxisNode.
113+
* \sa SparseAxisNode
114+
*/
115+
class SparseAxis : public Axis {
116+
public:
117+
TVM_DEFINE_OBJECT_REF_METHODS(SparseAxis, Axis, SparseAxisNode);
118+
};
119+
102120
/*!
103121
* \brief Dense axis with fixed length per row.
104122
*/
105123
class DenseFixedAxisNode : public DenseAxisNode {
106124
public:
125+
Optional<SparseAxis> from_sparse;
126+
107127
void VisitAttrs(AttrVisitor* v) {
108128
v->Visit("name", &name);
109129
v->Visit("length", &length);
130+
v->Visit("from_sparse", &from_sparse);
110131
}
111132

112-
bool SEqualReduce(const DenseAxisNode* other, SEqualReducer equal) const {
113-
return equal(name, other->name) && equal(length, other->length);
133+
bool SEqualReduce(const DenseFixedAxisNode* other, SEqualReducer equal) const {
134+
equal->MarkGraphNode();
135+
return equal(name, other->name) && equal(length, other->length) &&
136+
equal(from_sparse, other->from_sparse);
114137
}
115138

116139
void SHashReduce(SHashReducer hash_reduce) const {
140+
hash_reduce->MarkGraphNode();
117141
hash_reduce(name);
118142
hash_reduce(length);
143+
hash_reduce(from_sparse);
119144
}
120145

121146
static constexpr const char* _type_key = "tir.sparse.DenseFixedAxis";
@@ -128,7 +153,8 @@ class DenseFixedAxisNode : public DenseAxisNode {
128153
*/
129154
class DenseFixedAxis : public DenseAxis {
130155
public:
131-
TVM_DLL explicit DenseFixedAxis(String name, PrimExpr length);
156+
TVM_DLL explicit DenseFixedAxis(String name, PrimExpr length,
157+
Optional<SparseAxis> from_sparse = NullOpt);
132158

133159
TVM_DEFINE_OBJECT_REF_METHODS(DenseFixedAxis, DenseAxis, DenseFixedAxisNode);
134160
};
@@ -144,10 +170,12 @@ class DenseVariableAxisNode : public DenseAxisNode {
144170
}
145171

146172
bool SEqualReduce(const DenseVariableAxisNode* other, SEqualReducer equal) const {
173+
equal->MarkGraphNode();
147174
return equal(name, other->name) && equal(length, other->length) && equal(indptr, other->indptr);
148175
}
149176

150177
void SHashReduce(SHashReducer hash_reduce) const {
178+
hash_reduce->MarkGraphNode();
151179
hash_reduce(name);
152180
hash_reduce(length);
153181
hash_reduce(indptr);
@@ -168,24 +196,6 @@ class DenseVariableAxis : public DenseAxis {
168196
TVM_DEFINE_OBJECT_REF_METHODS(DenseVariableAxis, DenseAxis, DenseVariableAxisNode);
169197
};
170198

171-
/*!
172-
* \brief Sparse axis whose column indices is not consecutive.
173-
*/
174-
class SparseAxisNode : public AxisNode {
175-
public:
176-
static constexpr const char* _type_key = "tir.sparse.SparseAxis";
177-
TVM_DECLARE_BASE_OBJECT_INFO(SparseAxisNode, AxisNode);
178-
};
179-
180-
/*!
181-
* \brief Managed reference to SparseAxisNode.
182-
* \sa SparseAxisNode
183-
*/
184-
class SparseAxis : public Axis {
185-
public:
186-
TVM_DEFINE_OBJECT_REF_METHODS(SparseAxis, Axis, SparseAxisNode);
187-
};
188-
189199
/*!
190200
* \brief Sparse axis with fixed number of non-zero columns per row.
191201
*/
@@ -203,11 +213,13 @@ class SparseFixedAxisNode : public SparseAxisNode {
203213
}
204214

205215
bool SEqualReduce(const SparseFixedAxisNode* other, SEqualReducer equal) const {
216+
equal->MarkGraphNode();
206217
return equal(name, other->name) && equal(length, other->length) &&
207218
equal(indices, other->indices) && equal(num_cols, other->num_cols);
208219
}
209220

210221
void SHashReduce(SHashReducer hash_reduce) const {
222+
hash_reduce->MarkGraphNode();
211223
hash_reduce(name);
212224
hash_reduce(length);
213225
hash_reduce(indices);
@@ -245,11 +257,13 @@ class SparseVariableAxisNode : public SparseAxisNode {
245257
}
246258

247259
bool SEqualReduce(const SparseVariableAxisNode* other, SEqualReducer equal) const {
260+
equal->MarkGraphNode();
248261
return equal(name, other->name) && equal(length, other->length) &&
249262
equal(indptr, other->indptr) && equal(indices, other->indices);
250263
}
251264

252265
void SHashReduce(SHashReducer hash_reduce) const {
266+
hash_reduce->MarkGraphNode();
253267
hash_reduce(name);
254268
hash_reduce(length);
255269
hash_reduce(indptr);
@@ -277,13 +291,27 @@ class SparseVariableAxis : public SparseAxis {
277291
class AxisTreeNode : public Object {
278292
public:
279293
// unordered map that stores the parent relationship between axes.
280-
std::unordered_map<String, Optional<String>, ObjectPtrHash, ObjectPtrEqual> parent;
294+
Map<String, Optional<String>> parent;
281295
// unordered map that stores the children relationship between axes.
282-
std::unordered_map<Optional<String>, Array<String>, ObjectPtrHash, ObjectPtrEqual> children;
296+
Map<Optional<String>, Array<String>> children;
297+
298+
void VisitAttrs(AttrVisitor* v) {
299+
v->Visit("parent", &parent);
300+
v->Visit("children", &children);
301+
}
302+
303+
bool SEqualReduce(const AxisTreeNode* other, SEqualReducer equal) const {
304+
return equal(parent, other->parent) && equal(children, other->children);
305+
}
283306

284-
void VisitAttrs(AttrVisitor* v) {}
307+
void SHashReduce(SHashReducer hash_reduce) const {
308+
hash_reduce(parent);
309+
hash_reduce(children);
310+
}
285311

286312
static constexpr const char* _type_key = "tir.sparse.AxisTree";
313+
static constexpr const bool _type_has_method_sequal_reduce = true;
314+
static constexpr const bool _type_has_method_shash_reduce = true;
287315
TVM_DECLARE_FINAL_OBJECT_INFO(AxisTreeNode, Object);
288316
};
289317

@@ -313,22 +341,26 @@ class SparseBufferNode : public Object {
313341
inline int ndim() const { return static_cast<int>(axes.size()); }
314342

315343
void VisitAttrs(AttrVisitor* v) {
316-
v->Visit("length", &axes);
317-
v->Visit("num_cols", &data);
344+
v->Visit("axes", &axes);
345+
v->Visit("data", &data);
318346
v->Visit("name", &name);
319347
}
320348

321349
bool SEqualReduce(const SparseBufferNode* other, SEqualReducer equal) const {
350+
equal->MarkGraphNode();
322351
return equal(axes, other->axes) && equal(data, other->data) && equal(name, other->name);
323352
}
324353

325354
void SHashReduce(SHashReducer hash_reduce) const {
355+
hash_reduce->MarkGraphNode();
326356
hash_reduce(axes);
327357
hash_reduce(data);
328358
hash_reduce(name);
329359
}
330360

331361
static constexpr const char* _type_key = "tir.sparse.SparseBuffer";
362+
static constexpr const bool _type_has_method_sequal_reduce = true;
363+
static constexpr const bool _type_has_method_shash_reduce = true;
332364
TVM_DECLARE_FINAL_OBJECT_INFO(SparseBufferNode, Object);
333365
};
334366

@@ -359,7 +391,7 @@ class SpIterVarNode : public Object {
359391
PrimExpr max_extent;
360392
SpIterKind kind;
361393
bool is_reduction;
362-
Optional<Axis> axis;
394+
Axis axis;
363395

364396
void VisitAttrs(AttrVisitor* v) {
365397
v->Visit("var", &var);
@@ -392,7 +424,7 @@ class SpIterVarNode : public Object {
392424
class SpIterVar : public ObjectRef {
393425
public:
394426
TVM_DLL explicit SpIterVar(Var var, PrimExpr max_extent, SpIterKind kind, bool is_reduction,
395-
Optional<Axis> axis = NullOpt);
427+
Axis axis);
396428

397429
/*!
398430
* \return the corresponding var in the IterVar.

include/tvm/tir/stmt.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -335,11 +335,11 @@ class BufferStore : public Stmt {
335335
* buffer[i, j] = value;
336336
*
337337
* \endcode
338-
* \sa SparseBufferLoad
338+
* \sa SparseBufferStore
339339
*/
340340
class SparseBufferStoreNode : public StmtNode {
341341
public:
342-
/*! \brief The buffer variable. */
342+
/*! \brief The sparse buffer to be accessed. */
343343
SparseBuffer buffer;
344344
/*! \brief The value to be stored. */
345345
PrimExpr value;
@@ -1303,17 +1303,17 @@ class SparseBlockNode : public StmtNode {
13031303
}
13041304

13051305
bool SEqualReduce(const SparseBlockNode* other, SEqualReducer equal) const {
1306-
return equal(sp_iter_vars, other->sp_iter_vars) &&
1307-
equal(sp_struct2param_map, other->sp_struct2param_map) && equal(name, other->name) &&
1308-
equal(body, other->body) && equal(init, other->init);
1306+
return equal(sp_iter_vars, other->sp_iter_vars) && equal(name, other->name) &&
1307+
equal(body, other->body) && equal(init, other->init) &&
1308+
equal(sp_struct2param_map, other->sp_struct2param_map);
13091309
}
13101310

13111311
void SHashReduce(SHashReducer hash_reduce) const {
13121312
hash_reduce(sp_iter_vars);
1313-
hash_reduce(sp_struct2param_map);
13141313
hash_reduce(name);
13151314
hash_reduce(body);
13161315
hash_reduce(init);
1316+
hash_reduce(sp_struct2param_map);
13171317
}
13181318

13191319
static constexpr const char* _type_key = "tir.SparseBlock";

include/tvm/tir/stmt_functor.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
9999
virtual R VisitStmt_(const EvaluateNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
100100
virtual R VisitStmt_(const BlockNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
101101
virtual R VisitStmt_(const BlockRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
102+
virtual R VisitStmt_(const SparseBlockNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
102103
virtual R VisitStmtDefault_(const Object* op, Args...) {
103104
LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
104105
return R();
@@ -126,6 +127,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
126127
IR_STMT_FUNCTOR_DISPATCH(BufferRealizeNode);
127128
IR_STMT_FUNCTOR_DISPATCH(BlockNode);
128129
IR_STMT_FUNCTOR_DISPATCH(BlockRealizeNode);
130+
IR_STMT_FUNCTOR_DISPATCH(SparseBlockNode);
129131
return vtable;
130132
}
131133
};
@@ -169,6 +171,7 @@ class TVM_DLL StmtVisitor : protected StmtFunctor<void(const Stmt&)> {
169171
void VisitStmt_(const EvaluateNode* op) override;
170172
void VisitStmt_(const BlockNode* op) override;
171173
void VisitStmt_(const BlockRealizeNode* op) override;
174+
void VisitStmt_(const SparseBlockNode* op) override;
172175
};
173176

174177
/*!
@@ -270,6 +273,7 @@ class TVM_DLL StmtMutator : protected StmtFunctor<Stmt(const Stmt&)> {
270273
Stmt VisitStmt_(const EvaluateNode* op) override;
271274
Stmt VisitStmt_(const BlockNode* op) override;
272275
Stmt VisitStmt_(const BlockRealizeNode* op) override;
276+
Stmt VisitStmt_(const SparseBlockNode* op) override;
273277
/*!
274278
* \brief Alternative advance method for SeqStmtNode.
275279
*

python/tvm/script/context_maintainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def update_symbol(
219219
self, name: str, symbol: Union[Buffer, Var, SparseBuffer, Axis], node: synr.ast.Node
220220
):
221221
"""Append a symbol into current scope"""
222-
if isinstance(symbol, (Buffer, Var, SparseBuffer, Axis)):
222+
if isinstance(symbol, (Buffer, SparseBuffer, Axis)):
223223
if name in self.symbols[0]:
224224
self.report_error("Duplicate Buffer name: " + symbol.name, node.span)
225225
self.symbols[0][name] = symbol

python/tvm/script/parser.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,14 @@ def transform_SubscriptAssign(self, node):
582582
indexes,
583583
span=tvm_span_from_synr(node.span),
584584
)
585+
elif isinstance(symbol, tvm.tir.sparse.SparseBuffer):
586+
# SparseBufferStore
587+
return tvm.tir.SparseBufferStore(
588+
symbol,
589+
tvm.runtime.convert(rhs, span=rhs_span),
590+
indexes,
591+
span=tvm_span_from_synr(node.span),
592+
)
585593
else:
586594
if len(indexes) != 1:
587595
self.report_error(
@@ -876,6 +884,8 @@ def transform_Subscript(self, node):
876884
return BufferSlice(
877885
symbol, indexes, self.report_error, span=tvm_span_from_synr(node.span)
878886
)
887+
elif isinstance(symbol, tvm.tir.sparse.SparseBuffer):
888+
return tvm.tir.SparseBufferLoad(symbol, indexes, span=tvm_span_from_synr(node.span))
879889
elif isinstance(symbol, tvm.container.Array):
880890
if len(indexes) > 1:
881891
self.report_error(

python/tvm/script/tir/intrin.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,18 @@
1717
"""TVM Script Parser Intrinsic Classes"""
1818
# pylint: disable=redefined-builtin, relative-beyond-top-level
1919
import builtins
20-
from typing import List, Any
20+
from typing import List, Optional, Any
2121

2222
import tvm.tir
23+
from tvm.ir import Span
24+
from tvm.tir.sparse import (
25+
Axis,
26+
DenseFixedAxis,
27+
DenseVariableAxis,
28+
SpIterVar,
29+
SparseFixedAxis,
30+
SparseVariableAxis,
31+
)
2332
from ..registry import register
2433
from ..utils import get_param_list, tvm_span_from_synr
2534

@@ -244,3 +253,35 @@ def comm_reducer(lambda_io, identities, span):
244253
lambda_output = (lambda_output,)
245254

246255
return tvm.tir.CommReducer(x, y, lambda_output, identities, span)
256+
257+
258+
@register
259+
def to_dense(axis: Axis, span: Optional[Span] = None):
260+
if isinstance(axis, (SparseFixedAxis, SparseVariableAxis)):
261+
return DenseFixedAxis(axis.name + "_dense", axis.length, axis)
262+
else:
263+
return axis
264+
265+
266+
@register
267+
def cord(axis: Axis, span: Optional[Span] = None):
268+
# The field `var` and `is_reduction` will be updated in SparseBlock scope handler
269+
var_temp = tvm.te.var()
270+
if isinstance(axis, DenseVariableAxis):
271+
return SpIterVar(var_temp, axis.length, SpIterVar.DenseVariable, False, axis)
272+
else:
273+
return SpIterVar(var_temp, axis.length, SpIterVar.DenseFixed, False, axis)
274+
275+
276+
@register
277+
def pos(axis: Axis, span: Optional[Span] = None):
278+
# The field `var` and `is_reduction` will be updated in SparseBlock scope handler
279+
var_temp = tvm.te.var()
280+
if isinstance(axis, DenseFixedAxis):
281+
return SpIterVar(var_temp, axis.length, SpIterVar.DenseFixed, False, axis)
282+
elif isinstance(axis, DenseVariableAxis):
283+
return SpIterVar(var_temp, axis.length, SpIterVar.DenseVariable, False, axis)
284+
elif isinstance(axis, SparseFixedAxis):
285+
return SpIterVar(var_temp, axis.length, SpIterVar.SparseFixed, False, axis)
286+
else:
287+
return SpIterVar(var_temp, axis.length, SpIterVar.SparseVariable, False, axis)

0 commit comments

Comments
 (0)