Skip to content

Commit 47d90a6

Browse files
committed
Axis Dependency Tree aware code-gen and bmm example (#28)
* upd * upd * upd * upd * upd * upd * upd * upd * remove redundancy * fix * upd * upd
1 parent 0851d74 commit 47d90a6

File tree

7 files changed

+373
-186
lines changed

7 files changed

+373
-186
lines changed

include/tvm/tir/sparse.h

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ class AxisNode : public Object {
4747
String GetName() const { return name; }
4848
PrimExpr GetLength() const { return length; }
4949
DataType GetIndexType() const { return length->dtype; }
50+
51+
virtual bool is_fixed() const = 0;
5052

5153
static constexpr const char* _type_key = "tir.sparse.Axis";
5254
static constexpr const bool _type_has_method_sequal_reduce = true;
@@ -141,6 +143,10 @@ class DenseFixedAxisNode : public DenseAxisNode {
141143
hash_reduce(from_sparse);
142144
}
143145

146+
bool is_fixed() const {
147+
return true;
148+
}
149+
144150
static constexpr const char* _type_key = "tir.sparse.DenseFixedAxis";
145151
TVM_DECLARE_FINAL_OBJECT_INFO(DenseFixedAxisNode, DenseAxisNode);
146152
};
@@ -177,6 +183,10 @@ class DenseVariableAxisNode : public DenseAxisNode {
177183
hash_reduce(indptr);
178184
}
179185

186+
bool is_fixed() const {
187+
return false;
188+
}
189+
180190
static constexpr const char* _type_key = "tir.sparse.DenseVariableAxis";
181191
TVM_DECLARE_FINAL_OBJECT_INFO(DenseVariableAxisNode, DenseAxisNode);
182192
};
@@ -220,6 +230,10 @@ class SparseFixedAxisNode : public SparseAxisNode {
220230
hash_reduce(nnz_cols);
221231
}
222232

233+
bool is_fixed() const {
234+
return true;
235+
}
236+
223237
static constexpr const char* _type_key = "tir.sparse.SparseFixedAxis";
224238
TVM_DECLARE_FINAL_OBJECT_INFO(SparseFixedAxisNode, SparseAxisNode);
225239
};
@@ -262,6 +276,10 @@ class SparseVariableAxisNode : public SparseAxisNode {
262276
hash_reduce(indices);
263277
}
264278

