Skip to content

Commit fffe67f

Browse files
MasterJH5574yzh119
authored andcommitted
[SparseTIR] Enhance SparseBlock to contain enough PrimFunc information (#13)
* Enhance SparseBlock to have enough PrimFunc info * Remove `func_sparse_buffer_map_` * Don't print the map uh-huh
1 parent f737522 commit fffe67f

File tree

5 files changed

+35
-30
lines changed

5 files changed

+35
-30
lines changed

include/tvm/tir/stmt.h

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1285,8 +1285,8 @@ class SparseBlockNode : public StmtNode {
12851285
public:
12861286
/*! \brief The sparse iteration variables of the block. */
12871287
Array<SpIterVar> sp_iter_vars;
1288-
/*! \brief The sparse buffers defined in the block. */
1289-
Array<SparseBuffer> sp_buffers;
1288+
/*! \brief The mapping from sparse data structures to the PrimFunc parameters */
1289+
Map<ObjectRef, Array<Var>> sp_struct2param_map;
12901290
/*! \brief The name of the block */
12911291
String name;
12921292
/*! \brief The body of the block */
@@ -1296,20 +1296,21 @@ class SparseBlockNode : public StmtNode {
12961296

12971297
void VisitAttrs(AttrVisitor* v) {
12981298
v->Visit("sp_iter_vars", &sp_iter_vars);
1299-
v->Visit("sp_buffers", &sp_buffers);
1299+
v->Visit("sp_struct2param_map", &sp_struct2param_map);
13001300
v->Visit("name", &name);
13011301
v->Visit("body", &body);
13021302
v->Visit("init", &init);
13031303
}
13041304

13051305
bool SEqualReduce(const SparseBlockNode* other, SEqualReducer equal) const {
1306-
return equal(sp_iter_vars, other->sp_iter_vars) && equal(sp_buffers, other->sp_buffers) &&
1307-
equal(name, other->name) && equal(body, other->body) && equal(init, other->init);
1306+
return equal(sp_iter_vars, other->sp_iter_vars) &&
1307+
equal(sp_struct2param_map, other->sp_struct2param_map) && equal(name, other->name) &&
1308+
equal(body, other->body) && equal(init, other->init);
13081309
}
13091310

13101311
void SHashReduce(SHashReducer hash_reduce) const {
13111312
hash_reduce(sp_iter_vars);
1312-
hash_reduce(sp_buffers);
1313+
hash_reduce(sp_struct2param_map);
13131314
hash_reduce(name);
13141315
hash_reduce(body);
13151316
hash_reduce(init);
@@ -1325,9 +1326,9 @@ class SparseBlockNode : public StmtNode {
13251326
*/
13261327
class SparseBlock : public Stmt {
13271328
public:
1328-
TVM_DLL explicit SparseBlock(Array<SpIterVar> sp_iter_vars, Array<SparseBuffer> sp_buffers,
1329-
String name, Stmt body, Optional<Stmt> init = NullOpt,
1330-
Span span = Span());
1329+
TVM_DLL explicit SparseBlock(Array<SpIterVar> sp_iter_vars,
1330+
Map<ObjectRef, Array<Var>> sp_struct2param_map, String name,
1331+
Stmt body, Optional<Stmt> init = NullOpt, Span span = Span());
13311332

13321333
TVM_DEFINE_OBJECT_REF_METHODS(SparseBlock, Stmt, SparseBlockNode);
13331334
TVM_DEFINE_OBJECT_REF_COW_METHOD(SparseBlockNode);

python/tvm/script/context_maintainer.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,13 +128,15 @@ class ContextMaintainer:
128128
"""List[Var]: The function parameters"""
129129
func_buffer_map: Mapping[Var, Buffer] = {}
130130
"""Mapping[Var, Buffer]: The function buffer map"""
131-
func_sparse_buffer_map: Mapping[Var, SparseBuffer] = {}
132-
"""Mapping[Var, SparseBuffer]: The function sparse buffer map"""
133131
func_dict_attr: Mapping[str, Object] = {}
134132
"""Mapping[str, Object]: The function attrs"""
135133
func_var_env_dict: Mapping[Var, str] = {}
136134
"""Mapping[Var, str]: The map from var to env thread"""
137135

136+
# sparse block context
137+
sp_struct2param_map: Mapping[Object, List[Var]] = {}
138+
"""Mapping[Object, List[Var]]: The mapping from sparse data structures to the func parameters"""
139+
138140
# parser and analyzer
139141
analyzer: tvm.arith.Analyzer = tvm.arith.Analyzer()
140142
"""tvm.arith.Analyzer: The analyzer for simplifying"""
@@ -154,9 +156,10 @@ def __init__(self, _report_error: Callable[[str, Union[Span, synr.ast.Span]], No
154156
# function context
155157
self.func_params = []
156158
self.func_buffer_map = {}
157-
self.func_sparse_buffer_map = {}
158159
self.func_dict_attr = {}
159160
self.func_var_env_dict = {}
161+
# sparse block context
162+
self.sp_struct2param_map = {}
160163
# parser and analyzer
161164
self._report_error = _report_error
162165
self.analyzer = tvm.arith.Analyzer()

python/tvm/script/tir/special_stmt.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -857,6 +857,7 @@ def __init__(self):
857857
def dense_fixed(name: str, length: PrimExpr, span: Optional[Span] = None):
858858
var_name = self.node.lhs[0].id.name
859859
axis = DenseFixedAxis(name, length)
860+
self.context.sp_struct2param_map[axis] = []
860861
self.context.update_symbol(var_name, axis, self.node)
861862

862863
super().__init__(dense_fixed, def_symbol=True)
@@ -880,7 +881,7 @@ def dense_variable(
880881
(indptr_len,), dtype=idtype, name=name + "_indptr", span=span
881882
)
882883
axis = DenseVariableAxis(name, length, indptr_buf)
883-
self.context.func_buffer_map[indptr_var] = indptr_buf
884+
self.context.sp_struct2param_map[axis] = indptr_var
884885
self.context.update_symbol(var_name, axis, self.node)
885886
self.context.update_symbol(name + "_indptr", indptr_buf, self.node)
886887

@@ -905,7 +906,7 @@ def sparse_fixed(
905906
(nnz,), dtype=idtype, name=name + "_indices", span=span
906907
)
907908
axis = SparseFixedAxis(name, length, indices_buf, nnz_cols)
908-
self.context.func_buffer_map[indices_var] = indices_buf
909+
self.context.sp_struct2param_map[axis] = [indices_var]
909910
self.context.update_symbol(var_name, axis, self.node)
910911
self.context.update_symbol(name + "_indices", indices_buf, self.node)
911912

@@ -934,8 +935,7 @@ def sparse_variable(
934935
(nnz,), dtype=idtype, name=name + "_indices", span=span
935936
)
936937
axis = SparseVariableAxis(name, length, indptr_buf, indices_buf)
937-
self.context.func_buffer_map[indices_var] = indices_buf
938-
self.context.func_buffer_map[indptr_var] = indptr_buf
938+
self.context.sp_struct2param_map[axis] = [indptr_var, indices_var]
939939
self.context.update_symbol(var_name, axis, self.node)
940940
self.context.update_symbol(name + "_indptr", indptr_buf, self.node)
941941
self.context.update_symbol(name + "_indices", indices_buf, self.node)
@@ -971,8 +971,7 @@ def match_sparse_buffer(
971971
if param in self.context.func_params:
972972
data = tvm.tir.decl_buffer(nnz, dtype, buffer_name + "_data", span=span)
973973
buffer = tvm.tir.sparse.SparseBuffer(axes, data, buffer_name)
974-
self.context.func_buffer_map[param] = data
975-
self.context.func_sparse_buffer_map[param] = buffer
974+
self.context.sp_struct2param_map[buffer] = [param]
976975
self.context.update_symbol(buffer_name + "_data", data, self.node)
977976
self.context.update_symbol(buffer_name, buffer, self.node)
978977
else:

python/tvm/tir/stmt.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535

3636
from . import _ffi_api
3737
from .buffer import Buffer
38-
from .expr import IterVar
38+
from .expr import Var, IterVar
3939
from .sparse import SpIterVar, SparseBuffer
4040

4141

@@ -624,8 +624,8 @@ class SparseBlock(Stmt):
624624
sp_iter_vars : List[SpIterVar]
625625
The sparse iteration variables of the block.
626626
627-
sp_buffers : List[SparseBuffer]
628-
The sparse buffers defined in the block.
627+
sp_struct2param_map : Mapping[Object, List[Var]]
628+
The mapping from sparse data structures to the PrimFunc parameters.
629629
630630
name : str
631631
The name of the block.
@@ -641,7 +641,7 @@ class SparseBlock(Stmt):
641641
"""
642642

643643
sp_iter_vars: List[SpIterVar]
644-
sp_buffers: List[SparseBuffer]
644+
sp_struct2param_map: Mapping[Object, List[Var]]
645645
name: str
646646
body: Stmt
647647
init: Optional[Stmt]
@@ -650,7 +650,7 @@ class SparseBlock(Stmt):
650650
def __init__(
651651
self,
652652
sp_iter_vars: List[SpIterVar],
653-
sp_buffers: List[SparseBuffer],
653+
sp_struct2param_map: Mapping[Object, List[Var]],
654654
name: str,
655655
body: Stmt,
656656
init: Optional[Stmt] = None,
@@ -659,7 +659,7 @@ def __init__(
659659
self.__init_handle_by_constructor__(
660660
_ffi_api.SparseBlock, # type: ignore
661661
sp_iter_vars,
662-
sp_buffers,
662+
sp_struct2param_map,
663663
name,
664664
body,
665665
init,

src/tir/ir/stmt.cc

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -968,11 +968,12 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
968968
p->stream << "}\n";
969969
});
970970

971-
SparseBlock::SparseBlock(Array<SpIterVar> sp_iter_vars, Array<SparseBuffer> sp_buffers, String name,
972-
Stmt body, Optional<Stmt> init, Span span) {
971+
SparseBlock::SparseBlock(Array<SpIterVar> sp_iter_vars,
972+
Map<ObjectRef, Array<Var>> sp_struct2param_map, String name, Stmt body,
973+
Optional<Stmt> init, Span span) {
973974
ObjectPtr<SparseBlockNode> node = make_object<SparseBlockNode>();
974975
node->sp_iter_vars = std::move(sp_iter_vars);
975-
node->sp_buffers = std::move(sp_buffers);
976+
node->sp_struct2param_map = std::move(sp_struct2param_map);
976977
node->name = std::move(name);
977978
node->body = std::move(body);
978979
node->init = std::move(init);
@@ -981,9 +982,10 @@ SparseBlock::SparseBlock(Array<SpIterVar> sp_iter_vars, Array<SparseBuffer> sp_b
981982
}
982983

983984
TVM_REGISTER_GLOBAL("tir.SparseBlock")
984-
.set_body_typed([](Array<SpIterVar> sp_iter_vars, Array<SparseBuffer> sp_buffers, String name,
985-
Stmt body, Optional<Stmt> init, Span span) {
986-
return SparseBlock(sp_iter_vars, sp_buffers, name, body, init, span);
985+
.set_body_typed([](Array<SpIterVar> sp_iter_vars,
986+
Map<ObjectRef, Array<Var>> sp_struct2param_map, String name, Stmt body,
987+
Optional<Stmt> init, Span span) {
988+
return SparseBlock(sp_iter_vars, sp_struct2param_map, name, body, init, span);
987989
});
988990

989991
TVM_REGISTER_NODE_TYPE(SparseBlockNode);

0 commit comments

Comments
 (0)