Skip to content

Commit cf36aa6

Browse files
authored
[TIR] Add TIR While node (#7425)
* add while node * update visitors * binary search lowering works * llvm codegen working * cuda codegen working * nms updated to use while loop * add missing upper bound check too * add mandelbrot test * add gpu mandel commit ee2363b Author: Masahiro Masuda <masahi129@gmail.com> Date: Fri Jan 29 11:44:02 2021 +0900 enable extern lib offload for nvptx * rename test * run black * add doc * add collatz test * add while + vectorize test * simplify bin search * Add special case visit method to storage_access.cc * disallow while loop inside vectorized loop * disallow trivial condition since we do not have break * error out in CoprocSync for now * error out LiftAttrScope for now * add placeholder to inject_vpthread * refactor to use MakeAttach * handle WhileNode in InplaceOpVerifier * error out in InjectVirtualThread * try handle WhileNode in StoragePlanRewriter * remove WhileNode visitor from storage rewrite * add while loop storage rewrite test * update tests * move test_vectorize_while_fail to test_tir_transform_vectorize.py
1 parent 3a02e0b commit cf36aa6

23 files changed

+695
-23
lines changed

include/tvm/tir/stmt.h

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -861,6 +861,53 @@ class For : public Stmt {
861861
TVM_DEFINE_OBJECT_REF_METHODS(For, Stmt, ForNode);
862862
};
863863

864+
/*!
865+
* \brief A While loop
866+
*
867+
* \code
868+
*
869+
* while (condition)
870+
* body
871+
*
872+
* \endcode
873+
*/
874+
class WhileNode : public StmtNode {
875+
public:
876+
/*! \brief The termination condition. */
877+
PrimExpr condition;
878+
/*! \brief The body of the while loop. */
879+
Stmt body;
880+
881+
void VisitAttrs(AttrVisitor* v) {
882+
v->Visit("condition", &condition);
883+
v->Visit("body", &body);
884+
v->Visit("span", &span);
885+
}
886+
887+
bool SEqualReduce(const WhileNode* other, SEqualReducer equal) const {
888+
return equal.DefEqual(condition, other->condition) && equal.DefEqual(body, other->body);
889+
}
890+
891+
void SHashReduce(SHashReducer hash_reduce) const {
892+
hash_reduce.DefHash(condition);
893+
hash_reduce.DefHash(body);
894+
}
895+
896+
static constexpr const char* _type_key = "tir.While";
897+
TVM_DECLARE_FINAL_OBJECT_INFO(WhileNode, StmtNode);
898+
};
899+
900+
/*!
901+
* \brief Managed reference to WhileNode.
902+
* \sa WhileNode
903+
*/
904+
class While : public Stmt {
905+
public:
906+
TVM_DLL While(PrimExpr condition, Stmt body, Span span = Span());
907+
908+
TVM_DEFINE_OBJECT_REF_METHODS(While, Stmt, WhileNode);
909+
};
910+
864911
/*!
865912
* \brief A prefetch hint for a buffer
866913
*/

include/tvm/tir/stmt_functor.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
8686
virtual R VisitStmt_(const AttrStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
8787
virtual R VisitStmt_(const IfThenElseNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
8888
virtual R VisitStmt_(const ForNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
89+
virtual R VisitStmt_(const WhileNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
8990
virtual R VisitStmt_(const AllocateNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
9091
virtual R VisitStmt_(const StoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
9192
virtual R VisitStmt_(const BufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
@@ -111,6 +112,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
111112
IR_STMT_FUNCTOR_DISPATCH(AttrStmtNode);
112113
IR_STMT_FUNCTOR_DISPATCH(IfThenElseNode);
113114
IR_STMT_FUNCTOR_DISPATCH(ForNode);
115+
IR_STMT_FUNCTOR_DISPATCH(WhileNode);
114116
IR_STMT_FUNCTOR_DISPATCH(AllocateNode);
115117
IR_STMT_FUNCTOR_DISPATCH(StoreNode);
116118
IR_STMT_FUNCTOR_DISPATCH(AssertStmtNode);
@@ -152,6 +154,7 @@ class TVM_DLL StmtVisitor : protected StmtFunctor<void(const Stmt&)> {
152154
void VisitStmt_(const IfThenElseNode* op) override;
153155
void VisitStmt_(const LetStmtNode* op) override;
154156
void VisitStmt_(const ForNode* op) override;
157+
void VisitStmt_(const WhileNode* op) override;
155158
void VisitStmt_(const AllocateNode* op) override;
156159
void VisitStmt_(const StoreNode* op) override;
157160
void VisitStmt_(const BufferStoreNode* op) override;
@@ -245,6 +248,7 @@ class TVM_DLL StmtMutator : protected StmtFunctor<Stmt(const Stmt&)> {
245248
Stmt VisitStmt_(const IfThenElseNode* op) override;
246249
Stmt VisitStmt_(const LetStmtNode* op) override;
247250
Stmt VisitStmt_(const ForNode* op) override;
251+
Stmt VisitStmt_(const WhileNode* op) override;
248252
Stmt VisitStmt_(const AllocateNode* op) override;
249253
Stmt VisitStmt_(const StoreNode* op) override;
250254
Stmt VisitStmt_(const BufferStoreNode* op) override;

python/tvm/tir/ir_builder.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,35 @@ def _exit_cb():
263263

264264
return WithScope(loop_var, _exit_cb)
265265

266+
def while_loop(self, condition):
267+
"""Create a while loop scope.
268+
269+
Parameters
270+
----------
271+
condition : Expr
272+
The termination condition.
273+
274+
Returns
275+
-------
276+
loop_scope : With.Scope of Var
277+
The while scope.
278+
279+
Examples
280+
--------
281+
.. code-block:: python
282+
283+
ib = tvm.tir.ir_builder.create()
284+
iterations = ib.allocate("int32", (1,), name="iterations", scope="local")
285+
with ib.while_loop(iterations[0] < 10):
286+
iterations[0] += 1
287+
"""
288+
self._seq_stack.append([])
289+
290+
def _exit_cb():
291+
self.emit(_stmt.While(condition, self._pop_seq()))
292+
293+
return WithScope(None, _exit_cb)
294+
266295
def if_scope(self, cond):
267296
"""Create an if scope.
268297

python/tvm/tir/stmt.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,31 @@ def __init__(
159159
)
160160

161161

162+
@tvm._ffi.register_object("tir.While")
163+
class While(Stmt):
164+
"""While node.
165+
166+
Parameters
167+
----------
168+
condition : PrimExpr
169+
The termination condition.
170+
171+
body : Stmt
172+
The body statement.
173+
174+
span : Optional[Span]
175+
The location of this itervar in the source code.
176+
"""
177+
178+
def __init__(self, condition, body, span=None):
179+
self.__init_handle_by_constructor__(
180+
_ffi_api.While,
181+
condition,
182+
body,
183+
span,
184+
)
185+
186+
162187
@tvm._ffi.register_object("tir.Store")
163188
class Store(Stmt):
164189
"""Store node.

python/tvm/topi/cuda/nms.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -521,7 +521,7 @@ def nms_inner_loop(ib, j):
521521
offset_j = j * 4
522522
num_iter_per_thread = ceil_div(nkeep - (j + 1), nthread_tx)
523523

524-
with ib.for_range(0, num_iter_per_thread) as _k:
524+
with ib.for_range(0, num_iter_per_thread, name="_k") as _k:
525525
k = j + 1 + _k * nthread_tx + tx
526526
offset_k = k * 4
527527

@@ -555,16 +555,22 @@ def nms_inner_loop(ib, j):
555555

556556
with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)):
557557
# Apply nms
558-
with ib.for_range(0, nkeep) as j:
559-
# Proceed to the inner loop if the box j is still valid
560-
with ib.if_scope(out_scores[i, j] > -1.0):
561-
with ib.if_scope(max_output_size > 0):
562-
# No need to do more iteration if we have already reached max_output_size
563-
# boxes
564-
# TODO(masahi): Add TIR while loop to realize early exit from the outer loop
565-
with ib.if_scope(num_valid_boxes_local[0] < max_output_size):
566-
nms_inner_loop(ib, j)
567-
with ib.else_scope():
558+
with ib.if_scope(max_output_size > 0):
559+
# No need to do more iteration if we have already reached max_output_size boxes
560+
box_idx = ib.allocate("int32", (1,), name="box_idx", scope="local")
561+
box_idx[0] = 0
562+
with ib.while_loop(
563+
tvm.tir.all(box_idx[0] < nkeep, num_valid_boxes_local[0] < max_output_size)
564+
):
565+
# Proceed to the inner loop if the box with id box_idx is still valid
566+
with ib.if_scope(out_scores[i, box_idx[0]] > -1.0):
567+
nms_inner_loop(ib, box_idx[0])
568+
box_idx[0] += 1
569+
570+
with ib.else_scope():
571+
with ib.for_range(0, nkeep, name="j") as j:
572+
# Proceed to the inner loop if the box j is still valid
573+
with ib.if_scope(out_scores[i, j] > -1.0):
568574
nms_inner_loop(ib, j)
569575

570576
with ib.if_scope(tx + 0 == 0):

src/printer/text_printer.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,7 @@ class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
308308
Doc VisitStmt_(const SeqStmtNode* op) override;
309309
Doc VisitStmt_(const EvaluateNode* op) override;
310310
Doc VisitStmt_(const ForNode* op) override;
311+
Doc VisitStmt_(const WhileNode* op) override;
311312
Doc VisitStmt_(const PrefetchNode* op) override;
312313
Doc VisitStmtDefault_(const Object* op) override;
313314

src/printer/tir_text_printer.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,13 @@ Doc TIRTextPrinter::VisitStmt_(const ForNode* op) {
494494
return doc;
495495
}
496496

497+
Doc TIRTextPrinter::VisitStmt_(const WhileNode* op) {
498+
Doc doc;
499+
doc << "while (" << Print(op->condition) << ")";
500+
doc << PrintBody(op->body);
501+
return doc;
502+
}
503+
497504
Doc TIRTextPrinter::VisitStmt_(const PrefetchNode* op) {
498505
Doc doc;
499506
doc << "prefetch(" << Print(op->buffer) << ", " << Print(op->bounds) << ")";

src/target/llvm/codegen_llvm.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1328,6 +1328,20 @@ void CodeGenLLVM::VisitStmt_(const ForNode* op) {
13281328
llvm::ConstantInt::getSigned(GetLLVMType(op->extent), 1), op->loop_var, op->body);
13291329
}
13301330

1331+
void CodeGenLLVM::VisitStmt_(const WhileNode* op) {
1332+
using llvm::BasicBlock;
1333+
BasicBlock* while_cond = BasicBlock::Create(*ctx_, "while_cond", function_);
1334+
BasicBlock* while_body = BasicBlock::Create(*ctx_, "while_body", function_);
1335+
BasicBlock* while_merge = BasicBlock::Create(*ctx_, "while_merge", function_);
1336+
builder_->CreateBr(while_cond);
1337+
builder_->SetInsertPoint(while_cond);
1338+
builder_->CreateCondBr(MakeValue(op->condition), while_body, while_merge);
1339+
builder_->SetInsertPoint(while_body);
1340+
this->VisitStmt(op->body);
1341+
builder_->CreateBr(while_cond);
1342+
builder_->SetInsertPoint(while_merge);
1343+
}
1344+
13311345
void CodeGenLLVM::VisitStmt_(const IfThenElseNode* op) {
13321346
using llvm::BasicBlock;
13331347
llvm::Value* cond = MakeValue(op->condition);

src/target/llvm/codegen_llvm.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
152152
// stmt
153153
void VisitStmt_(const StoreNode* op) override;
154154
void VisitStmt_(const ForNode* op) override;
155+
void VisitStmt_(const WhileNode* op) override;
155156
void VisitStmt_(const IfThenElseNode* op) override;
156157
void VisitStmt_(const AllocateNode* op) override;
157158
void VisitStmt_(const AttrStmtNode* op) override;

src/target/source/codegen_c.cc

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -728,7 +728,6 @@ void CodeGenC::VisitStmt_(const StoreNode* op) {
728728
ICHECK(is_one(op->predicate)) << "Predicated store is not supported";
729729
arith::PVar<PrimExpr> base;
730730

731-
732731
if (arith::ramp(base, 1, t.lanes()).Match(op->index)) {
733732
std::string value = this->PrintExpr(op->value);
734733
this->PrintVecStore(op->buffer_var.get(), t, base.Eval(), value);
@@ -899,6 +898,16 @@ void CodeGenC::VisitStmt_(const ForNode* op) {
899898
stream << "}\n";
900899
}
901900

901+
void CodeGenC::VisitStmt_(const WhileNode* op) {
902+
PrintIndent();
903+
stream << "while (" << PrintExpr(op->condition) << ") {\n";
904+
int while_scope = BeginScope();
905+
PrintStmt(op->body);
906+
this->EndScope(while_scope);
907+
PrintIndent();
908+
stream << "}\n";
909+
}
910+
902911
void CodeGenC::VisitStmt_(const IfThenElseNode* op) {
903912
std::string cond = PrintExpr(op->condition);
904913
PrintIndent();

0 commit comments

Comments
 (0)