279+
bool is_fixed() const {
280+
return false;
281+
}
282+
265283
static constexpr const char* _type_key = "tir.sparse.SparseVariableAxis";
266284
TVM_DECLARE_FINAL_OBJECT_INFO(SparseVariableAxisNode, SparseAxisNode);
267285
};
@@ -283,9 +301,9 @@ class SparseVariableAxis : public SparseAxis {
283301
class AxisTreeNode : public Object {
284302
public:
285303
// unordered map that stores the parent relationship between axes.
286-
Map<String, Optional<String>> parent;
304+
Map<String, String> parent;
287305
// unordered map that stores the children relationship between axes.
288-
Map<Optional<String>, Array<String>> children;
306+
Map<String, Array<String>> children;
289307

290308
void VisitAttrs(AttrVisitor* v) {
291309
v->Visit("parent", &parent);

include/tvm/tir/transform.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -494,9 +494,10 @@ TVM_DLL Pass UnifiedStaticMemoryPlanner();
494494

495495
/*!
496496
* \brief Lower SparseTIR to TIR.
497+
* \param axis_tree The axis dependency tree.
497498
* \return The pass.
498499
*/
499-
TVM_DLL Pass LowerSparseTIR();
500+
TVM_DLL Pass LowerSparseTIR(AxisTree axis_tree);
500501

501502
} // namespace transform
502503
} // namespace tir

python/tvm/tir/transform/transform.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from typing import Optional
2020
from . import _ffi_api
2121
from . import function_pass as _fpass
22+
from ..sparse import AxisTree
2223

2324

2425
def Apply(ftransform):
@@ -751,12 +752,17 @@ def ConvertForLoopsToSerial():
751752
return _ffi_api.ConvertForLoopsToSerial() # type: ignore
752753

753754

754-
def LowerSparseTIR():
755+
def LowerSparseTIR(axis_tree: AxisTree):
755756
"""Lower SparseTIR to TIR
756757
758+
Parameters
759+
----------
760+
axis_tree : AxisTree
761+
The axis dependency tree.
762+
757763
Returns
758764
-------
759765
fpass : tvm.transform.Pass
760766
The result pass
761767
"""
762-
return _ffi_api.LowerSparseTIR() # type: ignore
768+
return _ffi_api.LowerSparseTIR(axis_tree) # type: ignore

src/tir/ir/sparse.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,12 +146,15 @@ AxisTree::AxisTree(Array<String> axis_names, Array<Optional<String>> axis_parent
146146
"axis_parent_names "
147147
"array.";
148148
ObjectPtr<AxisTreeNode> node = make_object<AxisTreeNode>();
149-
Map<String, Optional<String>> parent;
150-
Map<Optional<String>, Array<String>> children;
149+
Map<String, String> parent;
150+
Map<String, Array<String>> children;
151151
for (size_t i = 0; i < axis_names.size(); i++) {
152152
// update parent map & children map
153153
String axis_name = axis_names[i];
154-
Optional<String> parent_name = axis_parent_names[i];
154+
String parent_name("root");
155+
if (axis_parent_names[i].defined()) {
156+
parent_name = axis_parent_names[i].value();
157+
}
155158
parent.Set(axis_name, parent_name);
156159

157160
auto it = children.find(parent_name);

src/tir/transforms/lower_sparse_tir.cc

Lines changed: 118 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
#include <tvm/tir/stmt_functor.h>
2727
#include <tvm/tir/transform.h>
2828

29+
#include <set>
30+
#include <stack>
2931
#include <utility>
3032

3133
#include "../schedule/analysis.h"
@@ -87,8 +89,8 @@ Map<Var, Buffer> UpdateBufferMap(PrimFunc f) {
8789
*/
8890
class IndexTransformer : public StmtExprMutator {
8991
public:
90-
explicit IndexTransformer(AccessAndDependencyCollector collector)
91-
: collector_(std::move(collector)) {}
92+
explicit IndexTransformer(AccessAndDependencyCollector collector, AxisTree axis_tree)
93+
: collector_(std::move(collector)), axis_tree_(std::move(axis_tree)) {}
9294

9395
private:
9496
/*!
@@ -281,43 +283,124 @@ class IndexTransformer : public StmtExprMutator {
281283
sp_block->init.defined() ? VisitStmt(sp_block->init.value()) : Optional<Stmt>(NullOpt);
282284
Stmt body = VisitStmt(sp_block->body);
283285

284-
// Step 2. Create the new outer loop vars.
285-
Array<Var> loop_vars;
286+
// Step 2. Create the new loop vars.
286287
std::unordered_map<const VarNode*, PrimExpr> var_map;
287-
loop_vars.reserve(n_iter);
288+
Array<Var> all_loop_vars;
288289
var_map.reserve(n_iter);
289290
for (const SpIterVar& sp_iter : sp_block->sp_iter_vars) {
290291
Var loop_var("v_" + sp_iter->var->name_hint);
291-
loop_vars.push_back(loop_var);
292+
all_loop_vars.push_back(loop_var);
292293
var_map[sp_iter->var.get()] = loop_var;
293294
}
294295

295-
// Step 3. Create block iters and iter bindings.
296-
Array<IterVar> block_iters;
297-
Array<PrimExpr> iter_bindings;
298-
block_iters.reserve(n_iter);
299-
iter_bindings.reserve(n_iter);
300-
for (int i = 0; i < n_iter; ++i) {
301-
block_iters.push_back(SpIterVarToIterVar(sp_block->sp_iter_vars[i], var_map));
302-
iter_bindings.push_back(loop_vars[i]);
303-
}
296+
// Step 3. Collet block iters and iter bindings.
297+
std::set<String> in_stack;
298+
in_stack.insert("root");
299+
/* A stack that stores block itervars in each block. */
300+
std::stack<Array<IterVar>> block_iters_st;
301+
/* A stack that stores itervar bindings in each block. */
302+
std::stack<Array<PrimExpr>> iter_bindings_st;
303+
/* A stack that stores generated loop vars in each block. */
304+
std::stack<Array<Var>> loop_vars_st;
305+
/* A stack that stores whether to place init block in each block. */
306+
std::stack<bool> place_init_st;
307+
/* An indicator that records whether init block has been set. */
308+
bool init_set = false;
309+
do {
310+
/* Block itervars of current block. */
311+
Array<IterVar> block_iters;
312+
/* Itervar bindings of current block. */
313+
Array<PrimExpr> iter_bindings;
314+
/* Axis names of current block. */
315+
Array<String> axis_names;
316+
/* Generated loop vars of current block. */
317+
Array<Var> loop_vars;
318+
/* An indicator that records whether there is reduction axis in current block. */
319+
bool has_reduction_var = false;
320+
for (int i = 0; i < n_iter; ++i) {
321+
SpIterVar sp_it_var = sp_block->sp_iter_vars[i];
322+
String axis_name = sp_it_var->axis->name;
323+
auto&& parent_axis = axis_tree_->parent.Get(axis_name);
324+
CHECK(parent_axis.defined()) << "Sparse IterVar not defined in Axis Tree.";
325+
String parent_axis_name = parent_axis.value();
326+
bool is_fixed_axis = sp_it_var->axis->is_fixed();
327+
/* Add itervar to current block when
328+
* - it's not used yet (not in stack) and
329+
* - it's parent axis was used in outer blocks or
330+
* - it's an iterator to a fixed axis.
331+
*/
332+
if ((is_fixed_axis || in_stack.find(parent_axis_name) != in_stack.end()) &&
333+
in_stack.find(axis_name) == in_stack.end()) {
334+
loop_vars.push_back(all_loop_vars[i]);
335+
axis_names.push_back(std::move(axis_name));
336+
block_iters.push_back(SpIterVarToIterVar(sp_it_var, var_map));
337+
iter_bindings.push_back(all_loop_vars[i]);
338+
has_reduction_var |= sp_it_var->is_reduction;
339+
}
340+
}
341+
342+
/* Tag axes in current block as "in-stack". */
343+
for (const String&& axis_name : axis_names) {
344+
in_stack.insert(std::move(axis_name));
345+
}
346+
347+
/* Update stack. */
348+
if (!block_iters.empty()) {
349+
block_iters_st.push(std::move(block_iters));
350+
iter_bindings_st.push(std::move(iter_bindings));
351+
loop_vars_st.push(std::move(loop_vars));
352+
if (init_set) {
353+
place_init_st.push(false);
354+
} else {
355+
place_init_st.push(has_reduction_var);
356+
init_set |= has_reduction_var;
357+
}
358+
} else {
359+
break;
360+
}
361+
} while (true);
304362

305363
// Step 4. Generate the read-region and write-retion of the block.
306364
Array<BufferRegion> reads{nullptr};
307365
Array<BufferRegion> writes{nullptr};
308366
GenerateReadWriteRegions(sp_block, &reads, &writes);
309367

310-
// Step 5. Create the block and block-realize
311-
Map<String, ObjectRef> mapping;
312-
mapping.Set("sparse", Bool(true));
313-
Block block(block_iters, std::move(reads), std::move(writes), sp_block->name, std::move(body),
314-
std::move(init), {}, {}, std::move(mapping));
315-
BlockRealize block_realize(std::move(iter_bindings), const_true(), std::move(block));
316-
317-
// Step 6. Create outer loops and the block binding.
318-
Stmt loop = GenerateLoops(std::move(block_realize), block_iters, loop_vars);
368+
// Step 5. Generate nested blocks and loops from innermost to outermost.
369+
int blk_counter = 0;
370+
while (!block_iters_st.empty()) {
371+
Array<IterVar> block_iters = std::move(block_iters_st.top());
372+
Array<PrimExpr> iter_bindings = std::move(iter_bindings_st.top());
373+
Array<Var> loop_vars = std::move(loop_vars_st.top());
374+
bool place_init = place_init_st.top();
375+
block_iters_st.pop();
376+
iter_bindings_st.pop();
377+
loop_vars_st.pop();
378+
place_init_st.pop();
379+
380+
Map<String, ObjectRef> mapping;
381+
mapping.Set("sparse", Bool(true));
382+
String blk_name_hint = sp_block->name;
383+
if (blk_counter != 0) {
384+
blk_name_hint = blk_name_hint + "_" + std::to_string(blk_counter);
385+
}
386+
Block block(/*iter_vars=*/block_iters,
387+
/*reads=*/reads,
388+
/*writes=*/writes,
389+
/*name_hint=*/blk_name_hint,
390+
/*body=*/std::move(body),
391+
/*init=*/place_init ? std::move(init) : NullOpt,
392+
/*alloc_buffers=*/{},
393+
/*match_buffers=*/{},
394+
/*annotations=*/std::move(mapping),
395+
/*span=*/sp_block->span);
396+
BlockRealize block_realize(std::move(iter_bindings), const_true(), std::move(block));
397+
// Generate outer loop and the block binding.
398+
Stmt loop = GenerateLoops(std::move(block_realize), block_iters, loop_vars);
399+
body = loop;
400+
blk_counter += 1;
401+
}
319402

320-
return loop;
403+
return body;
321404
}
322405

323406
/*!
@@ -380,9 +463,10 @@ class IndexTransformer : public StmtExprMutator {
380463
}
381464

382465
/*!
383-
* \brief generated nested for loops for sparse block.
466+
* \brief generated nested for-loops for sparse block.
384467
* \param block_iters The iterators defined in sparse blocks.
385468
* \param loop_vars The loop variables binded with block iterators.
469+
* \return The outermost loop.
386470
*/
387471
Stmt GenerateLoops(Stmt body, const Array<IterVar>& block_iters, const Array<Var>& loop_vars) {
388472
int n_iter = static_cast<int>(block_iters.size());
@@ -394,6 +478,7 @@ class IndexTransformer : public StmtExprMutator {
394478
}
395479

396480
AccessAndDependencyCollector collector_;
481+
AxisTree axis_tree_;
397482
arith::Analyzer ana_;
398483
std::unordered_set<const SparseBufferNode*> buffer_read_;
399484
std::unordered_set<const SparseBufferNode*> buffer_write_;
@@ -411,11 +496,12 @@ Stmt WrapWithRootBlock(Stmt body) {
411496
}
412497

413498
/*!
414-
* \brief Rewrite the given primitive function
499+
* \brief Rewrite the given primitive function.
500+
* \param axis_tree The axis dependency tree.
415501
* \param f The Sparse-TIR primitive function to lower.
416502
* \return lowered primitive function in TIR.
417503
*/
418-
PrimFunc LowerSparseTIR(PrimFunc f) {
504+
PrimFunc LowerSparseTIR(AxisTree axis_tree, PrimFunc f) {
419505
// Only apply this pass to TIR that is not from TE schedules
420506
if (!IsFromLegacyTESchedule(f)) {
421507
PrimFuncNode* fptr = f.CopyOnWrite();
@@ -425,7 +511,7 @@ PrimFunc LowerSparseTIR(PrimFunc f) {
425511
AccessAndDependencyCollector collector;
426512
collector.Collect(f->body);
427513
// Step 3. Lower indices.
428-
fptr->body = IndexTransformer(collector)(std::move(f->body));
514+
fptr->body = IndexTransformer(collector, axis_tree)(std::move(f->body));
429515
// Step 4. Wrap the function body with a root block.
430516
fptr->body = WrapWithRootBlock(std::move(fptr->body));
431517
return f;
@@ -438,10 +524,11 @@ namespace transform {
438524

439525
/*!
440526
* \brief The lowering pass from TIR to Sparse TIR.
527+
* \param axis_tree The axis dependency tree.
441528
*/
442-
Pass LowerSparseTIR() {
529+
Pass LowerSparseTIR(AxisTree axis_tree) {
443530
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
444-
return LowerSparseTIR(std::move(f));
531+
return LowerSparseTIR(std::move(axis_tree), std::move(f));
445532
};
446533
return CreatePrimFuncPass(pass_func, 0, "tir.LowerSparseTIR", {});
447534
}

0 commit comments

Comments
 (0)