From 1db51d8eb2d220a4f0000555ada310990098cf5b Mon Sep 17 00:00:00 2001 From: Peter Rong Date: Wed, 28 Dec 2022 16:51:07 -0800 Subject: [PATCH] [Transform] Rewrite LowerSwitch using APInt This rewrite fixes https://github.com/llvm/llvm-project/issues/59316. Previously LowerSwitch uses int64_t, which will crash on case branches using integers with more than 64 bits. Using APInt fixes this problem. This patch also includes a test Reviewed By: RKSimon Differential Revision: https://reviews.llvm.org/D140747 --- llvm/lib/Transforms/Utils/LowerSwitch.cpp | 95 +++++++++++---------- llvm/test/Transforms/LowerSwitch/pr59316.ll | 64 ++++++++++++++ 2 files changed, 115 insertions(+), 44 deletions(-) create mode 100644 llvm/test/Transforms/LowerSwitch/pr59316.ll diff --git a/llvm/lib/Transforms/Utils/LowerSwitch.cpp b/llvm/lib/Transforms/Utils/LowerSwitch.cpp index 9e3095fa291f81..26aebdfff6408d 100644 --- a/llvm/lib/Transforms/Utils/LowerSwitch.cpp +++ b/llvm/lib/Transforms/Utils/LowerSwitch.cpp @@ -52,7 +52,7 @@ using namespace llvm; namespace { struct IntRange { - int64_t Low, High; + APInt Low, High; }; } // end anonymous namespace @@ -66,8 +66,8 @@ bool IsInRanges(const IntRange &R, const std::vector &Ranges) { // then check if the Low field is <= R.Low. If so, we // have a Range that covers R. auto I = llvm::lower_bound( - Ranges, R, [](IntRange A, IntRange B) { return A.High < B.High; }); - return I != Ranges.end() && I->Low <= R.Low; + Ranges, R, [](IntRange A, IntRange B) { return A.High.slt(B.High); }); + return I != Ranges.end() && I->Low.sle(R.Low); } struct CaseRange { @@ -116,15 +116,14 @@ raw_ostream &operator<<(raw_ostream &O, const CaseVector &C) { /// 2) Removed if subsequent incoming values now share the same case, i.e., /// multiple outcome edges are condensed into one. This is necessary to keep the /// number of phi values equal to the number of branches to SuccBB. -void FixPhis( - BasicBlock *SuccBB, BasicBlock *OrigBB, BasicBlock *NewBB, - const unsigned NumMergedCases = std::numeric_limits::max()) { +void FixPhis(BasicBlock *SuccBB, BasicBlock *OrigBB, BasicBlock *NewBB, + const APInt &NumMergedCases) { for (auto &I : SuccBB->phis()) { PHINode *PN = cast(&I); // Only update the first occurrence if NewBB exists. unsigned Idx = 0, E = PN->getNumIncomingValues(); - unsigned LocalNumMergedCases = NumMergedCases; + APInt LocalNumMergedCases = NumMergedCases; for (; Idx != E && NewBB; ++Idx) { if (PN->getIncomingBlock(Idx) == OrigBB) { PN->setIncomingBlock(Idx, NewBB); @@ -139,10 +138,10 @@ void FixPhis( // Remove additional occurrences coming from condensed cases and keep the // number of incoming values equal to the number of branches to SuccBB. SmallVector Indices; - for (; LocalNumMergedCases > 0 && Idx < E; ++Idx) + for (; LocalNumMergedCases.ugt(0) && Idx < E; ++Idx) if (PN->getIncomingBlock(Idx) == OrigBB) { Indices.push_back(Idx); - LocalNumMergedCases--; + LocalNumMergedCases -= 1; } // Remove incoming values in the reverse order to prevent invalidating // *successive* index. @@ -209,8 +208,8 @@ BasicBlock *NewLeafBlock(CaseRange &Leaf, Value *Val, ConstantInt *LowerBound, for (BasicBlock::iterator I = Succ->begin(); isa(I); ++I) { PHINode *PN = cast(I); // Remove all but one incoming entries from the cluster - uint64_t Range = Leaf.High->getSExtValue() - Leaf.Low->getSExtValue(); - for (uint64_t j = 0; j < Range; ++j) { + APInt Range = Leaf.High->getValue() - Leaf.Low->getValue(); + for (APInt j(Range.getBitWidth(), 0, true); j.slt(Range); ++j) { PN->removeIncomingValue(OrigBlock); } @@ -241,8 +240,7 @@ BasicBlock *SwitchConvert(CaseItr Begin, CaseItr End, ConstantInt *LowerBound, // emitting the code that checks if the value actually falls in the range // because the bounds already tell us so. if (Begin->Low == LowerBound && Begin->High == UpperBound) { - unsigned NumMergedCases = 0; - NumMergedCases = UpperBound->getSExtValue() - LowerBound->getSExtValue(); + APInt NumMergedCases = UpperBound->getValue() - LowerBound->getValue(); FixPhis(Begin->BB, OrigBlock, Predecessor, NumMergedCases); return Begin->BB; } @@ -273,17 +271,17 @@ BasicBlock *SwitchConvert(CaseItr Begin, CaseItr End, ConstantInt *LowerBound, if (!UnreachableRanges.empty()) { // Check if the gap between LHS's highest and NewLowerBound is unreachable. - int64_t GapLow = LHS.back().High->getSExtValue() + 1; - int64_t GapHigh = NewLowerBound->getSExtValue() - 1; + APInt GapLow = LHS.back().High->getValue() + 1; + APInt GapHigh = NewLowerBound->getValue() - 1; IntRange Gap = {GapLow, GapHigh}; - if (GapHigh >= GapLow && IsInRanges(Gap, UnreachableRanges)) + if (GapHigh.sge(GapLow) && IsInRanges(Gap, UnreachableRanges)) NewUpperBound = LHS.back().High; } - LLVM_DEBUG(dbgs() << "LHS Bounds ==> [" << LowerBound->getSExtValue() << ", " - << NewUpperBound->getSExtValue() << "]\n" - << "RHS Bounds ==> [" << NewLowerBound->getSExtValue() - << ", " << UpperBound->getSExtValue() << "]\n"); + LLVM_DEBUG(dbgs() << "LHS Bounds ==> [" << LowerBound->getValue() << ", " + << NewUpperBound->getValue() << "]\n" + << "RHS Bounds ==> [" << NewLowerBound->getValue() << ", " + << UpperBound->getValue() << "]\n"); // Create a new node that checks if the value is < pivot. Go to the // left branch if it is and right branch if not. @@ -327,14 +325,15 @@ unsigned Clusterify(CaseVector &Cases, SwitchInst *SI) { if (Cases.size() >= 2) { CaseItr I = Cases.begin(); for (CaseItr J = std::next(I), E = Cases.end(); J != E; ++J) { - int64_t nextValue = J->Low->getSExtValue(); - int64_t currentValue = I->High->getSExtValue(); + const APInt &nextValue = J->Low->getValue(); + const APInt ¤tValue = I->High->getValue(); BasicBlock *nextBB = J->BB; BasicBlock *currentBB = I->BB; // If the two neighboring cases go to the same destination, merge them // into a single case. - assert(nextValue > currentValue && "Cases should be strictly ascending"); + assert(nextValue.sgt(currentValue) && + "Cases should be strictly ascending"); if ((nextValue == currentValue + 1) && (currentBB == nextBB)) { I->High = J->High; // FIXME: Combine branch weights. @@ -369,6 +368,10 @@ void ProcessSwitchInst(SwitchInst *SI, // Prepare cases vector. CaseVector Cases; const unsigned NumSimpleCases = Clusterify(Cases, SI); + IntegerType *IT = cast(SI->getCondition()->getType()); + const unsigned BitWidth = IT->getBitWidth(); + APInt SignedZero(BitWidth, 0); + APInt UnsignedMax = APInt::getMaxValue(BitWidth); LLVM_DEBUG(dbgs() << "Clusterify finished. Total clusters: " << Cases.size() << ". Total non-default cases: " << NumSimpleCases << "\nCase clusters: " << Cases << "\n"); @@ -377,7 +380,7 @@ void ProcessSwitchInst(SwitchInst *SI, if (Cases.empty()) { BranchInst::Create(Default, OrigBlock); // Remove all the references from Default's PHIs to OrigBlock, but one. - FixPhis(Default, OrigBlock, OrigBlock); + FixPhis(Default, OrigBlock, OrigBlock, UnsignedMax); SI->eraseFromParent(); return; } @@ -414,8 +417,8 @@ void ProcessSwitchInst(SwitchInst *SI, // the unlikely event that some of them survived, we just conservatively // maintain the invariant that all the cases lie between the bounds. This // may, however, still render the default case effectively unreachable. - APInt Low = Cases.front().Low->getValue(); - APInt High = Cases.back().High->getValue(); + const APInt &Low = Cases.front().Low->getValue(); + const APInt &High = Cases.back().High->getValue(); APInt Min = APIntOps::smin(ValRange.getSignedMin(), Low); APInt Max = APIntOps::smax(ValRange.getSignedMax(), High); @@ -427,35 +430,38 @@ void ProcessSwitchInst(SwitchInst *SI, std::vector UnreachableRanges; if (DefaultIsUnreachableFromSwitch) { - DenseMap Popularity; - unsigned MaxPop = 0; + DenseMap Popularity; + APInt MaxPop(SignedZero); BasicBlock *PopSucc = nullptr; - IntRange R = {std::numeric_limits::min(), - std::numeric_limits::max()}; + APInt SignedMax = APInt::getSignedMaxValue(BitWidth); + APInt SignedMin = APInt::getSignedMinValue(BitWidth); + IntRange R = {SignedMin, SignedMax}; UnreachableRanges.push_back(R); for (const auto &I : Cases) { - int64_t Low = I.Low->getSExtValue(); - int64_t High = I.High->getSExtValue(); + const APInt &Low = I.Low->getValue(); + const APInt &High = I.High->getValue(); IntRange &LastRange = UnreachableRanges.back(); - if (LastRange.Low == Low) { + if (LastRange.Low.eq(Low)) { // There is nothing left of the previous range. UnreachableRanges.pop_back(); } else { // Terminate the previous range. - assert(Low > LastRange.Low); + assert(Low.sgt(LastRange.Low)); LastRange.High = Low - 1; } - if (High != std::numeric_limits::max()) { - IntRange R = {High + 1, std::numeric_limits::max()}; + if (High.ne(SignedMax)) { + IntRange R = {High + 1, SignedMax}; UnreachableRanges.push_back(R); } // Count popularity. - int64_t N = High - Low + 1; - unsigned &Pop = Popularity[I.BB]; - if ((Pop += N) > MaxPop) { + APInt N = High - Low + 1; + assert(N.sge(SignedZero) && "Popularity shouldn't be negative."); + // Explict insert to make sure the bitwidth of APInts match + APInt &Pop = Popularity.insert({I.BB, APInt(SignedZero)}).first->second; + if ((Pop += N).sgt(MaxPop)) { MaxPop = Pop; PopSucc = I.BB; } @@ -464,10 +470,10 @@ void ProcessSwitchInst(SwitchInst *SI, /* UnreachableRanges should be sorted and the ranges non-adjacent. */ for (auto I = UnreachableRanges.begin(), E = UnreachableRanges.end(); I != E; ++I) { - assert(I->Low <= I->High); + assert(I->Low.sle(I->High)); auto Next = I + 1; if (Next != E) { - assert(Next->Low > I->High); + assert(Next->Low.sgt(I->High)); } } #endif @@ -480,7 +486,8 @@ void ProcessSwitchInst(SwitchInst *SI, // Use the most popular block as the new default, reducing the number of // cases. - assert(MaxPop > 0 && PopSucc); + assert(MaxPop.sgt(SignedZero) && PopSucc && + "Max populartion shouldn't be negative."); Default = PopSucc; llvm::erase_if(Cases, [PopSucc](const CaseRange &R) { return R.BB == PopSucc; }); @@ -491,7 +498,7 @@ void ProcessSwitchInst(SwitchInst *SI, SI->eraseFromParent(); // As all the cases have been replaced with a single branch, only keep // one entry in the PHI nodes. - for (unsigned I = 0; I < (MaxPop - 1); ++I) + for (APInt I(SignedZero); I.slt(MaxPop - 1); ++I) PopSucc->removePredecessor(OrigBlock); return; } @@ -512,7 +519,7 @@ void ProcessSwitchInst(SwitchInst *SI, // that SwitchBlock is the same as Default, under which the PHIs in Default // are fixed inside SwitchConvert(). if (SwitchBlock != Default) - FixPhis(Default, OrigBlock, nullptr); + FixPhis(Default, OrigBlock, nullptr, UnsignedMax); // Branch to our shiny new if-then stuff... BranchInst::Create(SwitchBlock, OrigBlock); diff --git a/llvm/test/Transforms/LowerSwitch/pr59316.ll b/llvm/test/Transforms/LowerSwitch/pr59316.ll new file mode 100644 index 00000000000000..2e4226c71ea7d3 --- /dev/null +++ b/llvm/test/Transforms/LowerSwitch/pr59316.ll @@ -0,0 +1,64 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -passes=lowerswitch -S | FileCheck %s + +define i64 @f(i1 %bool, i128 %i128) { +; CHECK-LABEL: @f( +; CHECK-NEXT: BB: +; CHECK-NEXT: br label [[NODEBLOCK1:%.*]] +; CHECK: NodeBlock1: +; CHECK-NEXT: [[PIVOT2:%.*]] = icmp slt i128 [[I128:%.*]], 16201310291018008446 +; CHECK-NEXT: br i1 [[PIVOT2]], label [[LEAFBLOCK:%.*]], label [[NODEBLOCK:%.*]] +; CHECK: NodeBlock: +; CHECK-NEXT: [[PIVOT:%.*]] = icmp slt i128 [[I128]], 16201310291018008447 +; CHECK-NEXT: br i1 [[PIVOT]], label [[SW_C3:%.*]], label [[SW_C2:%.*]] +; CHECK: LeafBlock: +; CHECK-NEXT: [[SWITCHLEAF:%.*]] = icmp eq i128 [[I128]], 16201310291018008445 +; CHECK-NEXT: br i1 [[SWITCHLEAF]], label [[SW_C4:%.*]], label [[SW_C1:%.*]] +; CHECK: BB1: +; CHECK-NEXT: unreachable +; CHECK: SW_C1: +; CHECK-NEXT: br i1 [[BOOL:%.*]], label [[BB1:%.*]], label [[SW_C1]] +; CHECK: SW_C2: +; CHECK-NEXT: ret i64 0 +; CHECK: SW_C3: +; CHECK-NEXT: ret i64 1 +; CHECK: SW_C4: +; CHECK-NEXT: ret i64 2 +; +BB: + switch i128 %i128, label %BB1 [ + i128 627, label %SW_C1 + i128 16201310291018008447, label %SW_C2 + i128 16201310291018008446, label %SW_C3 + i128 16201310291018008445, label %SW_C4 + ] + +BB1: ; preds = %SW_C1, %BB + unreachable + +SW_C1: ; preds = %SW_C1, %BB + br i1 %bool, label %BB1, label %SW_C1 + +SW_C2: ; preds = %BB + ret i64 0 + +SW_C3: ; preds = %BB + ret i64 1 + +SW_C4: ; preds = %BB + ret i64 2 +} + +define i64 @f_empty(i1 %bool, i128 %i128) { +; CHECK-LABEL: @f_empty( +; CHECK-NEXT: BB: +; CHECK-NEXT: br label [[BB1:%.*]] +; CHECK: BB1: +; CHECK-NEXT: unreachable +; +BB: + switch i128 %i128, label %BB1 [] + +BB1: ; preds = %BB + unreachable +}