Skip to content

Commit ae9306b

Browse files
committed
Add docstring for sparse tir lowering (apache#21)
* add docstring * upd
1 parent 189a469 commit ae9306b

File tree

4 files changed

+234
-102
lines changed

4 files changed

+234
-102
lines changed

include/tvm/tir/stmt.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1288,7 +1288,7 @@ class SparseBlockNode : public StmtNode {
12881288
/*! \brief The sparse data structures */
12891289
Array<ObjectRef> sp_structs;
12901290
/*! \brief The mapping from sparse data structures to the PrimFunc parameters */
1291-
Map<ObjectRef, Array<Var>> sp_struct2param_map;
1291+
Map<ObjectRef, Array<Var>> sp_struct_param_map;
12921292
/*! \brief The name of the block */
12931293
String name;
12941294
/*! \brief The body of the block */
@@ -1299,7 +1299,7 @@ class SparseBlockNode : public StmtNode {
12991299
void VisitAttrs(AttrVisitor* v) {
13001300
v->Visit("sp_iter_vars", &sp_iter_vars);
13011301
v->Visit("sp_structs", &sp_structs);
1302-
v->Visit("sp_struct2param_map", &sp_struct2param_map);
1302+
v->Visit("sp_struct_param_map", &sp_struct_param_map);
13031303
v->Visit("name", &name);
13041304
v->Visit("body", &body);
13051305
v->Visit("init", &init);

src/printer/tvmscript_printer.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1278,7 +1278,7 @@ Doc TVMScriptPrinter::PrintSparseStructDefinitions(const SparseBlockNode* sp_blo
12781278
std::vector<Doc> sp_buf_docs;
12791279

12801280
for (const ObjectRef& obj : sp_block->sp_structs) {
1281-
Array<Var> params = sp_block->sp_struct2param_map.Get(obj).value();
1281+
Array<Var> params = sp_block->sp_struct_param_map.Get(obj).value();
12821282

12831283
Doc doc;
12841284
doc << Print(obj) << " = " << tir_prefix_ << ".";

src/tir/ir/stmt.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -974,7 +974,7 @@ SparseBlock::SparseBlock(Array<SpIterVar> sp_iter_vars, Array<ObjectRef> sp_stru
974974
CHECK_EQ(sp_structs.size(), sp_struct_params.size())
975975
<< "ValueError: The length of `sp_struct_params` is expected to be equal to the length "
976976
"`sp_structs`, which is the number of sparse data structures";
977-
Map<ObjectRef, Array<Var>> sp_struct2param_map;
977+
Map<ObjectRef, Array<Var>> sp_struct_param_map;
978978
for (int i = 0; i < static_cast<int>(sp_structs.size()); ++i) {
979979
ObjectRef obj = sp_structs[i];
980980
Array<Var> params = sp_struct_params[i];
@@ -998,13 +998,13 @@ SparseBlock::SparseBlock(Array<SpIterVar> sp_iter_vars, Array<ObjectRef> sp_stru
998998
LOG(FATAL) << "ValueError: " << obj->_type_key << " is not a sparse data structure";
999999
}
10001000

1001-
sp_struct2param_map.Set(obj, params);
1001+
sp_struct_param_map.Set(obj, params);
10021002
}
10031003

10041004
ObjectPtr<SparseBlockNode> node = make_object<SparseBlockNode>();
10051005
node->sp_iter_vars = std::move(sp_iter_vars);
10061006
node->sp_structs = std::move(sp_structs);
1007-
node->sp_struct2param_map = std::move(sp_struct2param_map);
1007+
node->sp_struct_param_map = std::move(sp_struct_param_map);
10081008
node->name = std::move(name);
10091009
node->body = std::move(body);
10101010
node->init = std::move(init);

0 commit comments

Comments
 (0)