Skip to content

Commit 3b193b7

Browse files
FrankYzyylc
authored andcommitted
[Relay] Add a non-recursive LetNode VisitExpr_ for LabelOps Pass to avoid stack overflow (apache#8917)
* Add a non-recursive Let VisitExpr_ for LabelOps * fake commit to retrigger CI * fake commit to retrigger the CI * fix CI issue * fix CI issue
1 parent 584d054 commit 3b193b7

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

src/relay/transforms/label_ops.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,25 @@ class LabelOpsMutator : public MixedModeMutator {
7777
}
7878
return std::move(f);
7979
}
80+
Expr VisitExpr_(const LetNode* op) final {
81+
auto pre_visit = [this](const LetNode* op) {
82+
this->Mutate(op->var);
83+
this->Mutate(op->value);
84+
};
85+
auto post_visit = [this](const LetNode* op) {
86+
Var var = Downcast<Var>(this->Mutate(op->var));
87+
auto value = this->Mutate(op->value);
88+
auto body = this->Mutate(op->body);
89+
auto expr = GetRef<Expr>(op);
90+
if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) {
91+
this->memo_[expr] = expr;
92+
} else {
93+
this->memo_[expr] = Let(var, value, body);
94+
}
95+
};
96+
ExpandANormalForm(op, pre_visit, post_visit);
97+
return memo_[GetRef<Expr>(op)];
98+
}
8099

81100
Expr Rewrite_(const CallNode* op, const Expr& post) final {
82101
auto updated = MixedModeMutator::Rewrite_(op, post);

0 commit comments

Comments
 (0)