Skip to content

Commit 7a260cc

Browse files
yzh119MasterJH5574
authored andcommitted
Add docstring for sparse tir lowering (#21)
* add docstring * upd
1 parent 241af8c commit 7a260cc

File tree

4 files changed

+234
-102
lines changed

4 files changed

+234
-102
lines changed

include/tvm/tir/stmt.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1288,7 +1288,7 @@ class SparseBlockNode : public StmtNode {
12881288
/*! \brief The sparse data structures */
12891289
Array<ObjectRef> sp_structs;
12901290
/*! \brief The mapping from sparse data structures to the PrimFunc parameters */
1291-
Map<ObjectRef, Array<Var>> sp_struct2param_map;
1291+
Map<ObjectRef, Array<Var>> sp_struct_param_map;
12921292
/*! \brief The name of the block */
12931293
String name;
12941294
/*! \brief The body of the block */
@@ -1299,7 +1299,7 @@ class SparseBlockNode : public StmtNode {
12991299
void VisitAttrs(AttrVisitor* v) {
13001300
v->Visit("sp_iter_vars", &sp_iter_vars);
13011301
v->Visit("sp_structs", &sp_structs);
1302-
v->Visit("sp_struct2param_map", &sp_struct2param_map);
1302+
v->Visit("sp_struct_param_map", &sp_struct_param_map);
13031303
v->Visit("name", &name);
13041304
v->Visit("body", &body);
13051305
v->Visit("init", &init);

src/printer/tvmscript_printer.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1359,7 +1359,7 @@ Doc TVMScriptPrinter::PrintSparseStructDefinitions(const SparseBlockNode* sp_blo
13591359
std::vector<Doc> sp_buf_docs;
13601360

13611361
for (const ObjectRef& obj : sp_block->sp_structs) {
1362-
Array<Var> params = sp_block->sp_struct2param_map.Get(obj).value();
1362+
Array<Var> params = sp_block->sp_struct_param_map.Get(obj).value();
13631363

13641364
Doc doc;
13651365
doc << Print(obj) << " = " << tir_prefix_ << ".";

src/tir/ir/stmt.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -981,7 +981,7 @@ SparseBlock::SparseBlock(Array<SpIterVar> sp_iter_vars, Array<ObjectRef> sp_stru
981981
CHECK_EQ(sp_structs.size(), sp_struct_params.size())
982982
<< "ValueError: The length of `sp_struct_params` is expected to be equal to the length "
983983
"`sp_structs`, which is the number of sparse data structures";
984-
Map<ObjectRef, Array<Var>> sp_struct2param_map;
984+
Map<ObjectRef, Array<Var>> sp_struct_param_map;
985985
for (int i = 0; i < static_cast<int>(sp_structs.size()); ++i) {
986986
ObjectRef obj = sp_structs[i];
987987
Array<Var> params = sp_struct_params[i];
@@ -1005,13 +1005,13 @@ SparseBlock::SparseBlock(Array<SpIterVar> sp_iter_vars, Array<ObjectRef> sp_stru
10051005
LOG(FATAL) << "ValueError: " << obj->_type_key << " is not a sparse data structure";
10061006
}
10071007

1008-
sp_struct2param_map.Set(obj, params);
1008+
sp_struct_param_map.Set(obj, params);
10091009
}
10101010

10111011
ObjectPtr<SparseBlockNode> node = make_object<SparseBlockNode>();
10121012
node->sp_iter_vars = std::move(sp_iter_vars);
10131013
node->sp_structs = std::move(sp_structs);
1014-
node->sp_struct2param_map = std::move(sp_struct2param_map);
1014+
node->sp_struct_param_map = std::move(sp_struct_param_map);
10151015
node->name = std::move(name);
10161016
node->body = std::move(body);
10171017
node->init = std::move(init);

0 commit comments

Comments
 (0)