Skip to content

Commit

Permalink
Vector lowering improvents in GC placement
Browse files Browse the repository at this point in the history
Support for vectors of tracked pointer was incomplete in the GC placement
pass. Try to fix as many cases as possible and add some tests. A refactor
to make all of this nicer (vectors weren't originally part of the implementation
might be good), but for now, let's get it correct first.

Fixes #28536
  • Loading branch information
Keno committed Sep 3, 2018
1 parent 2e603aa commit b1dac9f
Show file tree
Hide file tree
Showing 2 changed files with 266 additions and 75 deletions.
236 changes: 161 additions & 75 deletions src/llvm-late-gc-lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -349,8 +349,8 @@ struct LateLowerGCFrame: public FunctionPass {
NoteUse(S, BBS, V, BBS.UpExposedUses);
}
Value *MaybeExtractUnion(std::pair<Value*,int> Val, Instruction *InsertBefore);
int LiftPhi(State &S, PHINode *Phi);
int LiftSelect(State &S, SelectInst *SI);
void LiftPhi(State &S, PHINode *Phi, SmallVector<int, 16> &PHINumbers);
bool LiftSelect(State &S, SelectInst *SI);
int Number(State &S, Value *V);
std::vector<int> NumberVector(State &S, Value *Vec);
int NumberBase(State &S, Value *V, Value *Base);
Expand Down Expand Up @@ -383,7 +383,10 @@ struct LateLowerGCFrame: public FunctionPass {
};

static unsigned getValueAddrSpace(Value *V) {
return cast<PointerType>(V->getType())->getAddressSpace();
Type *Ty = V->getType();
if (isa<VectorType>(Ty))
Ty = cast<VectorType>(V->getType())->getElementType();
return cast<PointerType>(Ty)->getAddressSpace();
}

