Skip to content

Commit b9a243d

Browse files
ChuanqiXu9tstellar
authored andcommitted
[Coroutines] Enhance symmetric transfer for constant CmpInst
This fixes bug52896. Simply, some symmetric transfer optimization chances get invalided due to we delete some inlined optimization passes in 822b92a. This would cause stack-overflow in some situations which should be avoided by the design of coroutine. This patch tries to fix this by transforming the constant CmpInst instruction which was done in the deleted passes. Reviewed By: rjmccall, junparser Differential Revision: https://reviews.llvm.org/D116327 (cherry picked from commit 403772f)
1 parent 9d9efb1 commit b9a243d

File tree

2 files changed

+128
-35
lines changed

2 files changed

+128
-35
lines changed

llvm/lib/Transforms/Coroutines/CoroSplit.cpp

Lines changed: 63 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "llvm/Analysis/CFG.h"
3030
#include "llvm/Analysis/CallGraph.h"
3131
#include "llvm/Analysis/CallGraphSCCPass.h"
32+
#include "llvm/Analysis/ConstantFolding.h"
3233
#include "llvm/Analysis/LazyCallGraph.h"
3334
#include "llvm/IR/Argument.h"
3435
#include "llvm/IR/Attributes.h"
@@ -1174,6 +1175,15 @@ scanPHIsAndUpdateValueMap(Instruction *Prev, BasicBlock *NewBlock,
11741175
static bool simplifyTerminatorLeadingToRet(Instruction *InitialInst) {
11751176
DenseMap<Value *, Value *> ResolvedValues;
11761177
BasicBlock *UnconditionalSucc = nullptr;
1178+
assert(InitialInst->getModule());
1179+
const DataLayout &DL = InitialInst->getModule()->getDataLayout();
1180+
1181+
auto TryResolveConstant = [&ResolvedValues](Value *V) {
1182+
auto It = ResolvedValues.find(V);
1183+
if (It != ResolvedValues.end())
1184+
V = It->second;
1185+
return dyn_cast<ConstantInt>(V);
1186+
};
11771187

11781188
Instruction *I = InitialInst;
11791189
while (I->isTerminator() ||
@@ -1190,47 +1200,65 @@ static bool simplifyTerminatorLeadingToRet(Instruction *InitialInst) {
11901200
}
11911201
if (auto *BR = dyn_cast<BranchInst>(I)) {
11921202
if (BR->isUnconditional()) {
1193-
BasicBlock *BB = BR->getSuccessor(0);
1203+
BasicBlock *Succ = BR->getSuccessor(0);
11941204
if (I == InitialInst)
1195-
UnconditionalSucc = BB;
1196-
scanPHIsAndUpdateValueMap(I, BB, ResolvedValues);
1197-
I = BB->getFirstNonPHIOrDbgOrLifetime();
1205+
UnconditionalSucc = Succ;
1206+
scanPHIsAndUpdateValueMap(I, Succ, ResolvedValues);
1207+
I = Succ->getFirstNonPHIOrDbgOrLifetime();
1208+
continue;
1209+
}
1210+
1211+
BasicBlock *BB = BR->getParent();
1212+
// Handle the case the condition of the conditional branch is constant.
1213+
// e.g.,
1214+
//
1215+
// br i1 false, label %cleanup, label %CoroEnd
1216+
//
1217+
// It is possible during the transformation. We could continue the
1218+
// simplifying in this case.
1219+
if (ConstantFoldTerminator(BB, /*DeleteDeadConditions=*/true)) {
1220+
// Handle this branch in next iteration.
1221+
I = BB->getTerminator();
11981222
continue;
11991223
}
12001224
} else if (auto *CondCmp = dyn_cast<CmpInst>(I)) {
1225+
// If the case number of suspended switch instruction is reduced to
1226+
// 1, then it is simplified to CmpInst in llvm::ConstantFoldTerminator.
12011227
auto *BR = dyn_cast<BranchInst>(I->getNextNode());
1202-
if (BR && BR->isConditional() && CondCmp == BR->getCondition()) {
1203-
// If the case number of suspended switch instruction is reduced to
1204-
// 1, then it is simplified to CmpInst in llvm::ConstantFoldTerminator.
1205-
// And the comparsion looks like : %cond = icmp eq i8 %V, constant.
1206-
ConstantInt *CondConst = dyn_cast<ConstantInt>(CondCmp->getOperand(1));
1207-
if (CondConst && CondCmp->getPredicate() == CmpInst::ICMP_EQ) {
1208-
Value *V = CondCmp->getOperand(0);
1209-
auto it = ResolvedValues.find(V);
1210-
if (it != ResolvedValues.end())
1211-
V = it->second;
1212-
1213-
if (ConstantInt *Cond0 = dyn_cast<ConstantInt>(V)) {
1214-
BasicBlock *BB = Cond0->equalsInt(CondConst->getZExtValue())
1215-
? BR->getSuccessor(0)
1216-
: BR->getSuccessor(1);
1217-
scanPHIsAndUpdateValueMap(I, BB, ResolvedValues);
1218-
I = BB->getFirstNonPHIOrDbgOrLifetime();
1219-
continue;
1220-
}
1221-
}
1222-
}
1228+
if (!BR || !BR->isConditional() || CondCmp != BR->getCondition())
1229+
return false;
1230+
1231+
// And the comparsion looks like : %cond = icmp eq i8 %V, constant.
1232+
// So we try to resolve constant for the first operand only since the
1233+
// second operand should be literal constant by design.
1234+
ConstantInt *Cond0 = TryResolveConstant(CondCmp->getOperand(0));
1235+
auto *Cond1 = dyn_cast<ConstantInt>(CondCmp->getOperand(1));
1236+
if (!Cond0 || !Cond1)
1237+
return false;
1238+
1239+
// Both operands of the CmpInst are Constant. So that we could evaluate
1240+
// it immediately to get the destination.
1241+
auto *ConstResult =
1242+
dyn_cast_or_null<ConstantInt>(ConstantFoldCompareInstOperands(
1243+
CondCmp->getPredicate(), Cond0, Cond1, DL));
1244+
if (!ConstResult)
1245+
return false;
1246+
1247+
CondCmp->replaceAllUsesWith(ConstResult);
1248+
CondCmp->eraseFromParent();
1249+
1250+
// Handle this branch in next iteration.
1251+
I = BR;
1252+
continue;
12231253
} else if (auto *SI = dyn_cast<SwitchInst>(I)) {
1224-
Value *V = SI->getCondition();
1225-
auto it = ResolvedValues.find(V);
1226-
if (it != ResolvedValues.end())
1227-
V = it->second;
1228-
if (ConstantInt *Cond = dyn_cast<ConstantInt>(V)) {
1229-
BasicBlock *BB = SI->findCaseValue(Cond)->getCaseSuccessor();
1230-
scanPHIsAndUpdateValueMap(I, BB, ResolvedValues);
1231-
I = BB->getFirstNonPHIOrDbgOrLifetime();
1232-
continue;
1233-
}
1254+
ConstantInt *Cond = TryResolveConstant(SI->getCondition());
1255+
if (!Cond)
1256+
return false;
1257+
1258+
BasicBlock *BB = SI->findCaseValue(Cond)->getCaseSuccessor();
1259+
scanPHIsAndUpdateValueMap(I, BB, ResolvedValues);
1260+
I = BB->getFirstNonPHIOrDbgOrLifetime();
1261+
continue;
12341262
}
12351263
return false;
12361264
}
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
; Tests that coro-split will convert a call before coro.suspend to a musttail call
2+
; while the user of the coro.suspend is a icmpinst.
3+
; RUN: opt < %s -passes='cgscc(coro-split),simplifycfg,early-cse' -S | FileCheck %s
4+
5+
define void @fakeresume1(i8*) {
6+
entry:
7+
ret void;
8+
}
9+
10+
define void @f() #0 {
11+
entry:
12+
%id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null)
13+
%alloc = call i8* @malloc(i64 16) #3
14+
%vFrame = call noalias nonnull i8* @llvm.coro.begin(token %id, i8* %alloc)
15+
16+
%save = call token @llvm.coro.save(i8* null)
17+
18+
%init_suspend = call i8 @llvm.coro.suspend(token %save, i1 false)
19+
switch i8 %init_suspend, label %coro.end [
20+
i8 0, label %await.ready
21+
i8 1, label %coro.end
22+
]
23+
await.ready:
24+
%save2 = call token @llvm.coro.save(i8* null)
25+
26+
call fastcc void @fakeresume1(i8* align 8 null)
27+
%suspend = call i8 @llvm.coro.suspend(token %save2, i1 true)
28+
%switch = icmp ult i8 %suspend, 2
29+
br i1 %switch, label %cleanup, label %coro.end
30+
31+
cleanup:
32+
%free.handle = call i8* @llvm.coro.free(token %id, i8* %vFrame)
33+
%.not = icmp eq i8* %free.handle, null
34+
br i1 %.not, label %coro.end, label %coro.free
35+
36+
coro.free:
37+
call void @delete(i8* nonnull %free.handle) #2
38+
br label %coro.end
39+
40+
coro.end:
41+
call i1 @llvm.coro.end(i8* null, i1 false)
42+
ret void
43+
}
44+
45+
; CHECK-LABEL: @f.resume(
46+
; CHECK: musttail call fastcc void @fakeresume1(
47+
; CHECK-NEXT: ret void
48+
49+
declare token @llvm.coro.id(i32, i8* readnone, i8* nocapture readonly, i8*) #1
50+
declare i1 @llvm.coro.alloc(token) #2
51+
declare i64 @llvm.coro.size.i64() #3
52+
declare i8* @llvm.coro.begin(token, i8* writeonly) #2
53+
declare token @llvm.coro.save(i8*) #2
54+
declare i8* @llvm.coro.frame() #3
55+
declare i8 @llvm.coro.suspend(token, i1) #2
56+
declare i8* @llvm.coro.free(token, i8* nocapture readonly) #1
57+
declare i1 @llvm.coro.end(i8*, i1) #2
58+
declare i8* @llvm.coro.subfn.addr(i8* nocapture readonly, i8) #1
59+
declare i8* @malloc(i64)
60+
declare void @delete(i8* nonnull) #2
61+
62+
attributes #0 = { "coroutine.presplit"="1" }
63+
attributes #1 = { argmemonly nounwind readonly }
64+
attributes #2 = { nounwind }
65+
attributes #3 = { nounwind readnone }

0 commit comments

Comments
 (0)