Skip to content

Commit 247e385

Browse files
Kenotkf
authored andcommitted
[LateGCLowering] Fix skipped Select lifting
backport 7a4ea21
1 parent f5dbc47 commit 247e385

File tree

2 files changed

+72
-65
lines changed

2 files changed

+72
-65
lines changed

src/llvm-late-gc-lowering.cpp

+56-65
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ struct LateLowerGCFrame: public FunctionPass, private JuliaPassContext {
331331
}
332332

333333
void LiftPhi(State &S, PHINode *Phi);
334-
bool LiftSelect(State &S, SelectInst *SI);
334+
void LiftSelect(State &S, SelectInst *SI);
335335
Value *MaybeExtractScalar(State &S, std::pair<Value*,int> ValExpr, Instruction *InsertBefore);
336336
std::vector<Value*> MaybeExtractVector(State &S, Value *BaseVec, Instruction *InsertBefore);
337337
Value *GetPtrForNumber(State &S, unsigned Num, Instruction *InsertBefore);
@@ -600,12 +600,12 @@ Value *LateLowerGCFrame::GetPtrForNumber(State &S, unsigned Num, Instruction *In
600600
return MaybeExtractScalar(S, std::make_pair(Val, Idx), InsertBefore);
601601
}
602602

603-
bool LateLowerGCFrame::LiftSelect(State &S, SelectInst *SI) {
603+
void LateLowerGCFrame::LiftSelect(State &S, SelectInst *SI) {
604604
if (isa<PointerType>(SI->getType()) ?
605605
S.AllPtrNumbering.count(SI) :
606606
S.AllCompositeNumbering.count(SI)) {
607607
// already visited here--nothing to do
608-
return true;
608+
return;
609609
}
610610
std::vector<int> Numbers;
611611
unsigned NumRoots = 1;
@@ -617,68 +617,60 @@ bool LateLowerGCFrame::LiftSelect(State &S, SelectInst *SI) {
617617
// find the base root for the arguments
618618
Value *TrueBase = MaybeExtractScalar(S, FindBaseValue(S, SI->getTrueValue(), false), SI);
619619
Value *FalseBase = MaybeExtractScalar(S, FindBaseValue(S, SI->getFalseValue(), false), SI);
620-
Value *V_null = ConstantPointerNull::get(cast<PointerType>(T_prjlvalue));
621-
bool didsplit = false;
622-
if (TrueBase != V_null && FalseBase != V_null) {
623-
std::vector<Value*> TrueBases;
624-
std::vector<Value*> FalseBases;
625-
if (!isa<PointerType>(TrueBase->getType())) {
626-
TrueBases = MaybeExtractVector(S, TrueBase, SI);
627-
assert(TrueBases.size() == Numbers.size());
628-
NumRoots = TrueBases.size();
629-
}
630-
if (!isa<PointerType>(FalseBase->getType())) {
631-
FalseBases = MaybeExtractVector(S, FalseBase, SI);
632-
assert(FalseBases.size() == Numbers.size());
633-
NumRoots = FalseBases.size();
634-
}
635-
if (isa<PointerType>(SI->getType()) ?
636-
S.AllPtrNumbering.count(SI) :
637-
S.AllCompositeNumbering.count(SI)) {
638-
// MaybeExtractScalar or MaybeExtractVector handled this for us (recursively, though a PHINode)
639-
return true;
640-
}
641-
// need to handle each element (may just be one scalar)
642-
for (unsigned i = 0; i < NumRoots; ++i) {
643-
Value *TrueElem;
644-
if (isa<PointerType>(TrueBase->getType()))
645-
TrueElem = TrueBase;
646-
else
647-
TrueElem = TrueBases[i];
648-
Value *FalseElem;
649-
if (isa<PointerType>(FalseBase->getType()))
650-
FalseElem = FalseBase;
651-
else
652-
FalseElem = FalseBases[i];
653-
if (TrueElem != V_null || FalseElem != V_null) {
654-
Value *Cond = SI->getCondition();
655-
if (isa<VectorType>(Cond->getType())) {
656-
Cond = ExtractElementInst::Create(Cond,
657-
ConstantInt::get(Type::getInt32Ty(Cond->getContext()), i),
658-
"", SI);
659-
}
660-
SelectInst *SelectBase = SelectInst::Create(Cond, TrueElem, FalseElem, "gclift", SI);
661-
int Number = ++S.MaxPtrNumber;
662-
S.AllPtrNumbering[SelectBase] = Number;
663-
S.ReversePtrNumbering[Number] = SelectBase;
664-
if (isa<PointerType>(SI->getType()))
665-
S.AllPtrNumbering[SI] = Number;
666-
else
667-
Numbers[i] = Number;
668-
didsplit = true;
669-
}
670-
}
671-
if (isa<VectorType>(SI->getType()) && NumRoots != Numbers.size()) {
672-
// broadcast the scalar root number to fill the vector
673-
assert(NumRoots == 1);
674-
int Number = Numbers[0];
675-
Numbers.resize(0);
676-
Numbers.resize(SI->getType()->getVectorNumElements(), Number);
677-
}
620+
std::vector<Value*> TrueBases;
621+
std::vector<Value*> FalseBases;
622+
if (!isa<PointerType>(TrueBase->getType())) {
623+
TrueBases = MaybeExtractVector(S, TrueBase, SI);
624+
assert(TrueBases.size() == Numbers.size());
625+
NumRoots = TrueBases.size();
626+
}
627+
if (!isa<PointerType>(FalseBase->getType())) {
628+
FalseBases = MaybeExtractVector(S, FalseBase, SI);
629+
assert(FalseBases.size() == Numbers.size());
630+
NumRoots = FalseBases.size();
631+
}
632+
if (isa<PointerType>(SI->getType()) ?
633+
S.AllPtrNumbering.count(SI) :
634+
S.AllCompositeNumbering.count(SI)) {
635+
// MaybeExtractScalar or MaybeExtractVector handled this for us (recursively, though a PHINode)
636+
return;
637+
}
638+
// need to handle each element (may just be one scalar)
639+
for (unsigned i = 0; i < NumRoots; ++i) {
640+
Value *TrueElem;
641+
if (isa<PointerType>(TrueBase->getType()))
642+
TrueElem = TrueBase;
643+
else
644+
TrueElem = TrueBases[i];
645+
Value *FalseElem;
646+
if (isa<PointerType>(FalseBase->getType()))
647+
FalseElem = FalseBase;
648+
else
649+
FalseElem = FalseBases[i];
650+
Value *Cond = SI->getCondition();
651+
if (isa<VectorType>(Cond->getType())) {
652+
Cond = ExtractElementInst::Create(Cond,
653+
ConstantInt::get(Type::getInt32Ty(Cond->getContext()), i),
654+
"", SI);
655+
}
656+
SelectInst *SelectBase = SelectInst::Create(Cond, TrueElem, FalseElem, "gclift", SI);
657+
int Number = ++S.MaxPtrNumber;
658+
S.AllPtrNumbering[SelectBase] = Number;
659+
S.ReversePtrNumbering[Number] = SelectBase;
660+
if (isa<PointerType>(SI->getType()))
661+
S.AllPtrNumbering[SI] = Number;
662+
else
663+
Numbers[i] = Number;
664+
}
665+
if (isa<VectorType>(SI->getType()) && NumRoots != Numbers.size()) {
666+
// broadcast the scalar root number to fill the vector
667+
assert(NumRoots == 1);
668+
int Number = Numbers[0];
669+
Numbers.resize(0);
670+
Numbers.resize(SI->getType()->getVectorNumElements(), Number);
678671
}
679672
if (!isa<PointerType>(SI->getType()))
680673
S.AllCompositeNumbering[SI] = Numbers;
681-
return didsplit;
682674
}
683675

684676
void LateLowerGCFrame::LiftPhi(State &S, PHINode *Phi) {
@@ -754,9 +746,8 @@ int LateLowerGCFrame::NumberBase(State &S, Value *CurrentV)
754746
// input IR)
755747
Number = -1;
756748
} else if (isa<SelectInst>(CurrentV) && !isTrackedValue(CurrentV)) {
757-
Number = -1;
758-
if (LiftSelect(S, cast<SelectInst>(CurrentV))) // lifting a scalar pointer (if necessary)
759-
Number = S.AllPtrNumbering.at(CurrentV);
749+
LiftSelect(S, cast<SelectInst>(CurrentV));
750+
Number = S.AllPtrNumbering.at(CurrentV);
760751
return Number;
761752
} else if (isa<PHINode>(CurrentV) && !isTrackedValue(CurrentV)) {
762753
LiftPhi(S, cast<PHINode>(CurrentV));

test/llvmpasses/gcroots.ll

+16
Original file line numberDiff line numberDiff line change
@@ -703,6 +703,22 @@ top:
703703
ret i8 %val
704704
}
705705

706+
define i8 @lost_select_decayed(i1 %arg1) {
707+
; CHECK-LABEL: @lost_select_decayed
708+
; CHECK: %gcframe = alloca %jl_value_t addrspace(10)*, i32 3
709+
; CHECK: [[GEP0:%.*]] = getelementptr %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)** %gcframe, i32 2
710+
; CHECK: store %jl_value_t addrspace(10)* [[SOMETHING:%.*]], %jl_value_t addrspace(10)** [[GEP0]]
711+
top:
712+
%ptls = call %jl_value_t*** @julia.ptls_states()
713+
%obj1 = call %jl_value_t addrspace(10) *@alloc()
714+
%decayed = addrspacecast %jl_value_t addrspace(10) *%obj1 to %jl_value_t addrspace(11)*
715+
%selected = select i1 %arg1, %jl_value_t addrspace(11)* null, %jl_value_t addrspace(11)* %decayed
716+
%casted = bitcast %jl_value_t addrspace(11)* %selected to i8 addrspace(11)*
717+
call void @jl_safepoint()
718+
%val = load i8, i8 addrspace(11)* %casted
719+
ret i8 %val
720+
}
721+
706722
!0 = !{!"jtbaa"}
707723
!1 = !{!"jtbaa_const", !0, i64 0}
708724
!2 = !{!1, !1, i64 0, i64 1}

0 commit comments

Comments
 (0)