Skip to content

Commit 9aa2956

Browse files
huajsjylc
authored andcommitted
[Tir]Adding detail error messages when MatchCopyPattern function is failed. (apache#10244)
There is an error message to show the body when 'MatchCopyPattern' is failed, but the error message not give the information why this function get failed. Adding the detail error information to help trouble shooting.
1 parent 9cc73bc commit 9aa2956

File tree

1 file changed

+25
-7
lines changed

1 file changed

+25
-7
lines changed

src/tir/transforms/inject_copy_intrin.cc

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,26 +45,35 @@ class CopyIntrinInjector : public StmtMutator {
4545
Stmt VisitStmt_(const AttrStmtNode* op) final {
4646
if (op->attr_key == pragma_key_) {
4747
Stmt ret;
48-
ICHECK(MatchCopyPattern(op->body, &ret)) << "Cannot match copy pattern of " << op->body;
48+
std::string error_info;
49+
ICHECK(MatchCopyPattern(op->body, &ret, &error_info))
50+
<< "Cannot match copy pattern. The error is " << error_info << " The body is "
51+
<< op->body;
4952
return ret;
5053
}
5154
return StmtMutator::VisitStmt_(op);
5255
}
5356

5457
private:
55-
bool MatchCopyPattern(Stmt stmt, Stmt* out) {
58+
bool MatchCopyPattern(Stmt stmt, Stmt* out, std::string* error_info) {
5659
using namespace arith;
5760
Stmt body = stmt;
5861

5962
// strip the loops
6063
std::vector<const ForNode*> loops;
6164
while (const ForNode* op = body.as<ForNode>()) {
62-
if (!is_zero(op->min)) return false;
65+
if (!is_zero(op->min)) {
66+
*error_info = "the 'min' value of body 'Fonode' is 0.";
67+
return false;
68+
}
6369
loops.push_back(op);
6470
body = op->body;
6571
}
6672
const StoreNode* store = body.as<StoreNode>();
67-
if (store == nullptr) return false;
73+
if (store == nullptr) {
74+
*error_info = "the 'StoreNode' of body is a nullptr.";
75+
return false;
76+
}
6877
// Expr sel_cond, sel_true_value, sel_false_value;
6978
// match select or if
7079
PVar<PrimExpr> sel_cond, sel_true_value, sel_false_value;
@@ -84,7 +93,10 @@ class CopyIntrinInjector : public StmtMutator {
8493
if (cast != nullptr) {
8594
load = cast->value.as<LoadNode>();
8695
}
87-
if (load == nullptr) return false;
96+
if (load == nullptr) {
97+
*error_info = "the 'LoadNode' of body is a nullptr.";
98+
return false;
99+
}
88100
if (load->dtype.lanes() != 1) return false;
89101
Array<Var> loop_vars;
90102
for (const ForNode* op : loops) {
@@ -109,7 +121,10 @@ class CopyIntrinInjector : public StmtMutator {
109121
if (has_cond) {
110122
Array<PrimExpr> clip_bound = arith::DetectClipBound(sel_cond.Eval(), loop_vars);
111123
pad_value = sel_false_value.Eval();
112-
if (clip_bound.size() == 0) return false;
124+
if (clip_bound.size() == 0) {
125+
*error_info = "the size of clip bound is 0.";
126+
return false;
127+
}
113128
ICHECK_EQ(src_shape.size(), loop_vars.size());
114129
ICHECK_EQ(clip_bound.size(), loop_vars.size() * 2);
115130
for (size_t i = 0; i < src_shape.size(); ++i) {
@@ -150,7 +165,10 @@ class CopyIntrinInjector : public StmtMutator {
150165
Buffer src = Buffer(load->buffer_var, load->dtype, src_shape, src_strides, src_elem_offset,
151166
load->buffer_var->name_hint, 0, 0, kDefault);
152167
*out = flower_copy_fromto_(src, dst, pad_before, pad_after, pad_value);
153-
ICHECK(out->defined()) << "flower function did not return correct stmt";
168+
if (!out->defined()) {
169+
*error_info = "flower function did not return correct stmt";
170+
return false;
171+
}
154172
return true;
155173
}
156174

0 commit comments

Comments
 (0)