Skip to content

Commit 3c3bebf

Browse files
MasterJH5574yzh119
authored andcommitted
[SparseTIR] Enhance SparseBlock to contain enough PrimFunc information (#13)
* Enhance SparseBlock to have enough PrimFunc info * Remove `func_sparse_buffer_map_` * Don't print the map uh-huh
1 parent f4b3867 commit 3c3bebf

File tree

5 files changed

+35
-30
lines changed

5 files changed

+35
-30
lines changed

include/tvm/tir/stmt.h

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1285,8 +1285,8 @@ class SparseBlockNode : public StmtNode {
12851285
public:
12861286
/*! \brief The sparse iteration variables of the block. */
12871287
Array<SpIterVar> sp_iter_vars;
1288-
/*! \brief The sparse buffers defined in the block. */
1289-
Array<SparseBuffer> sp_buffers;
1288+
/*! \brief The mapping from sparse data structures to the PrimFunc parameters */
1289+
Map<ObjectRef, Array<Var>> sp_struct2param_map;
12901290
/*! \brief The name of the block */
12911291
String name;
12921292
/*! \brief The body of the block */
@@ -1296,20 +1296,21 @@ class SparseBlockNode : public StmtNode {
12961296

12971297
void VisitAttrs(AttrVisitor* v) {
12981298
v->Visit("sp_iter_vars", &sp_iter_vars);
1299-
v->Visit("sp_buffers", &sp_buffers);
1299+
v->Visit("sp_struct2param_map", &sp_struct2param_map);
13001300
v->Visit("name", &name);
13011301
v->Visit("body", &body);
13021302
v->Visit("init", &init);
13031303
}
13041304

13051305
bool SEqualReduce(const SparseBlockNode* other, SEqualReducer equal) const {
1306-
return equal(sp_iter_vars, other->sp_iter_vars) && equal(sp_buffers, other->sp_buffers) &&
1307-
equal(name, other->name) && equal(body, other->body) && equal(init, other->init);
1306+
return equal(sp_iter_vars, other->sp_iter_vars) &&
1307+
equal(sp_struct2param_map, other->sp_struct2param_map) && equal(name, other->name) &&
1308+
equal(body, other->body) && equal(init, other->init);
13081309
}
13091310

13101311
void SHashReduce(SHashReducer hash_reduce) const {
13111312
hash_reduce(sp_iter_vars);
1312-
hash_reduce(sp_buffers);
1313+
hash_reduce(sp_struct2param_map);
13131314
hash_reduce(name);
13141315
hash_reduce(body);
13151316
hash_reduce(init);
@@ -1325,9 +1326,9 @@ class SparseBlockNode : public StmtNode {
13251326
*/
13261327
class SparseBlock : public Stmt {
13271328
public:
1328-
TVM_DLL explicit SparseBlock(Array<SpIterVar> sp_iter_vars, Array<SparseBuffer> sp_buffers,
1329-
String name, Stmt body, Optional<Stmt> init = NullOpt,
1330-
Span span = Span());
1329+
TVM_DLL explicit SparseBlock(Array<SpIterVar> sp_iter_vars,
1330+
Map<ObjectRef, Array<Var>> sp_struct2param_map, String name,
1331+
Stmt body, Optional<Stmt> init = NullOpt, Span span = Span());
13311332

13321333
TVM_DEFINE_OBJECT_REF_METHODS(SparseBlock, Stmt, SparseBlockNode);
13331334
TVM_DEFINE_OBJECT_REF_COW_METHOD(SparseBlockNode);

python/tvm/script/context_maintainer.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,13 +128,15 @@ class ContextMaintainer:
128128
"""List[Var]: The function parameters"""
129129
func_buffer_map: Mapping[Var, Buffer] = {}
130130
"""Mapping[Var, Buffer]: The function buffer map"""
131-
func_sparse_buffer_map: Mapping[Var, SparseBuffer] = {}
132-
"""Mapping[Var, SparseBuffer]: The function sparse buffer map"""
133131
func_dict_attr: Mapping[str, Object] = {}
134132
"""Mapping[str, Object]: The function attrs"""
135133
func_var_env_dict: Mapping[Var, str] = {}
136134
"""Mapping[Var, str]: The map from var to env thread"""
137135

136+
# sparse block context
137+
sp_struct2param_map: Mapping[Object, List[Var]] = {}
138+
"""Mapping[Object, List[Var]]: The mapping from sparse data structures to the func parameters"""
139+
138140
# parser and analyzer
139141
analyzer: tvm.arith.Analyzer = tvm.arith.Analyzer()
140142
"""tvm.arith.Analyzer: The analyzer for simplifying"""
@@ -154,9 +156,10 @@ def __init__(self, _report_error: Callable[[str, Union[Span, synr.ast.Span]], No
154156
# function context
155157
self.func_params = []
156158
self.func_buffer_map = {}
157-
self.func_sparse_buffer_map = {}
158159
self.func_dict_attr = {}
159160
self.func_var_env_dict = {}
161+
# sparse block context
162+
self.sp_struct2param_map = {}
160163
# parser and analyzer
161164
self._report_error = _report_error
162165
self.analyzer = tvm.arith.Analyzer()

python/tvm/script/tir/special_stmt.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -903,6 +903,7 @@ def __init__(self):
903903
def dense_fixed(name: str, length: PrimExpr, span: Optional[Span] = None):
904904
var_name = self.node.lhs[0].id.name
905905
axis = DenseFixedAxis(name, length)
906+
self.context.sp_struct2param_map[axis] = []
906907
self.context.update_symbol(var_name, axis, self.node)
907908

908909
super().__init__(dense_fixed, def_symbol=True)
@@ -926,7 +927,7 @@ def dense_variable(
926927
(indptr_len,), dtype=idtype, name=name + "_indptr", span=span
927928
)
928929
axis = DenseVariableAxis(name, length, indptr_buf)
929-
self.context.func_buffer_map[indptr_var] = indptr_buf
930+
self.context.sp_struct2param_map[axis] = indptr_var
930931
self.context.update_symbol(var_name, axis, self.node)
931932
self.context.update_symbol(name + "_indptr", indptr_buf, self.node)
932933

@@ -951,7 +952,7 @@ def sparse_fixed(
951952
(nnz,), dtype=idtype, name=name + "_indices", span=span
952953
)
953954
axis = SparseFixedAxis(name, length, indices_buf, nnz_cols)
954-
self.context.func_buffer_map[indices_var] = indices_buf
955+
self.context.sp_struct2param_map[axis] = [indices_var]
955956
self.context.update_symbol(var_name, axis, self.node)
956957
self.context.update_symbol(name + "_indices", indices_buf, self.node)
957958

@@ -980,8 +981,7 @@ def sparse_variable(
980981
(nnz,), dtype=idtype, name=name + "_indices", span=span
981982
)
982983
axis = SparseVariableAxis(name, length, indptr_buf, indices_buf)
983-
self.context.func_buffer_map[indices_var] = indices_buf
984-
self.context.func_buffer_map[indptr_var] = indptr_buf
984+
self.context.sp_struct2param_map[axis] = [indptr_var, indices_var]
985985
self.context.update_symbol(var_name, axis, self.node)
986986
self.context.update_symbol(name + "_indptr", indptr_buf, self.node)
987987
self.context.update_symbol(name + "_indices", indices_buf, self.node)
@@ -1017,8 +1017,7 @@ def match_sparse_buffer(
10171017
if param in self.context.func_params:
10181018
data = tvm.tir.decl_buffer(nnz, dtype, buffer_name + "_data", span=span)
10191019
buffer = tvm.tir.sparse.SparseBuffer(axes, data, buffer_name)
1020-
self.context.func_buffer_map[param] = data
1021-
self.context.func_sparse_buffer_map[param] = buffer
1020+
self.context.sp_struct2param_map[buffer] = [param]
10221021
self.context.update_symbol(buffer_name + "_data", data, self.node)
10231022
self.context.update_symbol(buffer_name, buffer, self.node)
10241023
else:

python/tvm/tir/stmt.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535

3636
from . import _ffi_api
3737
from .buffer import Buffer
38-
from .expr import IterVar
38+
from .expr import Var, IterVar
3939
from .sparse import SpIterVar, SparseBuffer
4040

4141

@@ -624,8 +624,8 @@ class SparseBlock(Stmt):
624624
sp_iter_vars : List[SpIterVar]
625625
The sparse iteration variables of the block.
626626
627-
sp_buffers : List[SparseBuffer]
628-
The sparse buffers defined in the block.
627+
sp_struct2param_map : Mapping[Object, List[Var]]
628+
The mapping from sparse data structures to the PrimFunc parameters.
629629
630630
name : str
631631
The name of the block.
@@ -641,7 +641,7 @@ class SparseBlock(Stmt):
641641
"""
642642

643643
sp_iter_vars: List[SpIterVar]
644-
sp_buffers: List[SparseBuffer]
644+
sp_struct2param_map: Mapping[Object, List[Var]]
645645
name: str
646646
body: Stmt
647647
init: Optional[Stmt]
@@ -650,7 +650,7 @@ class SparseBlock(Stmt):
650650
def __init__(
651651
self,
652652
sp_iter_vars: List[SpIterVar],
653-
sp_buffers: List[SparseBuffer],
653+
sp_struct2param_map: Mapping[Object, List[Var]],
654654
name: str,
655655
body: Stmt,
656656
init: Optional[Stmt] = None,
@@ -659,7 +659,7 @@ def __init__(
659659
self.__init_handle_by_constructor__(
660660
_ffi_api.SparseBlock, # type: ignore
661661
sp_iter_vars,
662-
sp_buffers,
662+
sp_struct2param_map,
663663
name,
664664
body,
665665
init,

src/tir/ir/stmt.cc

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -975,11 +975,12 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
975975
p->stream << "}\n";
976976
});
977977

978-
SparseBlock::SparseBlock(Array<SpIterVar> sp_iter_vars, Array<SparseBuffer> sp_buffers, String name,
979-
Stmt body, Optional<Stmt> init, Span span) {
978+
SparseBlock::SparseBlock(Array<SpIterVar> sp_iter_vars,
979+
Map<ObjectRef, Array<Var>> sp_struct2param_map, String name, Stmt body,
980+
Optional<Stmt> init, Span span) {
980981
ObjectPtr<SparseBlockNode> node = make_object<SparseBlockNode>();
981982
node->sp_iter_vars = std::move(sp_iter_vars);
982-
node->sp_buffers = std::move(sp_buffers);
983+
node->sp_struct2param_map = std::move(sp_struct2param_map);
983984
node->name = std::move(name);
984985
node->body = std::move(body);
985986
node->init = std::move(init);
@@ -988,9 +989,10 @@ SparseBlock::SparseBlock(Array<SpIterVar> sp_iter_vars, Array<SparseBuffer> sp_b
988989
}
989990

990991
TVM_REGISTER_GLOBAL("tir.SparseBlock")
991-
.set_body_typed([](Array<SpIterVar> sp_iter_vars, Array<SparseBuffer> sp_buffers, String name,
992-
Stmt body, Optional<Stmt> init, Span span) {
993-
return SparseBlock(sp_iter_vars, sp_buffers, name, body, init, span);
992+
.set_body_typed([](Array<SpIterVar> sp_iter_vars,
993+
Map<ObjectRef, Array<Var>> sp_struct2param_map, String name, Stmt body,
994+
Optional<Stmt> init, Span span) {
995+
return SparseBlock(sp_iter_vars, sp_struct2param_map, name, body, init, span);
994996
});
995997

996998
TVM_REGISTER_NODE_TYPE(SparseBlockNode);

0 commit comments

Comments
 (0)