Skip to content

Commit 1db51d8

Browse files
committed
[Transform] Rewrite LowerSwitch using APInt
This rewrite fixes #59316. Previously LowerSwitch uses int64_t, which will crash on case branches using integers with more than 64 bits. Using APInt fixes this problem. This patch also includes a test Reviewed By: RKSimon Differential Revision: https://reviews.llvm.org/D140747
1 parent c3ab645 commit 1db51d8

File tree

2 files changed

+115
-44
lines changed

2 files changed

+115
-44
lines changed

llvm/lib/Transforms/Utils/LowerSwitch.cpp

Lines changed: 51 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ using namespace llvm;
5252
namespace {
5353

5454
struct IntRange {
55-
int64_t Low, High;
55+
APInt Low, High;
5656
};
5757

5858
} // end anonymous namespace
@@ -66,8 +66,8 @@ bool IsInRanges(const IntRange &R, const std::vector<IntRange> &Ranges) {
6666
// then check if the Low field is <= R.Low. If so, we
6767
// have a Range that covers R.
6868
auto I = llvm::lower_bound(
69-
Ranges, R, [](IntRange A, IntRange B) { return A.High < B.High; });
70-
return I != Ranges.end() && I->Low <= R.Low;
69+
Ranges, R, [](IntRange A, IntRange B) { return A.High.slt(B.High); });
70+
return I != Ranges.end() && I->Low.sle(R.Low);
7171
}
7272

