Skip to content

Commit 8260e95

Browse files
wrongtest-intellifylc
authored andcommitted
[TIR] add loop partition hint pragma (apache#9121)
* add loop partition hint pragma * fix unintialized var * fix to remove hint at last * use tir compare for loop partition testcase
1 parent bb37a8c commit 8260e95

File tree

3 files changed

+116
-27
lines changed

3 files changed

+116
-27
lines changed

include/tvm/tir/stmt.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1339,6 +1339,12 @@ constexpr const char* hand_threaded = "hand_threaded";
13391339
* if (mask & 2) the write region should be detected.
13401340
*/
13411341
constexpr const char* script_parsing_detect_access = "tir.script_parsing_detect_access";
1342+
1343+
/*!
1344+
* \brief Mark that the loop should be partitioned.
1345+
*/
1346+
constexpr const char* pragma_loop_partition_hint = "pragma_loop_partition_hint";
1347+
13421348
/*!
13431349
* \brief Check if attr_key is a pragma key extension
13441350
* \param attr_key The attr key to be compared

src/tir/transforms/loop_partition.cc

Lines changed: 80 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,13 @@ class CandidateSelector final : public StmtExprVisitor {
9898
void VisitStmt_(const ForNode* op) final {
9999
// partition const loop when sets partition_const_loop_
100100
if (!is_const_int(op->min) || !is_const_int(op->extent) || partition_const_loop_) {
101+
// always treat var with hint to be partitioned
101102
const VarNode* var = op->loop_var.get();
103+
if (partition_hint_vars.count(var)) {
104+
candidates.insert(GetRef<Stmt>(op));
105+
StmtExprVisitor::VisitStmt_(op);
106+
return;
107+
}
102108
record_.insert({var, false});
103109
StmtExprVisitor::VisitStmt_(op);
104110
if (record_.at(var) && !no_split_) {
@@ -117,6 +123,12 @@ class CandidateSelector final : public StmtExprVisitor {
117123
Var var = iv->var;
118124
runtime::ThreadScope scope = runtime::ThreadScope::Create(iv->thread_tag);
119125
if ((scope.rank == 0) && (!is_const_int(op->value) || partition_const_loop_)) {
126+
// always treat var with hint to be partitioned
127+
if (partition_hint_vars.count(var.get())) {
128+
candidates.insert(GetRef<Stmt>(op));
129+
StmtExprVisitor::VisitStmt_(op);
130+
return;
131+
}
120132
record_.insert({var.get(), false});
121133
StmtExprVisitor::VisitStmt_(op);
122134
if (record_.at(var.get()) && !no_split_) {
@@ -125,6 +137,15 @@ class CandidateSelector final : public StmtExprVisitor {
125137
record_.erase(var.get());
126138
return;
127139
}
140+
} else if (op->attr_key == attr::pragma_loop_partition_hint) {
141+
const VarNode* var = nullptr;
142+
if (op->node->IsInstance<VarNode>()) {
143+
var = op->node.as<VarNode>();
144+
} else if (op->node->IsInstance<IterVarNode>()) {
145+
var = op->node.as<IterVarNode>()->var.get();
146+
}
147+
ICHECK(var);
148+
partition_hint_vars.insert(var);
128149
}
129150
StmtExprVisitor::VisitStmt_(op);
130151
}
@@ -162,6 +183,7 @@ class CandidateSelector final : public StmtExprVisitor {
162183
}
163184

164185
std::unordered_set<Stmt, ObjectPtrHash, ObjectPtrEqual> candidates;
186+
std::unordered_set<const VarNode*> partition_hint_vars;
165187

166188
private:
167189
bool in_likely_{false};
@@ -170,15 +192,28 @@ class CandidateSelector final : public StmtExprVisitor {
170192
std::unordered_map<const VarNode*, VarIsUsed> record_;
171193
};
172194

195+
// Finder try best to find partitions for hinted vars
196+
#define DEFINE_PARTITION_FINDER_VISIT_CMP_OP(OpNodeT) \
197+
void VisitExpr_(const OpNodeT* op) final { \
198+
if (has_partition_hint_) { \
199+
DeduceCondition(GetRef<PrimExpr>(op)); \
200+
return; \
201+
} \
202+
StmtExprVisitor::VisitExpr_(op); \
203+
}
204+
173205
// Populate partitions data structure, i.e., for a specific variable,
174-
// find an interval in which each condition
175-
// (currently, "likely" conditions) has fixed true or false value
206+
// find an interval in which each condition has fixed true or false value
176207
class PartitionFinder : public StmtExprVisitor {
177208
public:
178209
explicit PartitionFinder(Var current_var,
179210
const std::unordered_map<const VarNode*, IntSet>& hint_map,
180-
const std::unordered_map<const VarNode*, IntSet>& relax_map)
181-
: current_var_(current_var), hint_map_(hint_map), relax_map_(relax_map) {
211+
const std::unordered_map<const VarNode*, IntSet>& relax_map,
212+
bool has_partition_hint)
213+
: current_var_(current_var),
214+
has_partition_hint_(has_partition_hint),
215+
hint_map_(hint_map),
216+
relax_map_(relax_map) {
182217
for (const auto& kv : hint_map) {
183218
out_vars_.insert(kv.first);
184219
}
@@ -218,33 +253,43 @@ class PartitionFinder : public StmtExprVisitor {
218253

219254
void VisitExpr_(const CallNode* op) final {
220255
if (op->op.same_as(builtin::likely())) {
221-
PrimExpr cond = op->args[0];
222-
if (UsesVar(cond, [this](const VarNode* var) { return var == current_var_.get(); })) {
223-
// For cond, find out the interval, if exists, in which we can prove that cond is
224-
// true. Also find the interval, if exists, in which we can prove that cond is
225-
// false.
226-
IntSet interval = DeduceBound(current_var_, cond, hint_map_, relax_map_);
227-
if (!interval.IsNothing()) {
228-
// cond is true within interval
229-
partitions[{cond, true}] = interval;
230-
}
231-
PrimExpr inverse_cond = InverseCond(cond);
232-
if (inverse_cond.defined()) {
233-
IntSet interval = DeduceBound(current_var_, inverse_cond, hint_map_, relax_map_);
234-
if (!interval.IsNothing()) {
235-
// cond is false within interval
236-
partitions[{cond, false}] = interval;
237-
}
238-
}
239-
}
256+
DeduceCondition(op->args[0]);
240257
} else {
241258
StmtExprVisitor::VisitExpr_(op);
242259
}
243260
}
244261

262+
DEFINE_PARTITION_FINDER_VISIT_CMP_OP(GENode);
263+
DEFINE_PARTITION_FINDER_VISIT_CMP_OP(GTNode);
264+
DEFINE_PARTITION_FINDER_VISIT_CMP_OP(LENode);
265+
DEFINE_PARTITION_FINDER_VISIT_CMP_OP(LTNode);
266+
DEFINE_PARTITION_FINDER_VISIT_CMP_OP(EQNode);
267+
DEFINE_PARTITION_FINDER_VISIT_CMP_OP(NENode);
268+
245269
Partition partitions;
246270

247271
private:
272+
void DeduceCondition(const PrimExpr& cond) {
273+
// For cond, find out the interval, if exists, in which we can prove that cond is
274+
// true. Also find the interval, if exists, in which we can prove that cond is
275+
// false.
276+
if (UsesVar(cond, [this](const VarNode* var) { return var == current_var_.get(); })) {
277+
IntSet interval = DeduceBound(current_var_, cond, hint_map_, relax_map_);
278+
if (!interval.IsNothing()) {
279+
// cond is true within interval
280+
partitions[{cond, true}] = interval;
281+
}
282+
PrimExpr inverse_cond = InverseCond(cond);
283+
if (inverse_cond.defined()) {
284+
IntSet interval = DeduceBound(current_var_, inverse_cond, hint_map_, relax_map_);
285+
if (!interval.IsNothing()) {
286+
// cond is false within interval
287+
partitions[{cond, false}] = interval;
288+
}
289+
}
290+
}
291+
}
292+
248293
PrimExpr InverseCond(const PrimExpr& cond) {
249294
PrimExpr inverse_cond;
250295
if (const LTNode* op = cond.as<LTNode>()) {
@@ -270,6 +315,7 @@ class PartitionFinder : public StmtExprVisitor {
270315
}
271316

272317
Var current_var_;
318+
bool has_partition_hint_;
273319
std::unordered_set<const VarNode*> out_vars_;
274320
std::unordered_map<const VarNode*, IntSet> hint_map_;
275321
std::unordered_map<const VarNode*, IntSet> relax_map_;
@@ -472,7 +518,8 @@ Stmt LoopPartitioner::TryPartition(const Stmt& stmt, Var var, PrimExpr min, Prim
472518
// include hint of var.
473519
hint_map_.insert({var.get(), IntSet::Interval(min, max)});
474520

475-
PartitionFinder finder(var, hint_map_, relax_map_);
521+
bool has_partition_hint_ = selector.partition_hint_vars.count(var.get());
522+
PartitionFinder finder(var, hint_map_, relax_map_, has_partition_hint_);
476523
finder(body);
477524

478525
hint_map_.erase(var.get());
@@ -601,7 +648,7 @@ inline Stmt LoopPartitioner::MakeFor(const Object* node, PrimExpr extent, Stmt b
601648
}
602649
}
603650

604-
class RemoveLikelyTags : public StmtExprMutator {
651+
class RemoveLikelyTagsAndHints : public StmtExprMutator {
605652
public:
606653
PrimExpr VisitExpr_(const CallNode* op) final {
607654
if (op->op.same_as(builtin::likely())) {
@@ -611,12 +658,19 @@ class RemoveLikelyTags : public StmtExprMutator {
611658
return StmtExprMutator::VisitExpr_(op);
612659
}
613660
}
661+
662+
Stmt VisitStmt_(const AttrStmtNode* op) final {
663+
if (op->attr_key == attr::pragma_loop_partition_hint) {
664+
return VisitStmt(op->body);
665+
}
666+
return StmtExprMutator::VisitStmt_(op);
667+
}
614668
};
615669

616670
Stmt LoopPartition(Stmt stmt, bool partition_const_loop, bool no_unroll_loop_with_extent_one) {
617671
stmt = LoopPartitioner(partition_const_loop, no_unroll_loop_with_extent_one)
618672
.VisitAndMutate(std::move(stmt));
619-
stmt = RemoveLikelyTags()(std::move(stmt));
673+
stmt = RemoveLikelyTagsAndHints()(std::move(stmt));
620674
return stmt;
621675
}
622676

tests/python/unittest/test_tir_transform_loop_partition.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import tvm
1818
import tvm.testing
1919
from tvm import te
20+
from tvm import tir
21+
from tvm.script import ty
2022
import numpy
2123

2224

@@ -434,7 +436,6 @@ def test_conv_tiling():
434436
oho, owo, ohi, owi = s[conv].tile(oh, ow, 16, 16)
435437
bounds = tvm.te.schedule.InferBound(s)
436438
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
437-
438439
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
439440
with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}):
440441
mod = tvm.tir.transform.LoopPartition()(mod)
@@ -538,6 +539,33 @@ def test_simple_rfactor():
538539
assert not tvm.ir.structural_equal(stmt1.body, stmt2.body)
539540

540541

542+
@tvm.script.tir
543+
def partitioned_concat(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
544+
tir.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
545+
A = tir.match_buffer(a, [16], dtype="float32")
546+
B = tir.match_buffer(b, [16], dtype="float32")
547+
C = tir.match_buffer(c, [32], dtype="float32")
548+
for i in tir.serial(0, 16):
549+
tir.store(C.data, i, tir.load("float32", A.data, i), True)
550+
for i in tir.serial(0, 16):
551+
tir.store(C.data, i + 16, tir.load("float32", B.data, i + 16), True)
552+
553+
554+
def test_explicit_partition_hint():
555+
A = te.placeholder((16,), name="A")
556+
B = te.placeholder((16,), name="B")
557+
C = te.compute((32,), lambda i: te.if_then_else(i < 16, A[i], B[i]), name="C")
558+
s = te.create_schedule(C.op)
559+
s.normalize()
560+
s[C].pragma(s[C].op.axis[0], "loop_partition_hint")
561+
mod = tvm.driver.build_module.schedule_to_module(s, [A, B, C], "main", None)
562+
with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}):
563+
mod = tvm.tir.transform.StorageFlatten(64)(mod)
564+
mod = tvm.tir.transform.LoopPartition()(mod)
565+
mod = tvm.tir.transform.Simplify()(mod)
566+
assert tvm.ir.structural_equal(mod["main"], partitioned_concat)
567+
568+
541569
if __name__ == "__main__":
542570
test_basic()
543571
test_const_loop()
@@ -559,3 +587,4 @@ def test_simple_rfactor():
559587
test_double_splitting_with_indivisible_factors()
560588
test_multilevel_splitting_with_indivisble_factors()
561589
test_simple_rfactor()
590+
test_explicit_partition_hint()

0 commit comments

Comments
 (0)