Skip to content

Commit cbe9fae

Browse files
committed
[CodeExtractor] Fix multiple bugs under certain shape of extracted region
Summary: If the extracted region has multiple exported data flows toward the same BB which is not included in the region, correct resotre instructions and PHI nodes won't be generated inside the exitStub. The solution is simply put the restore instructions right after the definition of output values instead of putting in exitStub. Unittest for this bug is included. Author: myhsu Reviewers: chandlerc, davide, lattner, silvas, davidxl, wmi, kuhar Subscribers: dberlin, kuhar, mgorny, llvm-commits Differential Revision: https://reviews.llvm.org/D37902 llvm-svn: 315041
1 parent 08dd582 commit cbe9fae

File tree

3 files changed

+101
-77
lines changed

3 files changed

+101
-77
lines changed

llvm/lib/Transforms/Utils/CodeExtractor.cpp

Lines changed: 31 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -651,19 +651,6 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
651651
return newFunction;
652652
}
653653

654-
/// FindPhiPredForUseInBlock - Given a value and a basic block, find a PHI
655-
/// that uses the value within the basic block, and return the predecessor
656-
/// block associated with that use, or return 0 if none is found.
657-
static BasicBlock* FindPhiPredForUseInBlock(Value* Used, BasicBlock* BB) {
658-
for (Use &U : Used->uses()) {
659-
PHINode *P = dyn_cast<PHINode>(U.getUser());
660-
if (P && P->getParent() == BB)
661-
return P->getIncomingBlock(U);
662-
}
663-
664-
return nullptr;
665-
}
666-
667654
/// emitCallAndSwitchStatement - This method sets up the caller side by adding
668655
/// the call instruction, splitting any PHI nodes in the header block as
669656
/// necessary.
@@ -736,7 +723,8 @@ emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer,
736723
if (!AggregateArgs)
737724
std::advance(OutputArgBegin, inputs.size());
738725

739-
// Reload the outputs passed in by reference
726+
// Reload the outputs passed in by reference.
727+
Function::arg_iterator OAI = OutputArgBegin;
740728
for (unsigned i = 0, e = outputs.size(); i != e; ++i) {
741729
Value *Output = nullptr;
742730
if (AggregateArgs) {
@@ -759,6 +747,34 @@ emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer,
759747
if (!Blocks.count(inst->getParent()))
760748
inst->replaceUsesOfWith(outputs[i], load);
761749
}
750+
751+
// Store to argument right after the definition of output value.
752+
auto *OutI = dyn_cast<Instruction>(outputs[i]);
753+
if (!OutI)
754+
continue;
755+
// Find proper insertion point.
756+
Instruction *InsertPt = OutI->getNextNode();
757+
// Let's assume that there is no other guy interleave non-PHI in PHIs.
758+
if (isa<PHINode>(InsertPt))
759+
InsertPt = InsertPt->getParent()->getFirstNonPHI();
760+
761+
assert(OAI != newFunction->arg_end() &&
762+
"Number of output arguments should match "
763+
"the amount of defined values");
764+
if (AggregateArgs) {
765+
Value *Idx[2];
766+
Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
767+
Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), FirstOut + i);
768+
GetElementPtrInst *GEP = GetElementPtrInst::Create(
769+
StructArgTy, &*OAI, Idx, "gep_" + outputs[i]->getName(), InsertPt);
770+
new StoreInst(outputs[i], GEP, InsertPt);
771+
// Since there should be only one struct argument aggregating
772+
// all the output values, we shouldn't increment OAI, which always
773+
// points to the struct argument, in this case.
774+
} else {
775+
new StoreInst(outputs[i], &*OAI, InsertPt);
776+
++OAI;
777+
}
762778
}
763779

764780
// Now we can emit a switch statement using the call as a value.
@@ -801,75 +817,13 @@ emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer,
801817
break;
802818
}
803819

804-
ReturnInst *NTRet = ReturnInst::Create(Context, brVal, NewTarget);
820+
ReturnInst::Create(Context, brVal, NewTarget);
805821

806822
// Update the switch instruction.
807823
TheSwitch->addCase(ConstantInt::get(Type::getInt16Ty(Context),
808824
SuccNum),
809825
OldTarget);
810826

