@@ -331,7 +331,7 @@ struct LateLowerGCFrame: public FunctionPass, private JuliaPassContext {
331
331
}
332
332
333
333
void LiftPhi (State &S, PHINode *Phi);
334
- bool LiftSelect (State &S, SelectInst *SI);
334
+ void LiftSelect (State &S, SelectInst *SI);
335
335
Value *MaybeExtractScalar (State &S, std::pair<Value*,int > ValExpr, Instruction *InsertBefore);
336
336
std::vector<Value*> MaybeExtractVector (State &S, Value *BaseVec, Instruction *InsertBefore);
337
337
Value *GetPtrForNumber (State &S, unsigned Num, Instruction *InsertBefore);
@@ -600,12 +600,12 @@ Value *LateLowerGCFrame::GetPtrForNumber(State &S, unsigned Num, Instruction *In
600
600
return MaybeExtractScalar (S, std::make_pair (Val, Idx), InsertBefore);
601
601
}
602
602
603
- bool LateLowerGCFrame::LiftSelect (State &S, SelectInst *SI) {
603
+ void LateLowerGCFrame::LiftSelect (State &S, SelectInst *SI) {
604
604
if (isa<PointerType>(SI->getType ()) ?
605
605
S.AllPtrNumbering .count (SI) :
606
606
S.AllCompositeNumbering .count (SI)) {
607
607
// already visited here--nothing to do
608
- return true ;
608
+ return ;
609
609
}
610
610
std::vector<int > Numbers;
611
611
unsigned NumRoots = 1 ;
@@ -617,68 +617,60 @@ bool LateLowerGCFrame::LiftSelect(State &S, SelectInst *SI) {
617
617
// find the base root for the arguments
618
618
Value *TrueBase = MaybeExtractScalar (S, FindBaseValue (S, SI->getTrueValue (), false ), SI);
619
619
Value *FalseBase = MaybeExtractScalar (S, FindBaseValue (S, SI->getFalseValue (), false ), SI);
620
- Value *V_null = ConstantPointerNull::get (cast<PointerType>(T_prjlvalue));
621
- bool didsplit = false ;
622
- if (TrueBase != V_null && FalseBase != V_null) {
623
- std::vector<Value*> TrueBases;
624
- std::vector<Value*> FalseBases;
625
- if (!isa<PointerType>(TrueBase->getType ())) {
626
- TrueBases = MaybeExtractVector (S, TrueBase, SI);
627
- assert (TrueBases.size () == Numbers.size ());
628
- NumRoots = TrueBases.size ();
629
- }
630
- if (!isa<PointerType>(FalseBase->getType ())) {
631
- FalseBases = MaybeExtractVector (S, FalseBase, SI);
632
- assert (FalseBases.size () == Numbers.size ());
633
- NumRoots = FalseBases.size ();
634
- }
635
- if (isa<PointerType>(SI->getType ()) ?
636
- S.AllPtrNumbering .count (SI) :
637
- S.AllCompositeNumbering .count (SI)) {
638
- // MaybeExtractScalar or MaybeExtractVector handled this for us (recursively, though a PHINode)
639
- return true ;
640
- }
641
- // need to handle each element (may just be one scalar)
642
- for (unsigned i = 0 ; i < NumRoots; ++i) {
643
- Value *TrueElem;
644
- if (isa<PointerType>(TrueBase->getType ()))
645
- TrueElem = TrueBase;
646
- else
647
- TrueElem = TrueBases[i];
648
- Value *FalseElem;
649
- if (isa<PointerType>(FalseBase->getType ()))
650
- FalseElem = FalseBase;
651
- else
652
- FalseElem = FalseBases[i];
653
- if (TrueElem != V_null || FalseElem != V_null) {
654
- Value *Cond = SI->getCondition ();
655
- if (isa<VectorType>(Cond->getType ())) {
656
- Cond = ExtractElementInst::Create (Cond,
657
- ConstantInt::get (Type::getInt32Ty (Cond->getContext ()), i),
658
- " " , SI);
659
- }
660
- SelectInst *SelectBase = SelectInst::Create (Cond, TrueElem, FalseElem, " gclift" , SI);
661
- int Number = ++S.MaxPtrNumber ;
662
- S.AllPtrNumbering [SelectBase] = Number;
663
- S.ReversePtrNumbering [Number] = SelectBase;
664
- if (isa<PointerType>(SI->getType ()))
665
- S.AllPtrNumbering [SI] = Number;
666
- else
667
- Numbers[i] = Number;
668
- didsplit = true ;
669
- }
670
- }
671
- if (isa<VectorType>(SI->getType ()) && NumRoots != Numbers.size ()) {
672
- // broadcast the scalar root number to fill the vector
673
- assert (NumRoots == 1 );
674
- int Number = Numbers[0 ];
675
- Numbers.resize (0 );
676
- Numbers.resize (SI->getType ()->getVectorNumElements (), Number);
677
- }
620
+ std::vector<Value*> TrueBases;
621
+ std::vector<Value*> FalseBases;
622
+ if (!isa<PointerType>(TrueBase->getType ())) {
623
+ TrueBases = MaybeExtractVector (S, TrueBase, SI);
624
+ assert (TrueBases.size () == Numbers.size ());
625
+ NumRoots = TrueBases.size ();
626
+ }
627
+ if (!isa<PointerType>(FalseBase->getType ())) {
628
+ FalseBases = MaybeExtractVector (S, FalseBase, SI);
629
+ assert (FalseBases.size () == Numbers.size ());
630
+ NumRoots = FalseBases.size ();
631
+ }
632
+ if (isa<PointerType>(SI->getType ()) ?
633
+ S.AllPtrNumbering .count (SI) :
634
+ S.AllCompositeNumbering .count (SI)) {
635
+ // MaybeExtractScalar or MaybeExtractVector handled this for us (recursively, though a PHINode)
636
+ return ;
637
+ }
638
+ // need to handle each element (may just be one scalar)
639
+ for (unsigned i = 0 ; i < NumRoots; ++i) {
640
+ Value *TrueElem;
641
+ if (isa<PointerType>(TrueBase->getType ()))
642
+ TrueElem = TrueBase;
643
+ else
644
+ TrueElem = TrueBases[i];
645
+ Value *FalseElem;
646
+ if (isa<PointerType>(FalseBase->getType ()))
647
+ FalseElem = FalseBase;
648
+ else
649
+ FalseElem = FalseBases[i];
650
+ Value *Cond = SI->getCondition ();
651
+ if (isa<VectorType>(Cond->getType ())) {
652
+ Cond = ExtractElementInst::Create (Cond,
653
+ ConstantInt::get (Type::getInt32Ty (Cond->getContext ()), i),
654
+ " " , SI);
655
+ }
656
+ SelectInst *SelectBase = SelectInst::Create (Cond, TrueElem, FalseElem, " gclift" , SI);
657
+ int Number = ++S.MaxPtrNumber ;
658
+ S.AllPtrNumbering [SelectBase] = Number;
659
+ S.ReversePtrNumbering [Number] = SelectBase;
660
+ if (isa<PointerType>(SI->getType ()))
661
+ S.AllPtrNumbering [SI] = Number;
662
+ else
663
+ Numbers[i] = Number;
664
+ }
665
+ if (isa<VectorType>(SI->getType ()) && NumRoots != Numbers.size ()) {
666
+ // broadcast the scalar root number to fill the vector
667
+ assert (NumRoots == 1 );
668
+ int Number = Numbers[0 ];
669
+ Numbers.resize (0 );
670
+ Numbers.resize (SI->getType ()->getVectorNumElements (), Number);
678
671
}
679
672
if (!isa<PointerType>(SI->getType ()))
680
673
S.AllCompositeNumbering [SI] = Numbers;
681
- return didsplit;
682
674
}
683
675
684
676
void LateLowerGCFrame::LiftPhi (State &S, PHINode *Phi) {
@@ -754,9 +746,8 @@ int LateLowerGCFrame::NumberBase(State &S, Value *CurrentV)
754
746
// input IR)
755
747
Number = -1 ;
756
748
} else if (isa<SelectInst>(CurrentV) && !isTrackedValue (CurrentV)) {
757
- Number = -1 ;
758
- if (LiftSelect (S, cast<SelectInst>(CurrentV))) // lifting a scalar pointer (if necessary)
759
- Number = S.AllPtrNumbering .at (CurrentV);
749
+ LiftSelect (S, cast<SelectInst>(CurrentV));
750
+ Number = S.AllPtrNumbering .at (CurrentV);
760
751
return Number;
761
752
} else if (isa<PHINode>(CurrentV) && !isTrackedValue (CurrentV)) {
762
753
LiftPhi (S, cast<PHINode>(CurrentV));
0 commit comments