Skip to content

Commit bd7b6f0

Browse files
committed
[SparseTIR] SparseTIR Lowering (apache#20)
* Fix a previous bug of sparse-fixed SpIterVar creation * Fix a previous bug in `GetDenseValue` * Refactor Collector and IndexTransformer * Construct block and loops * Fix a previous bug which rejects DV iters in collector * Update buffer map * Create root block * Fix bug of sparse-fixed SpIterVar creation * Fix bug on SpIterVar conversion (with refactor) * Fix bug when getting dependent SpIterVars * Fix bug on dependency map and index lowering * Full block read/write region * Test version 1 * Fix bug of loop order * Fix bug of batch-mm iterator ordering * Update PrimFunc args to use symbolic params * Fix bug of test "csr_element_wise" * Fix bug of index accumulation for sparse-fixed axis * Update correctness test * Test structural equality * Refactor and use Array
1 parent aa4e8f0 commit bd7b6f0

File tree

14 files changed

+843
-84
lines changed

14 files changed

+843
-84
lines changed

include/tvm/tir/sparse.h

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -131,13 +131,11 @@ class DenseFixedAxisNode : public DenseAxisNode {
131131
}
132132

133133
bool SEqualReduce(const DenseFixedAxisNode* other, SEqualReducer equal) const {
134-
equal->MarkGraphNode();
135134
return equal(name, other->name) && equal(length, other->length) &&
136135
equal(from_sparse, other->from_sparse);
137136
}
138137

139138
void SHashReduce(SHashReducer hash_reduce) const {
140-
hash_reduce->MarkGraphNode();
141139
hash_reduce(name);
142140
hash_reduce(length);
143141
hash_reduce(from_sparse);
@@ -170,12 +168,10 @@ class DenseVariableAxisNode : public DenseAxisNode {
170168
}
171169

172170
bool SEqualReduce(const DenseVariableAxisNode* other, SEqualReducer equal) const {
173-
equal->MarkGraphNode();
174171
return equal(name, other->name) && equal(length, other->length) && equal(indptr, other->indptr);
175172
}
176173

177174
void SHashReduce(SHashReducer hash_reduce) const {
178-
hash_reduce->MarkGraphNode();
179175
hash_reduce(name);
180176
hash_reduce(length);
181177
hash_reduce(indptr);
@@ -213,13 +209,11 @@ class SparseFixedAxisNode : public SparseAxisNode {
213209
}
214210

215211
bool SEqualReduce(const SparseFixedAxisNode* other, SEqualReducer equal) const {
216-
equal->MarkGraphNode();
217212
return equal(name, other->name) && equal(length, other->length) &&
218213
equal(indices, other->indices) && equal(num_cols, other->num_cols);
219214
}
220215

221216
void SHashReduce(SHashReducer hash_reduce) const {
222-
hash_reduce->MarkGraphNode();
223217
hash_reduce(name);
224218
hash_reduce(length);
225219
hash_reduce(indices);
@@ -257,13 +251,11 @@ class SparseVariableAxisNode : public SparseAxisNode {
257251
}
258252

259253
bool SEqualReduce(const SparseVariableAxisNode* other, SEqualReducer equal) const {
260-
equal->MarkGraphNode();
261254
return equal(name, other->name) && equal(length, other->length) &&
262255
equal(indptr, other->indptr) && equal(indices, other->indices);
263256
}
264257

265258
void SHashReduce(SHashReducer hash_reduce) const {
266-
hash_reduce->MarkGraphNode();
267259
hash_reduce(name);
268260
hash_reduce(length);
269261
hash_reduce(indptr);
@@ -347,12 +339,10 @@ class SparseBufferNode : public Object {
347339
}
348340

349341
bool SEqualReduce(const SparseBufferNode* other, SEqualReducer equal) const {
350-
equal->MarkGraphNode();
351342
return equal(axes, other->axes) && equal(data, other->data) && equal(name, other->name);
352343
}
353344

354345
void SHashReduce(SHashReducer hash_reduce) const {
355-
hash_reduce->MarkGraphNode();
356346
hash_reduce(axes);
357347
hash_reduce(data);
358348
hash_reduce(name);

include/tvm/tir/stmt.h

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1285,6 +1285,8 @@ class SparseBlockNode : public StmtNode {
12851285
public:
12861286
/*! \brief The sparse iteration variables of the block. */
12871287
Array<SpIterVar> sp_iter_vars;
1288+
/*! \brief The sparse data structures */
1289+
Array<ObjectRef> sp_structs;
12881290
/*! \brief The mapping from sparse data structures to the PrimFunc parameters */
12891291
Map<ObjectRef, Array<Var>> sp_struct2param_map;
12901292
/*! \brief The name of the block */
@@ -1296,6 +1298,7 @@ class SparseBlockNode : public StmtNode {
12961298

12971299
void VisitAttrs(AttrVisitor* v) {
12981300
v->Visit("sp_iter_vars", &sp_iter_vars);
1301+
v->Visit("sp_structs", &sp_structs);
12991302
v->Visit("sp_struct2param_map", &sp_struct2param_map);
13001303
v->Visit("name", &name);
13011304
v->Visit("body", &body);
@@ -1305,15 +1308,15 @@ class SparseBlockNode : public StmtNode {
13051308
bool SEqualReduce(const SparseBlockNode* other, SEqualReducer equal) const {
13061309
return equal(sp_iter_vars, other->sp_iter_vars) && equal(name, other->name) &&
13071310
equal(body, other->body) && equal(init, other->init) &&
1308-
equal(sp_struct2param_map, other->sp_struct2param_map);
1311+
equal(sp_structs, other->sp_structs);
13091312
}
13101313

13111314
void SHashReduce(SHashReducer hash_reduce) const {
13121315
hash_reduce(sp_iter_vars);
13131316
hash_reduce(name);
13141317
hash_reduce(body);
13151318
hash_reduce(init);
1316-
hash_reduce(sp_struct2param_map);
1319+
hash_reduce(sp_structs);
13171320
}
13181321

13191322
static constexpr const char* _type_key = "tir.SparseBlock";
@@ -1326,9 +1329,9 @@ class SparseBlockNode : public StmtNode {
13261329
*/
13271330
class SparseBlock : public Stmt {
13281331
public:
1329-
TVM_DLL explicit SparseBlock(Array<SpIterVar> sp_iter_vars,
1330-
Map<ObjectRef, Array<Var>> sp_struct2param_map, String name,
1331-
Stmt body, Optional<Stmt> init = NullOpt, Span span = Span());
1332+
TVM_DLL explicit SparseBlock(Array<SpIterVar> sp_iter_vars, Array<ObjectRef> sp_structs,
1333+
Array<Array<Var>> sp_struct_params, String name, Stmt body,
1334+
Optional<Stmt> init = NullOpt, Span span = Span());
13321335

13331336
TVM_DEFINE_OBJECT_REF_METHODS(SparseBlock, Stmt, SparseBlockNode);
13341337
TVM_DEFINE_OBJECT_REF_COW_METHOD(SparseBlockNode);

python/tvm/script/context_maintainer.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,10 @@ class ContextMaintainer:
134134
"""Mapping[Var, str]: The map from var to env thread"""
135135

136136
# sparse block context
137-
sp_struct2param_map: Mapping[Object, List[Var]] = {}
138-
"""Mapping[Object, List[Var]]: The mapping from sparse data structures to the func parameters"""
137+
sp_struct: List[Object] = []
138+
"""List[Object]: The sparse data structures"""
139+
sp_struct_params: List[List[Var]] = []
140+
"""List[List[Var]]: The function parameters that corresponding to each sparse data structures"""
139141

140142
# parser and analyzer
141143
analyzer: tvm.arith.Analyzer = tvm.arith.Analyzer()
@@ -159,7 +161,8 @@ def __init__(self, _report_error: Callable[[str, Union[Span, synr.ast.Span]], No
159161
self.func_dict_attr = {}
160162
self.func_var_env_dict = {}
161163
# sparse block context
162-
self.sp_struct2param_map = {}
164+
self.sp_struct = []
165+
self.sp_struct_params = []
163166
# parser and analyzer
164167
self._report_error = _report_error
165168
self.analyzer = tvm.arith.Analyzer()

python/tvm/script/tir/intrin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,6 @@ def pos(axis: Axis, span: Optional[Span] = None):
282282
elif isinstance(axis, DenseVariableAxis):
283283
return SpIterVar(var_temp, axis.length, SpIterVar.DenseVariable, False, axis)
284284
elif isinstance(axis, SparseFixedAxis):
285-
return SpIterVar(var_temp, axis.length, SpIterVar.SparseFixed, False, axis)
285+
return SpIterVar(var_temp, axis.num_cols, SpIterVar.SparseFixed, False, axis)
286286
else:
287287
return SpIterVar(var_temp, axis.length, SpIterVar.SparseVariable, False, axis)

python/tvm/script/tir/scope_handler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,8 @@ def iter(iters: List, iter_types: str, name: str = "", span: Optional[Span] = No
366366

367367
block = tvm.tir.SparseBlock(
368368
sp_iters,
369-
self.context.sp_struct2param_map,
369+
self.context.sp_struct,
370+
self.context.sp_struct_params,
370371
name,
371372
self.body,
372373
block_info.init,

python/tvm/script/tir/special_stmt.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -907,7 +907,8 @@ def dense_fixed(length: PrimExpr, span: Optional[Span] = None):
907907
)
908908

909909
axis = DenseFixedAxis(names[0], length)
910-
self.context.sp_struct2param_map[axis] = []
910+
self.context.sp_struct.append(axis)
911+
self.context.sp_struct_params.append([])
911912
self.context.update_symbol(names[0], axis, self.node)
912913

913914
super().__init__(dense_fixed, def_symbol=True)
@@ -935,7 +936,8 @@ def dense_variable(
935936
(indptr_len,), dtype=idtype, name=names[0] + "_indptr", span=span
936937
)
937938
axis = DenseVariableAxis(names[0], length, indptr_buf)
938-
self.context.sp_struct2param_map[axis] = [indptr_var]
939+
self.context.sp_struct.append(axis)
940+
self.context.sp_struct_params.append([indptr_var])
939941
self.context.update_symbol(names[0], axis, self.node)
940942
self.context.update_symbol(names[0] + "_indptr", indptr_buf, self.node)
941943

@@ -964,7 +966,8 @@ def sparse_fixed(
964966
(nnz,), dtype=idtype, name=names[0] + "_indices", span=span
965967
)
966968
axis = SparseFixedAxis(names[0], length, indices_buf, nnz_cols)
967-
self.context.sp_struct2param_map[axis] = [indices_var]
969+
self.context.sp_struct.append(axis)
970+
self.context.sp_struct_params.append([indices_var])
968971
self.context.update_symbol(names[0], axis, self.node)
969972
self.context.update_symbol(names[0] + "_indices", indices_buf, self.node)
970973

@@ -997,7 +1000,8 @@ def sparse_variable(
9971000
(nnz,), dtype=idtype, name=names[0] + "_indices", span=span
9981001
)
9991002
axis = SparseVariableAxis(names[0], length, indptr_buf, indices_buf)
1000-
self.context.sp_struct2param_map[axis] = [indptr_var, indices_var]
1003+
self.context.sp_struct.append(axis)
1004+
self.context.sp_struct_params.append([indptr_var, indices_var])
10011005
self.context.update_symbol(names[0], axis, self.node)
10021006
self.context.update_symbol(names[0] + "_indptr", indptr_buf, self.node)
10031007
self.context.update_symbol(names[0] + "_indices", indices_buf, self.node)
@@ -1033,7 +1037,8 @@ def match_sparse_buffer(
10331037
if param in self.context.func_params:
10341038
data = tvm.tir.decl_buffer(nnz, dtype, buffer_name + "_data", span=span)
10351039
buffer = tvm.tir.sparse.SparseBuffer(axes, data, buffer_name)
1036-
self.context.sp_struct2param_map[buffer] = [param]
1040+
self.context.sp_struct.append(buffer)
1041+
self.context.sp_struct_params.append([param])
10371042
self.context.update_symbol(buffer_name + "_data", data, self.node)
10381043
self.context.update_symbol(buffer_name, buffer, self.node)
10391044
else:

python/tvm/tir/sparse.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from .buffer import Buffer
2727

2828

29+
@tvm._ffi.register_object("tir.sparse.Axis")
2930
class Axis(Object):
3031
"""Base class of all the sparse axes."""
3132

@@ -42,10 +43,12 @@ def idtype(self):
4243
return _ffi_api.GetAxisIndexType(self)
4344

4445

46+
@tvm._ffi.register_object("tir.sparse.DenseAxis")
4547
class DenseAxis(Axis):
4648
pass
4749

4850

51+
@tvm._ffi.register_object("tir.sparse.SparseAxis")
4952
class SparseAxis(Axis):
5053
pass
5154

python/tvm/tir/stmt.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,12 @@ class SparseBlock(Stmt):
649649
sp_iter_vars : List[SpIterVar]
650650
The sparse iteration variables of the block.
651651
652+
sp_struct : List[Object]
653+
The sparse data structures
654+
655+
sp_struct_params : List[List[Var]]
656+
The function parameters that corresponding to each sparse data structures
657+
652658
sp_struct2param_map : Mapping[Object, List[Var]]
653659
The mapping from sparse data structures to the PrimFunc parameters.
654660
@@ -666,6 +672,7 @@ class SparseBlock(Stmt):
666672
"""
667673

668674
sp_iter_vars: List[SpIterVar]
675+
sp_struct: List[Object]
669676
sp_struct2param_map: Mapping[Object, List[Var]]
670677
name: str
671678
body: Stmt
@@ -675,7 +682,8 @@ class SparseBlock(Stmt):
675682
def __init__(
676683
self,
677684
sp_iter_vars: List[SpIterVar],
678-
sp_struct2param_map: Mapping[Object, List[Var]],
685+
sp_struct: List[Object],
686+
sp_struct_params: List[List[Var]],
679687
name: str,
680688
body: Stmt,
681689
init: Optional[Stmt] = None,
@@ -684,7 +692,8 @@ def __init__(
684692
self.__init_handle_by_constructor__(
685693
_ffi_api.SparseBlock, # type: ignore
686694
sp_iter_vars,
687-
sp_struct2param_map,
695+
sp_struct,
696+
sp_struct_params,
688697
name,
689698
body,
690699
init,

src/printer/tvmscript_printer.cc

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1358,12 +1358,14 @@ Doc TVMScriptPrinter::PrintSparseStructDefinitions(const SparseBlockNode* sp_blo
13581358
std::vector<Doc> axis_docs;
13591359
std::vector<Doc> sp_buf_docs;
13601360

1361-
for (auto it : sp_block->sp_struct2param_map) {
1361+
for (const ObjectRef& obj : sp_block->sp_structs) {
1362+
Array<Var> params = sp_block->sp_struct2param_map.Get(obj).value();
1363+
13621364
Doc doc;
1363-
doc << Print(it.first) << " = " << tir_prefix_ << ".";
1365+
doc << Print(obj) << " = " << tir_prefix_ << ".";
13641366

1365-
if (const auto* sp_buffer = it.first.as<SparseBufferNode>()) {
1366-
ICHECK_EQ(it.second.size(), 1);
1367+
if (const auto* sp_buffer = obj.as<SparseBufferNode>()) {
1368+
ICHECK_EQ(params.size(), 1);
13671369
Doc axes_doc;
13681370
if (sp_buffer->axes.size() != 1) {
13691371
std::vector<Doc> axes_docs;
@@ -1376,30 +1378,30 @@ Doc TVMScriptPrinter::PrintSparseStructDefinitions(const SparseBlockNode* sp_blo
13761378
axes_doc << Print(sp_buffer->axes[0]) << ",";
13771379
}
13781380

1379-
doc << "match_sparse_buffer(" << Print(it.second[0]) << ", (" << axes_doc << "), "
1381+
doc << "match_sparse_buffer(" << Print(params[0]) << ", (" << axes_doc << "), "
13801382
<< Print(sp_buffer->data->shape[0]) << ", " << PrintDType(sp_buffer->data->dtype) << ")";
13811383
sp_buf_docs.push_back(doc);
13821384
continue;
13831385
}
13841386

1385-
if (const auto* df_axis = it.first.as<DenseFixedAxisNode>()) {
1386-
ICHECK_EQ(it.second.size(), 0);
1387+
if (const auto* df_axis = obj.as<DenseFixedAxisNode>()) {
1388+
ICHECK_EQ(params.size(), 0);
13871389
doc << "dense_fixed(" << Print(df_axis->length) << ")";
1388-
} else if (const auto* dv_axis = it.first.as<DenseVariableAxisNode>()) {
1389-
ICHECK_EQ(it.second.size(), 1);
1390+
} else if (const auto* dv_axis = obj.as<DenseVariableAxisNode>()) {
1391+
ICHECK_EQ(params.size(), 1);
13901392
doc << "dense_variable((" << Print(dv_axis->length) << ", "
1391-
<< Print(dv_axis->indptr->shape[0]) << "), " << Print(it.second[0]) << ", "
1393+
<< Print(dv_axis->indptr->shape[0]) << "), " << Print(params[0]) << ", "
13921394
<< PrintDType(dv_axis->indptr->dtype) << ")";
1393-
} else if (const auto* sf_axis = it.first.as<SparseFixedAxisNode>()) {
1394-
ICHECK_EQ(it.second.size(), 1);
1395+
} else if (const auto* sf_axis = obj.as<SparseFixedAxisNode>()) {
1396+
ICHECK_EQ(params.size(), 1);
13951397
doc << "sparse_fixed((" << Print(sf_axis->length) << ", " << Print(sf_axis->indices->shape[0])
1396-
<< ", " << Print(sf_axis->num_cols) << "), " << Print(it.second[0]) << ", "
1398+
<< ", " << Print(sf_axis->num_cols) << "), " << Print(params[0]) << ", "
13971399
<< PrintDType(sf_axis->indices->dtype) << ")";
1398-
} else if (const auto* sv_axis = it.first.as<SparseVariableAxisNode>()) {
1399-
ICHECK_EQ(it.second.size(), 2);
1400+
} else if (const auto* sv_axis = obj.as<SparseVariableAxisNode>()) {
1401+
ICHECK_EQ(params.size(), 2);
14001402
doc << "sparse_variable((" << Print(sv_axis->length) << ", "
14011403
<< Print(sv_axis->indptr->shape[0]) << ", " << Print(sv_axis->indices->shape[0]) << "), ("
1402-
<< Print(it.second[0]) << ", " << Print(it.second[1]) << "), "
1404+
<< Print(params[0]) << ", " << Print(params[1]) << "), "
14031405
<< PrintDType(sv_axis->indptr->dtype) << ")";
14041406
} else {
14051407
ICHECK(false) << "Cannot reach here";

src/tir/ir/sparse.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,6 @@ SpIterVar::SpIterVar(Var var, PrimExpr max_extent, SpIterKind kind, bool is_redu
234234
ObjectPtr<SpIterVarNode> node = make_object<SpIterVarNode>();
235235

236236
arith::Analyzer ana;
237-
CHECK(ana.CanProveEqual(axis->length, max_extent));
238237
const char* err_str = "ValueError: The given kind doesn't match the type of the given axis";
239238
if (kind == SpIterKind::kDenseFixed) {
240239
CHECK(!axis->IsInstance<DenseVariableAxisNode>()) << err_str;

0 commit comments

Comments
 (0)