static bool isSpecialPtr(Type *Ty) {
Expand Down Expand Up @@ -508,42 +511,108 @@ Value *LateLowerGCFrame::MaybeExtractUnion(std::pair<Value*,int> Val, Instructio
return Val.first;
}

int LateLowerGCFrame::LiftSelect(State &S, SelectInst *SI) {
Value *TrueBase = MaybeExtractUnion(FindBaseValue(S, SI->getTrueValue(), false), SI);
Value *FalseBase = MaybeExtractUnion(FindBaseValue(S, SI->getFalseValue(), false), SI);
if (getValueAddrSpace(TrueBase) != AddressSpace::Tracked)
TrueBase = ConstantPointerNull::get(cast<PointerType>(FalseBase->getType()));
if (getValueAddrSpace(FalseBase) != AddressSpace::Tracked)
FalseBase = ConstantPointerNull::get(cast<PointerType>(TrueBase->getType()));
if (getValueAddrSpace(TrueBase) != AddressSpace::Tracked)
return -1;
Value *SelectBase = SelectInst::Create(SI->getCondition(),
TrueBase, FalseBase, "gclift", SI);
int Number = ++S.MaxPtrNumber;
S.PtrNumbering[SelectBase] = S.AllPtrNumbering[SelectBase] =
S.AllPtrNumbering[SI] = Number;
S.ReversePtrNumbering[Number] = SelectBase;
return Number;
static Value *GetPtrForNumber(State &S, unsigned Num, Instruction *InsertionPoint)
{
Value *Val = S.ReversePtrNumbering[Num];
if (isSpecialPtrVec(Val->getType())) {
const std::vector<int> &AllNums = S.AllVectorNumbering[Val];
unsigned Idx = 0;
for (; Idx < AllNums.size(); ++Idx) {
if ((unsigned)AllNums[Idx] == Num)
break;
}
Val = ExtractElementInst::Create(Val, ConstantInt::get(
Type::getInt32Ty(Val->getContext()), Idx), "", InsertionPoint);
}
return Val;
}

bool LateLowerGCFrame::LiftSelect(State &S, SelectInst *SI) {
if (isSpecialPtrVec(SI->getType())) {
VectorType *VT = cast<VectorType>(SI->getType());
std::vector<int> TrueNumbers = NumberVector(S, SI->getTrueValue());
std::vector<int> FalseNumbers = NumberVector(S, SI->getFalseValue());
std::vector<int> Numbers;
for (unsigned i = 0; i < VT->getNumElements(); ++i) {
SelectInst *LSI = SelectInst::Create(SI->getCondition(),
TrueNumbers[i] < 0 ?
ConstantPointerNull::get(cast<PointerType>(T_prjlvalue)) :
GetPtrForNumber(S, TrueNumbers[i], SI),
FalseNumbers[i] < 0 ?
ConstantPointerNull::get(cast<PointerType>(T_prjlvalue)) :
GetPtrForNumber(S, FalseNumbers[i], SI),
"gclift", SI);
int Number = ++S.MaxPtrNumber;
Numbers.push_back(Number);
S.PtrNumbering[LSI] = S.AllPtrNumbering[LSI] = Number;
S.ReversePtrNumbering[Number] = LSI;
}
S.AllVectorNumbering[SI] = Numbers;
} else {
Value *TrueBase = MaybeExtractUnion(FindBaseValue(S, SI->getTrueValue(), false), SI);
Value *FalseBase = MaybeExtractUnion(FindBaseValue(S, SI->getFalseValue(), false), SI);
if (getValueAddrSpace(TrueBase) != AddressSpace::Tracked)
TrueBase = ConstantPointerNull::get(cast<PointerType>(FalseBase->getType()));
if (getValueAddrSpace(FalseBase) != AddressSpace::Tracked)
FalseBase = ConstantPointerNull::get(cast<PointerType>(TrueBase->getType()));
if (getValueAddrSpace(TrueBase) != AddressSpace::Tracked)
return false;
Value *SelectBase = SelectInst::Create(SI->getCondition(),
TrueBase, FalseBase, "gclift", SI);
int Number = ++S.MaxPtrNumber;
S.PtrNumbering[SelectBase] = S.AllPtrNumbering[SelectBase] =
S.AllPtrNumbering[SI] = Number;
S.ReversePtrNumbering[Number] = SelectBase;
}
return true;
}

int LateLowerGCFrame::LiftPhi(State &S, PHINode *Phi)
void LateLowerGCFrame::LiftPhi(State &S, PHINode *Phi, SmallVector<int, 16> &PHINumbers)
{
PHINode *lift = PHINode::Create(T_prjlvalue, Phi->getNumIncomingValues(), "gclift", Phi);
for (unsigned i = 0; i < Phi->getNumIncomingValues(); ++i) {
Value *Incoming = Phi->getIncomingValue(i);
Value *Base = MaybeExtractUnion(FindBaseValue(S, Incoming, false),
Phi->getIncomingBlock(i)->getTerminator());
if (getValueAddrSpace(Base) != AddressSpace::Tracked)
Base = ConstantPointerNull::get(cast<PointerType>(T_prjlvalue));
if (Base->getType() != T_prjlvalue)
Base = new BitCastInst(Base, T_prjlvalue, "", Phi->getIncomingBlock(i)->getTerminator());
lift->addIncoming(Base, Phi->getIncomingBlock(i));
if (isSpecialPtrVec(Phi->getType())) {
VectorType *VT = cast<VectorType>(Phi->getType());
std::vector<PHINode *> lifted;
for (unsigned i = 0; i < VT->getNumElements(); ++i) {
lifted.push_back(PHINode::Create(T_prjlvalue, Phi->getNumIncomingValues(), "gclift", Phi));
}
for (unsigned i = 0; i < Phi->getNumIncomingValues(); ++i) {
std::vector<int> Numbers = NumberVector(S, Phi->getIncomingValue(i));
BasicBlock *IncomingBB = Phi->getIncomingBlock(i);
Instruction *Terminator = IncomingBB->getTerminator();
for (unsigned i = 0; i < VT->getNumElements(); ++i) {
if (Numbers[i] < 0)
lifted[i]->addIncoming(ConstantPointerNull::get(cast<PointerType>(T_prjlvalue)), IncomingBB);
else
lifted[i]->addIncoming(GetPtrForNumber(S, Numbers[i], Terminator), IncomingBB);
}
}
std::vector<int> Numbers;
for (unsigned i = 0; i < VT->getNumElements(); ++i) {
int Number = ++S.MaxPtrNumber;
PHINumbers.push_back(Number);
Numbers.push_back(Number);
S.PtrNumbering[lifted[i]] = S.AllPtrNumbering[lifted[i]] = Number;
S.ReversePtrNumbering[Number] = lifted[i];
}
S.AllVectorNumbering[Phi] = Numbers;
} else {
PHINode *lift = PHINode::Create(T_prjlvalue, Phi->getNumIncomingValues(), "gclift", Phi);
for (unsigned i = 0; i < Phi->getNumIncomingValues(); ++i) {
Value *Incoming = Phi->getIncomingValue(i);
Value *Base = MaybeExtractUnion(FindBaseValue(S, Incoming, false),
Phi->getIncomingBlock(i)->getTerminator());
if (getValueAddrSpace(Base) != AddressSpace::Tracked)
Base = ConstantPointerNull::get(cast<PointerType>(T_prjlvalue));
if (Base->getType() != T_prjlvalue)
Base = new BitCastInst(Base, T_prjlvalue, "", Phi->getIncomingBlock(i)->getTerminator());
lift->addIncoming(Base, Phi->getIncomingBlock(i));
}
int Number = ++S.MaxPtrNumber;
PHINumbers.push_back(Number);
S.PtrNumbering[lift] = S.AllPtrNumbering[lift] =
S.AllPtrNumbering[Phi] = Number;
S.ReversePtrNumbering[Number] = lift;
}
int Number = ++S.MaxPtrNumber;
S.PtrNumbering[lift] = S.AllPtrNumbering[lift] =
S.AllPtrNumbering[Phi] = Number;
S.ReversePtrNumbering[Number] = lift;
return Number;
}

int LateLowerGCFrame::NumberBase(State &S, Value *V, Value *CurrentV)
Expand All @@ -566,12 +635,14 @@ int LateLowerGCFrame::NumberBase(State &S, Value *V, Value *CurrentV)
// input IR)
Number = -1;
} else if (isa<SelectInst>(CurrentV) && !isUnion && getValueAddrSpace(CurrentV) != AddressSpace::Tracked) {
int Number = LiftSelect(S, cast<SelectInst>(CurrentV));
S.AllPtrNumbering[V] = Number;
Number = -1;
if (LiftSelect(S, cast<SelectInst>(CurrentV)))
Number = S.AllPtrNumbering[V] = S.AllPtrNumbering.at(CurrentV);
return Number;
} else if (isa<PHINode>(CurrentV) && !isUnion && getValueAddrSpace(CurrentV) != AddressSpace::Tracked) {
int Number = LiftPhi(S, cast<PHINode>(CurrentV));
S.AllPtrNumbering[V] = Number;
SmallVector<int, 16> PHINumbers;
LiftPhi(S, cast<PHINode>(CurrentV), PHINumbers);
Number = S.AllPtrNumbering[V] = S.AllPtrNumbering.at(CurrentV);
return Number;
} else if (isa<ExtractValueInst>(CurrentV) && !isUnion) {
assert(false && "TODO: Extract");
Expand Down Expand Up @@ -630,15 +701,23 @@ std::vector<int> LateLowerGCFrame::NumberVectorBase(State &S, Value *CurrentV) {
Numbers = NumberVectorBase(S, IEI->getOperand(0));
int ElNumber = Number(S, IEI->getOperand(1));
Numbers[idx] = ElNumber;
} else if (isa<LoadInst>(CurrentV) || isa<CallInst>(CurrentV) || isa<PHINode>(CurrentV)) {
} else if (isa<SelectInst>(CurrentV) && getValueAddrSpace(CurrentV) != AddressSpace::Tracked) {
LiftSelect(S, cast<SelectInst>(CurrentV));
Numbers = S.AllVectorNumbering[CurrentV];
} else if (isa<PHINode>(CurrentV) && getValueAddrSpace(CurrentV) != AddressSpace::Tracked) {
SmallVector<int, 16> PHINumbers;
LiftPhi(S, cast<PHINode>(CurrentV), PHINumbers);
Numbers = S.AllVectorNumbering[CurrentV];
} else if (isa<LoadInst>(CurrentV) || isa<CallInst>(CurrentV) || isa<PHINode>(CurrentV) ||
isa<SelectInst>(CurrentV)) {
// This is simple, we can just number them sequentially
for (unsigned i = 0; i < cast<VectorType>(CurrentV->getType())->getNumElements(); ++i) {
int Num = ++S.MaxPtrNumber;
Numbers.push_back(Num);
S.ReversePtrNumbering[Num] = CurrentV;
}
} else {
assert(false && "Unexpected vector generating operating");
assert(false && "Unexpected vector generating operation");
}
S.AllVectorNumbering[CurrentV] = Numbers;
return Numbers;
Expand Down Expand Up @@ -1148,40 +1227,63 @@ State LateLowerGCFrame::LocalScan(Function &F) {
NoteOperandUses(S, BBS, I, BBS.UpExposedUsesUnrooted);
} else if (SelectInst *SI = dyn_cast<SelectInst>(&I)) {
// We need to insert an extra select for the GC root
if (!isSpecialPtr(SI->getType()) && !isUnionRep(SI->getType()))
if (!isSpecialPtr(SI->getType()) && !isSpecialPtrVec(SI->getType()) &&
!isUnionRep(SI->getType()))
continue;
if (!isUnionRep(SI->getType()) && getValueAddrSpace(SI) != AddressSpace::Tracked) {
if (S.AllPtrNumbering.find(SI) != S.AllPtrNumbering.end())
if (isSpecialPtrVec(SI->getType()) ?
S.AllVectorNumbering.find(SI) != S.AllVectorNumbering.end() :
S.AllPtrNumbering.find(SI) != S.AllPtrNumbering.end())
continue;
auto Num = LiftSelect(S, SI);
if (Num < 0)
if (!LiftSelect(S, SI))
continue;
auto SelectBase = cast<SelectInst>(S.ReversePtrNumbering[Num]);
SmallVector<int, 1> RefinedPtr{Number(S, SelectBase->getTrueValue()),
Number(S, SelectBase->getFalseValue())};
S.Refinements[Num] = std::move(RefinedPtr);
if (!isSpecialPtrVec(SI->getType())) {
// TODO: Refinements for vector select
int Num = S.AllPtrNumbering[SI];
if (Num < 0)
continue;
auto SelectBase = cast<SelectInst>(S.ReversePtrNumbering[Num]);
SmallVector<int, 2> RefinedPtr{Number(S, SelectBase->getTrueValue()),
Number(S, SelectBase->getFalseValue())};
S.Refinements[Num] = std::move(RefinedPtr);
}
} else {
SmallVector<int, 1> RefinedPtr{Number(S, SI->getTrueValue()),
Number(S, SI->getFalseValue())};
SmallVector<int, 2> RefinedPtr;
if (!isSpecialPtrVec(SI->getType())) {
RefinedPtr = {
Number(S, SI->getTrueValue()),
Number(S, SI->getFalseValue())
};
}
MaybeNoteDef(S, BBS, SI, BBS.Safepoints, std::move(RefinedPtr));
NoteOperandUses(S, BBS, I, BBS.UpExposedUsesUnrooted);
}
} else if (PHINode *Phi = dyn_cast<PHINode>(&I)) {
if (!isSpecialPtr(Phi->getType()) && !isUnionRep(Phi->getType())) {
if (!isSpecialPtr(Phi->getType()) && !isSpecialPtrVec(Phi->getType()) &&
!isUnionRep(Phi->getType())) {
continue;
}
auto nIncoming = Phi->getNumIncomingValues();
// We need to insert an extra phi for the GC root
if (!isUnionRep(Phi->getType()) && getValueAddrSpace(Phi) != AddressSpace::Tracked) {
if (S.AllPtrNumbering.find(Phi) != S.AllPtrNumbering.end())
if (isSpecialPtrVec(Phi->getType()) ?
S.AllVectorNumbering.find(Phi) != S.AllVectorNumbering.end() :
S.AllPtrNumbering.find(Phi) != S.AllPtrNumbering.end())
continue;
auto Num = LiftPhi(S, Phi);
auto lift = cast<PHINode>(S.ReversePtrNumbering[Num]);
S.Refinements[Num] = GetPHIRefinements(lift, S);
PHINumbers.push_back(Num);
LiftPhi(S, Phi, PHINumbers);
} else {
MaybeNoteDef(S, BBS, Phi, BBS.Safepoints, GetPHIRefinements(Phi, S));
PHINumbers.push_back(Number(S, Phi));
SmallVector<int, 1> PHIRefinements;
if (!isSpecialPtrVec(Phi->getType()))
PHIRefinements = GetPHIRefinements(Phi, S);
MaybeNoteDef(S, BBS, Phi, BBS.Safepoints, std::move(PHIRefinements));
if (isSpecialPtrVec(Phi->getType())) {
// TODO: Vector refinements
std::vector<int> Nums = NumberVector(S, Phi);
for (int Num : Nums)
PHINumbers.push_back(Num);
} else {
PHINumbers.push_back(Number(S, Phi));
}
for (unsigned i = 0; i < nIncoming; ++i) {
BBState &IncomingBBS = S.BBStates[Phi->getIncomingBlock(i)];
NoteUse(S, IncomingBBS, Phi->getIncomingValue(i), IncomingBBS.PhiOuts);
Expand Down Expand Up @@ -1776,22 +1878,6 @@ bool LateLowerGCFrame::CleanupIR(Function &F, State *S) {
return ChangesMade;
}

static Value *GetPtrForNumber(State &S, unsigned Num, Instruction *InsertionPoint)
{
Value *Val = S.ReversePtrNumbering[Num];
if (isSpecialPtrVec(Val->getType())) {
const std::vector<int> &AllNums = S.AllVectorNumbering[Val];
unsigned Idx = 0;
for (; Idx < AllNums.size(); ++Idx) {
if ((unsigned)AllNums[Idx] == Num)
break;
}
Val = ExtractElementInst::Create(Val, ConstantInt::get(
Type::getInt32Ty(Val->getContext()), Idx), "", InsertionPoint);
}
return Val;
}

static void AddInPredLiveOuts(BasicBlock *BB, BitVector &LiveIn, State &S)
{
bool First = true;
Expand Down
Loading

0 comments on commit b1dac9f

Please sign in to comment.