7373
struct CaseRange {
@@ -116,15 +116,14 @@ raw_ostream &operator<<(raw_ostream &O, const CaseVector &C) {
116116
/// 2) Removed if subsequent incoming values now share the same case, i.e.,
117117
/// multiple outcome edges are condensed into one. This is necessary to keep the
118118
/// number of phi values equal to the number of branches to SuccBB.
119-
void FixPhis(
120-
BasicBlock *SuccBB, BasicBlock *OrigBB, BasicBlock *NewBB,
121-
const unsigned NumMergedCases = std::numeric_limits<unsigned>::max()) {
119+
void FixPhis(BasicBlock *SuccBB, BasicBlock *OrigBB, BasicBlock *NewBB,
120+
const APInt &NumMergedCases) {
122121
for (auto &I : SuccBB->phis()) {
123122
PHINode *PN = cast<PHINode>(&I);
124123

125124
// Only update the first occurrence if NewBB exists.
126125
unsigned Idx = 0, E = PN->getNumIncomingValues();
127-
unsigned LocalNumMergedCases = NumMergedCases;
126+
APInt LocalNumMergedCases = NumMergedCases;
128127
for (; Idx != E && NewBB; ++Idx) {
129128
if (PN->getIncomingBlock(Idx) == OrigBB) {
130129
PN->setIncomingBlock(Idx, NewBB);
@@ -139,10 +138,10 @@ void FixPhis(
139138
// Remove additional occurrences coming from condensed cases and keep the
140139
// number of incoming values equal to the number of branches to SuccBB.
141140
SmallVector<unsigned, 8> Indices;
142-
for (; LocalNumMergedCases > 0 && Idx < E; ++Idx)
141+
for (; LocalNumMergedCases.ugt(0) && Idx < E; ++Idx)
143142
if (PN->getIncomingBlock(Idx) == OrigBB) {
144143
Indices.push_back(Idx);
145-
LocalNumMergedCases--;
144+
LocalNumMergedCases -= 1;
146145
}
147146
// Remove incoming values in the reverse order to prevent invalidating
148147
// *successive* index.
@@ -209,8 +208,8 @@ BasicBlock *NewLeafBlock(CaseRange &Leaf, Value *Val, ConstantInt *LowerBound,
209208
for (BasicBlock::iterator I = Succ->begin(); isa<PHINode>(I); ++I) {
210209
PHINode *PN = cast<PHINode>(I);
211210
// Remove all but one incoming entries from the cluster
212-
uint64_t Range = Leaf.High->getSExtValue() - Leaf.Low->getSExtValue();
213-
for (uint64_t j = 0; j < Range; ++j) {
211+
APInt Range = Leaf.High->getValue() - Leaf.Low->getValue();
212+
for (APInt j(Range.getBitWidth(), 0, true); j.slt(Range); ++j) {
214213
PN->removeIncomingValue(OrigBlock);
215214
}
216215

@@ -241,8 +240,7 @@ BasicBlock *SwitchConvert(CaseItr Begin, CaseItr End, ConstantInt *LowerBound,
241240
// emitting the code that checks if the value actually falls in the range
242241
// because the bounds already tell us so.
243242
if (Begin->Low == LowerBound && Begin->High == UpperBound) {
244-
unsigned NumMergedCases = 0;
245-
NumMergedCases = UpperBound->getSExtValue() - LowerBound->getSExtValue();
243+
APInt NumMergedCases = UpperBound->getValue() - LowerBound->getValue();
246244
FixPhis(Begin->BB, OrigBlock, Predecessor, NumMergedCases);
247245
return Begin->BB;
248246
}
@@ -273,17 +271,17 @@ BasicBlock *SwitchConvert(CaseItr Begin, CaseItr End, ConstantInt *LowerBound,
273271

274272
if (!UnreachableRanges.empty()) {
275273
// Check if the gap between LHS's highest and NewLowerBound is unreachable.
276-
int64_t GapLow = LHS.back().High->getSExtValue() + 1;
277-
int64_t GapHigh = NewLowerBound->getSExtValue() - 1;
274+
APInt GapLow = LHS.back().High->getValue() + 1;
275+
APInt GapHigh = NewLowerBound->getValue() - 1;
278276
IntRange Gap = {GapLow, GapHigh};
279-
if (GapHigh >= GapLow && IsInRanges(Gap, UnreachableRanges))
277+
if (GapHigh.sge(GapLow) && IsInRanges(Gap, UnreachableRanges))
280278
NewUpperBound = LHS.back().High;
281279
}
282280

283-
LLVM_DEBUG(dbgs() << "LHS Bounds ==> [" << LowerBound->getSExtValue() << ", "
284-
<< NewUpperBound->getSExtValue() << "]\n"
285-
<< "RHS Bounds ==> [" << NewLowerBound->getSExtValue()
286-
<< ", " << UpperBound->getSExtValue() << "]\n");
281+
LLVM_DEBUG(dbgs() << "LHS Bounds ==> [" << LowerBound->getValue() << ", "
282+
<< NewUpperBound->getValue() << "]\n"
283+
<< "RHS Bounds ==> [" << NewLowerBound->getValue() << ", "
284+
<< UpperBound->getValue() << "]\n");
287285

288286
// Create a new node that checks if the value is < pivot. Go to the
289287
// left branch if it is and right branch if not.
@@ -327,14 +325,15 @@ unsigned Clusterify(CaseVector &Cases, SwitchInst *SI) {
327325
if (Cases.size() >= 2) {
328326
CaseItr I = Cases.begin();
329327
for (CaseItr J = std::next(I), E = Cases.end(); J != E; ++J) {
330-
int64_t nextValue = J->Low->getSExtValue();
331-
int64_t currentValue = I->High->getSExtValue();
328+
const APInt &nextValue = J->Low->getValue();
329+
const APInt &currentValue = I->High->getValue();
332330
BasicBlock *nextBB = J->BB;
333331
BasicBlock *currentBB = I->BB;
334332

335333
// If the two neighboring cases go to the same destination, merge them
336334
// into a single case.
337-
assert(nextValue > currentValue && "Cases should be strictly ascending");
335+
assert(nextValue.sgt(currentValue) &&
336+
"Cases should be strictly ascending");
338337
if ((nextValue == currentValue + 1) && (currentBB == nextBB)) {
339338
I->High = J->High;
340339
// FIXME: Combine branch weights.
@@ -369,6 +368,10 @@ void ProcessSwitchInst(SwitchInst *SI,
369368
// Prepare cases vector.
370369
CaseVector Cases;
371370
const unsigned NumSimpleCases = Clusterify(Cases, SI);
371+
IntegerType *IT = cast<IntegerType>(SI->getCondition()->getType());
372+
const unsigned BitWidth = IT->getBitWidth();
373+
APInt SignedZero(BitWidth, 0);
374+
APInt UnsignedMax = APInt::getMaxValue(BitWidth);
372375
LLVM_DEBUG(dbgs() << "Clusterify finished. Total clusters: " << Cases.size()
373376
<< ". Total non-default cases: " << NumSimpleCases
374377
<< "\nCase clusters: " << Cases << "\n");
@@ -377,7 +380,7 @@ void ProcessSwitchInst(SwitchInst *SI,
377380
if (Cases.empty()) {
378381
BranchInst::Create(Default, OrigBlock);
379382
// Remove all the references from Default's PHIs to OrigBlock, but one.
380-
FixPhis(Default, OrigBlock, OrigBlock);
383+
FixPhis(Default, OrigBlock, OrigBlock, UnsignedMax);
381384
SI->eraseFromParent();
382385
return;
383386
}
@@ -414,8 +417,8 @@ void ProcessSwitchInst(SwitchInst *SI,
414417
// the unlikely event that some of them survived, we just conservatively
415418
// maintain the invariant that all the cases lie between the bounds. This
416419
// may, however, still render the default case effectively unreachable.
417-
APInt Low = Cases.front().Low->getValue();
418-
APInt High = Cases.back().High->getValue();
420+
const APInt &Low = Cases.front().Low->getValue();
421+
const APInt &High = Cases.back().High->getValue();
419422
APInt Min = APIntOps::smin(ValRange.getSignedMin(), Low);
420423
APInt Max = APIntOps::smax(ValRange.getSignedMax(), High);
421424

@@ -427,35 +430,38 @@ void ProcessSwitchInst(SwitchInst *SI,
427430
std::vector<IntRange> UnreachableRanges;
428431

429432
if (DefaultIsUnreachableFromSwitch) {
430-
DenseMap<BasicBlock *, unsigned> Popularity;
431-
unsigned MaxPop = 0;
433+
DenseMap<BasicBlock *, APInt> Popularity;
434+
APInt MaxPop(SignedZero);
432435
BasicBlock *PopSucc = nullptr;
433436

434-
IntRange R = {std::numeric_limits<int64_t>::min(),
435-
std::numeric_limits<int64_t>::max()};
437+
APInt SignedMax = APInt::getSignedMaxValue(BitWidth);
438+
APInt SignedMin = APInt::getSignedMinValue(BitWidth);
439+
IntRange R = {SignedMin, SignedMax};
436440
UnreachableRanges.push_back(R);
437441
for (const auto &I : Cases) {
438-
int64_t Low = I.Low->getSExtValue();
439-
int64_t High = I.High->getSExtValue();
442+
const APInt &Low = I.Low->getValue();
443+
const APInt &High = I.High->getValue();
440444

441445
IntRange &LastRange = UnreachableRanges.back();
442-
if (LastRange.Low == Low) {
446+
if (LastRange.Low.eq(Low)) {
443447
// There is nothing left of the previous range.
444448
UnreachableRanges.pop_back();
445449
} else {
446450
// Terminate the previous range.
447-
assert(Low > LastRange.Low);
451+
assert(Low.sgt(LastRange.Low));
448452
LastRange.High = Low - 1;
449453
}
450-
if (High != std::numeric_limits<int64_t>::max()) {
451-
IntRange R = {High + 1, std::numeric_limits<int64_t>::max()};
454+
if (High.ne(SignedMax)) {
455+
IntRange R = {High + 1, SignedMax};
452456
UnreachableRanges.push_back(R);
453457
}
454458

455459
// Count popularity.
456-
int64_t N = High - Low + 1;
457-
unsigned &Pop = Popularity[I.BB];
458-
if ((Pop += N) > MaxPop) {
460+
APInt N = High - Low + 1;
461+
assert(N.sge(SignedZero) && "Popularity shouldn't be negative.");
462+
// Explict insert to make sure the bitwidth of APInts match
463+
APInt &Pop = Popularity.insert({I.BB, APInt(SignedZero)}).first->second;
464+
if ((Pop += N).sgt(MaxPop)) {
459465
MaxPop = Pop;
460466
PopSucc = I.BB;
461467
}
@@ -464,10 +470,10 @@ void ProcessSwitchInst(SwitchInst *SI,
464470
/* UnreachableRanges should be sorted and the ranges non-adjacent. */
465471
for (auto I = UnreachableRanges.begin(), E = UnreachableRanges.end();
466472
I != E; ++I) {
467-
assert(I->Low <= I->High);
473+
assert(I->Low.sle(I->High));
468474
auto Next = I + 1;
469475
if (Next != E) {
470-
assert(Next->Low > I->High);
476+
assert(Next->Low.sgt(I->High));
471477
}
472478
}
473479
#endif
@@ -480,7 +486,8 @@ void ProcessSwitchInst(SwitchInst *SI,
480486

481487
// Use the most popular block as the new default, reducing the number of
482488
// cases.
483-
assert(MaxPop > 0 && PopSucc);
489+
assert(MaxPop.sgt(SignedZero) && PopSucc &&
490+
"Max populartion shouldn't be negative.");
484491
Default = PopSucc;
485492
llvm::erase_if(Cases,
486493
[PopSucc](const CaseRange &R) { return R.BB == PopSucc; });
@@ -491,7 +498,7 @@ void ProcessSwitchInst(SwitchInst *SI,
491498
SI->eraseFromParent();
492499
// As all the cases have been replaced with a single branch, only keep
493500
// one entry in the PHI nodes.
494-
for (unsigned I = 0; I < (MaxPop - 1); ++I)
501+
for (APInt I(SignedZero); I.slt(MaxPop - 1); ++I)
495502
PopSucc->removePredecessor(OrigBlock);
496503
return;
497504
}
@@ -512,7 +519,7 @@ void ProcessSwitchInst(SwitchInst *SI,
512519
// that SwitchBlock is the same as Default, under which the PHIs in Default
513520
// are fixed inside SwitchConvert().
514521
if (SwitchBlock != Default)
515-
FixPhis(Default, OrigBlock, nullptr);
522+
FixPhis(Default, OrigBlock, nullptr, UnsignedMax);
516523

517524
// Branch to our shiny new if-then stuff...
518525
BranchInst::Create(SwitchBlock, OrigBlock);
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
2+
; RUN: opt < %s -passes=lowerswitch -S | FileCheck %s
3+
4+
define i64 @f(i1 %bool, i128 %i128) {
5+
; CHECK-LABEL: @f(
6+
; CHECK-NEXT: BB:
7+
; CHECK-NEXT: br label [[NODEBLOCK1:%.*]]
8+
; CHECK: NodeBlock1:
9+
; CHECK-NEXT: [[PIVOT2:%.*]] = icmp slt i128 [[I128:%.*]], 16201310291018008446
10+
; CHECK-NEXT: br i1 [[PIVOT2]], label [[LEAFBLOCK:%.*]], label [[NODEBLOCK:%.*]]
11+
; CHECK: NodeBlock:
12+
; CHECK-NEXT: [[PIVOT:%.*]] = icmp slt i128 [[I128]], 16201310291018008447
13+
; CHECK-NEXT: br i1 [[PIVOT]], label [[SW_C3:%.*]], label [[SW_C2:%.*]]
14+
; CHECK: LeafBlock:
15+
; CHECK-NEXT: [[SWITCHLEAF:%.*]] = icmp eq i128 [[I128]], 16201310291018008445
16+
; CHECK-NEXT: br i1 [[SWITCHLEAF]], label [[SW_C4:%.*]], label [[SW_C1:%.*]]
17+
; CHECK: BB1:
18+
; CHECK-NEXT: unreachable
19+
; CHECK: SW_C1:
20+
; CHECK-NEXT: br i1 [[BOOL:%.*]], label [[BB1:%.*]], label [[SW_C1]]
21+
; CHECK: SW_C2:
22+
; CHECK-NEXT: ret i64 0
23+
; CHECK: SW_C3:
24+
; CHECK-NEXT: ret i64 1
25+
; CHECK: SW_C4:
26+
; CHECK-NEXT: ret i64 2
27+
;
28+
BB:
29+
switch i128 %i128, label %BB1 [
30+
i128 627, label %SW_C1
31+
i128 16201310291018008447, label %SW_C2
32+
i128 16201310291018008446, label %SW_C3
33+
i128 16201310291018008445, label %SW_C4
34+
]
35+
36+
BB1: ; preds = %SW_C1, %BB
37+
unreachable
38+
39+
SW_C1: ; preds = %SW_C1, %BB
40+
br i1 %bool, label %BB1, label %SW_C1
41+
42+
SW_C2: ; preds = %BB
43+
ret i64 0
44+
45+
SW_C3: ; preds = %BB
46+
ret i64 1
47+
48+
SW_C4: ; preds = %BB
49+
ret i64 2
50+
}
51+
52+
define i64 @f_empty(i1 %bool, i128 %i128) {
53+
; CHECK-LABEL: @f_empty(
54+
; CHECK-NEXT: BB:
55+
; CHECK-NEXT: br label [[BB1:%.*]]
56+
; CHECK: BB1:
57+
; CHECK-NEXT: unreachable
58+
;
59+
BB:
60+
switch i128 %i128, label %BB1 []
61+
62+
BB1: ; preds = %BB
63+
unreachable
64+
}

0 commit comments

Comments
 (0)