811-
// Restore values just before we exit
812-
Function::arg_iterator OAI = OutputArgBegin;
813-
for (unsigned out = 0, e = outputs.size(); out != e; ++out) {
814-
// For an invoke, the normal destination is the only one that is
815-
// dominated by the result of the invocation
816-
BasicBlock *DefBlock = cast<Instruction>(outputs[out])->getParent();
817-
818-
bool DominatesDef = true;
819-
820-
BasicBlock *NormalDest = nullptr;
821-
if (auto *Invoke = dyn_cast<InvokeInst>(outputs[out]))
822-
NormalDest = Invoke->getNormalDest();
823-
824-
if (NormalDest) {
825-
DefBlock = NormalDest;
826-
827-
// Make sure we are looking at the original successor block, not
828-
// at a newly inserted exit block, which won't be in the dominator
829-
// info.
830-
for (const auto &I : ExitBlockMap)
831-
if (DefBlock == I.second) {
832-
DefBlock = I.first;
833-
break;
834-
}
835-
836-
// In the extract block case, if the block we are extracting ends
837-
// with an invoke instruction, make sure that we don't emit a
838-
// store of the invoke value for the unwind block.
839-
if (!DT && DefBlock != OldTarget)
840-
DominatesDef = false;
841-
}
842-
843-
if (DT) {
844-
DominatesDef = DT->dominates(DefBlock, OldTarget);
845-
846-
// If the output value is used by a phi in the target block,
847-
// then we need to test for dominance of the phi's predecessor
848-
// instead. Unfortunately, this a little complicated since we
849-
// have already rewritten uses of the value to uses of the reload.
850-
BasicBlock* pred = FindPhiPredForUseInBlock(Reloads[out],
851-
OldTarget);
852-
if (pred && DT && DT->dominates(DefBlock, pred))
853-
DominatesDef = true;
854-
}
855-
856-
if (DominatesDef) {
857-
if (AggregateArgs) {
858-
Value *Idx[2];
859-
Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
860-
Idx[1] = ConstantInt::get(Type::getInt32Ty(Context),
861-
FirstOut+out);
862-
GetElementPtrInst *GEP = GetElementPtrInst::Create(
863-
StructArgTy, &*OAI, Idx, "gep_" + outputs[out]->getName(),
864-
NTRet);
865-
new StoreInst(outputs[out], GEP, NTRet);
866-
} else {
867-
new StoreInst(outputs[out], &*OAI, NTRet);
868-
}
869-
}
870-
// Advance output iterator even if we don't emit a store
871-
if (!AggregateArgs) ++OAI;
872-
}
873827
}
874828

875829
// rewrite the original branch instruction with this new target

llvm/unittests/Transforms/Utils/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ set(LLVM_LINK_COMPONENTS
99
add_llvm_unittest(UtilsTests
1010
ASanStackFrameLayoutTest.cpp
1111
Cloning.cpp
12+
CodeExtractor.cpp
1213
FunctionComparator.cpp
1314
IntegerDivision.cpp
1415
Local.cpp
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
//===- CodeExtractor.cpp - Unit tests for CodeExtractor -------------------===//
2+
//
3+
// The LLVM Compiler Infrastructure
4+
//
5+
// This file is distributed under the University of Illinois Open Source
6+
// License. See LICENSE.TXT for details.
7+
//
8+
//===----------------------------------------------------------------------===//
9+
10+
#include "llvm/Transforms/Utils/CodeExtractor.h"
11+
#include "llvm/AsmParser/Parser.h"
12+
#include "llvm/IR/BasicBlock.h"
13+
#include "llvm/IR/Dominators.h"
14+
#include "llvm/IR/LLVMContext.h"
15+
#include "llvm/IR/Module.h"
16+
#include "llvm/IR/Verifier.h"
17+
#include "llvm/IRReader/IRReader.h"
18+
#include "llvm/Support/SourceMgr.h"
19+
#include "gtest/gtest.h"
20+
21+
using namespace llvm;
22+
23+
namespace {
24+
TEST(CodeExtractor, ExitStub) {
25+
LLVMContext Ctx;
26+
SMDiagnostic Err;
27+
std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
28+
define i32 @foo(i32 %x, i32 %y, i32 %z) {
29+
header:
30+
%0 = icmp ugt i32 %x, %y
31+
br i1 %0, label %body1, label %body2
32+
33+
body1:
34+
%1 = add i32 %z, 2
35+
br label %notExtracted
36+
37+
body2:
38+
%2 = mul i32 %z, 7
39+
br label %notExtracted
40+
41+
notExtracted:
42+
%3 = phi i32 [ %1, %body1 ], [ %2, %body2 ]
43+
%4 = add i32 %3, %x
44+
ret i32 %4
45+
}
46+
)invalid",
47+
Err, Ctx));
48+
49+
Function *Func = M->getFunction("foo");
50+
SmallVector<BasicBlock *, 3> Candidates;
51+
for (auto &BB : *Func) {
52+
if (BB.getName() == "body1")
53+
Candidates.push_back(&BB);
54+
if (BB.getName() == "body2")
55+
Candidates.push_back(&BB);
56+
}
57+
// CodeExtractor requires the first basic block
58+
// to dominate all the other ones.
59+
Candidates.insert(Candidates.begin(), &Func->getEntryBlock());
60+
61+
DominatorTree DT(*Func);
62+
CodeExtractor CE(Candidates, &DT);
63+
EXPECT_TRUE(CE.isEligible());
64+
65+
Function *Outlined = CE.extractCodeRegion();
66+
EXPECT_TRUE(Outlined);
67+
EXPECT_FALSE(verifyFunction(*Outlined));
68+
}
69+
} // end anonymous namespace

0 commit comments

Comments
 (0)