Skip to content

Commit 07b29fc

Browse files
authored
[ConstantRange] Improve shlWithNoWrap (#101800)
Closes dtcxzyw/llvm-tools#22.
1 parent 4dee641 commit 07b29fc

File tree

3 files changed

+147
-14
lines changed

3 files changed

+147
-14
lines changed

llvm/lib/IR/ConstantRange.cpp

Lines changed: 95 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1617,21 +1617,107 @@ ConstantRange::shl(const ConstantRange &Other) const {
16171617
return ConstantRange::getNonEmpty(std::move(Min), std::move(Max) + 1);
16181618
}
16191619

1620+
static ConstantRange computeShlNUW(const ConstantRange &LHS,
1621+
const ConstantRange &RHS) {
1622+
unsigned BitWidth = LHS.getBitWidth();
1623+
bool Overflow;
1624+
APInt LHSMin = LHS.getUnsignedMin();
1625+
unsigned RHSMin = RHS.getUnsignedMin().getLimitedValue(BitWidth);
1626+
APInt MinShl = LHSMin.ushl_ov(RHSMin, Overflow);
1627+
if (Overflow)
1628+
return ConstantRange::getEmpty(BitWidth);
1629+
APInt LHSMax = LHS.getUnsignedMax();
1630+
unsigned RHSMax = RHS.getUnsignedMax().getLimitedValue(BitWidth);
1631+
APInt MaxShl = MinShl;
1632+
unsigned MaxShAmt = LHSMax.countLeadingZeros();
1633+
if (RHSMin <= MaxShAmt)
1634+
MaxShl = LHSMax << std::min(RHSMax, MaxShAmt);
1635+
RHSMin = std::max(RHSMin, MaxShAmt + 1);
1636+
RHSMax = std::min(RHSMax, LHSMin.countLeadingZeros());
1637+
if (RHSMin <= RHSMax)
1638+
MaxShl = APIntOps::umax(MaxShl,
1639+
APInt::getHighBitsSet(BitWidth, BitWidth - RHSMin));
1640+
return ConstantRange::getNonEmpty(MinShl, MaxShl + 1);
1641+
}
1642+
1643+
static ConstantRange computeShlNSWWithNNegLHS(const APInt &LHSMin,
1644+
const APInt &LHSMax,
1645+
unsigned RHSMin,
1646+
unsigned RHSMax) {
1647+
unsigned BitWidth = LHSMin.getBitWidth();
1648+
bool Overflow;
1649+
APInt MinShl = LHSMin.sshl_ov(RHSMin, Overflow);
1650+
if (Overflow)
1651+
return ConstantRange::getEmpty(BitWidth);
1652+
APInt MaxShl = MinShl;
1653+
unsigned MaxShAmt = LHSMax.countLeadingZeros() - 1;
1654+
if (RHSMin <= MaxShAmt)
1655+
MaxShl = LHSMax << std::min(RHSMax, MaxShAmt);
1656+
RHSMin = std::max(RHSMin, MaxShAmt + 1);
1657+
RHSMax = std::min(RHSMax, LHSMin.countLeadingZeros() - 1);
1658+
if (RHSMin <= RHSMax)
1659+
MaxShl = APIntOps::umax(MaxShl,
1660+
APInt::getBitsSet(BitWidth, RHSMin, BitWidth - 1));
1661+
return ConstantRange::getNonEmpty(MinShl, MaxShl + 1);
1662+
}
1663+
1664+
static ConstantRange computeShlNSWWithNegLHS(const APInt &LHSMin,
1665+
const APInt &LHSMax,
1666+
unsigned RHSMin, unsigned RHSMax) {
1667+
unsigned BitWidth = LHSMin.getBitWidth();
1668+
bool Overflow;
1669+
APInt MaxShl = LHSMax.sshl_ov(RHSMin, Overflow);
1670+
if (Overflow)
1671+
return ConstantRange::getEmpty(BitWidth);
1672+
APInt MinShl = MaxShl;
1673+
unsigned MaxShAmt = LHSMin.countLeadingOnes() - 1;
1674+
if (RHSMin <= MaxShAmt)
1675+
MinShl = LHSMin.shl(std::min(RHSMax, MaxShAmt));
1676+
RHSMin = std::max(RHSMin, MaxShAmt + 1);
1677+
RHSMax = std::min(RHSMax, LHSMax.countLeadingOnes() - 1);
1678+
if (RHSMin <= RHSMax)
1679+
MinShl = APInt::getSignMask(BitWidth);
1680+
return ConstantRange::getNonEmpty(MinShl, MaxShl + 1);
1681+
}
1682+
1683+
static ConstantRange computeShlNSW(const ConstantRange &LHS,
1684+
const ConstantRange &RHS) {
1685+
unsigned BitWidth = LHS.getBitWidth();
1686+
unsigned RHSMin = RHS.getUnsignedMin().getLimitedValue(BitWidth);
1687+
unsigned RHSMax = RHS.getUnsignedMax().getLimitedValue(BitWidth);
1688+
APInt LHSMin = LHS.getSignedMin();
1689+
APInt LHSMax = LHS.getSignedMax();
1690+
if (LHSMin.isNonNegative())
1691+
return computeShlNSWWithNNegLHS(LHSMin, LHSMax, RHSMin, RHSMax);
1692+
else if (LHSMax.isNegative())
1693+
return computeShlNSWWithNegLHS(LHSMin, LHSMax, RHSMin, RHSMax);
1694+
return computeShlNSWWithNNegLHS(APInt::getZero(BitWidth), LHSMax, RHSMin,
1695+
RHSMax)
1696+
.unionWith(computeShlNSWWithNegLHS(LHSMin, APInt::getAllOnes(BitWidth),
1697+
RHSMin, RHSMax),
1698+
ConstantRange::Signed);
1699+
}
1700+
16201701
ConstantRange ConstantRange::shlWithNoWrap(const ConstantRange &Other,
16211702
unsigned NoWrapKind,
16221703
PreferredRangeType RangeType) const {
16231704
if (isEmptySet() || Other.isEmptySet())
16241705
return getEmpty();
16251706

1626-
ConstantRange Result = shl(Other);
1627-
1628-
if (NoWrapKind & OverflowingBinaryOperator::NoSignedWrap)
1629-
Result = Result.intersectWith(sshl_sat(Other), RangeType);
1630-
1631-
if (NoWrapKind & OverflowingBinaryOperator::NoUnsignedWrap)
1632-
Result = Result.intersectWith(ushl_sat(Other), RangeType);
1633-
1634-
return Result;
1707+
switch (NoWrapKind) {
1708+
case 0:
1709+
return shl(Other);
1710+
case OverflowingBinaryOperator::NoSignedWrap:
1711+
return computeShlNSW(*this, Other);
1712+
case OverflowingBinaryOperator::NoUnsignedWrap:
1713+
return computeShlNUW(*this, Other);
1714+
case OverflowingBinaryOperator::NoSignedWrap |
1715+
OverflowingBinaryOperator::NoUnsignedWrap:
1716+
return computeShlNSW(*this, Other)
1717+
.intersectWith(computeShlNUW(*this, Other), RangeType);
1718+
default:
1719+
llvm_unreachable("Invalid NoWrapKind");
1720+
}
16351721
}
16361722

16371723
ConstantRange

llvm/test/Transforms/CorrelatedValuePropagation/shl.ll

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ define i8 @test4(i8 %a, i8 %b) {
8686
; CHECK-NEXT: br i1 [[CMP]], label [[BB:%.*]], label [[EXIT:%.*]]
8787
; CHECK: bb:
8888
; CHECK-NEXT: [[SHL:%.*]] = shl nuw nsw i8 [[A:%.*]], [[B]]
89-
; CHECK-NEXT: ret i8 -1
89+
; CHECK-NEXT: ret i8 [[SHL]]
9090
; CHECK: exit:
9191
; CHECK-NEXT: ret i8 0
9292
;
@@ -105,7 +105,7 @@ exit:
105105
define i8 @test5(i8 %b) {
106106
; CHECK-LABEL: @test5(
107107
; CHECK-NEXT: [[SHL:%.*]] = shl nuw nsw i8 0, [[B:%.*]]
108-
; CHECK-NEXT: ret i8 [[SHL]]
108+
; CHECK-NEXT: ret i8 0
109109
;
110110
%shl = shl i8 0, %b
111111
ret i8 %shl
@@ -474,3 +474,17 @@ define i1 @shl_nuw_nsw_test4(i32 %x, i32 range(i32 0, 32) %k) {
474474
%cmp = icmp eq i64 %shl, -9223372036854775808
475475
ret i1 %cmp
476476
}
477+
478+
define i1 @shl_nuw_nsw_test5(i32 %x) {
479+
; CHECK-LABEL: @shl_nuw_nsw_test5(
480+
; CHECK-NEXT: entry:
481+
; CHECK-NEXT: [[SHL:%.*]] = shl nuw nsw i32 768, [[X:%.*]]
482+
; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[SHL]], 1846
483+
; CHECK-NEXT: ret i1 true
484+
;
485+
entry:
486+
%shl = shl nuw nsw i32 768, %x
487+
%add = add nuw i32 %shl, 1846
488+
%cmp = icmp sgt i32 %add, 0
489+
ret i1 %cmp
490+
}

llvm/unittests/IR/ConstantRangeTest.cpp

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,12 @@ static bool CheckNonSignWrappedOnly(const ConstantRange &CR1,
228228
return !CR1.isSignWrappedSet() && !CR2.isSignWrappedSet();
229229
}
230230

231+
static bool
232+
CheckNoSignedWrappedLHSAndNoWrappedRHSOnly(const ConstantRange &CR1,
233+
const ConstantRange &CR2) {
234+
return !CR1.isSignWrappedSet() && !CR2.isWrappedSet();
235+
}
236+
231237
static bool CheckNonWrappedOrSignWrappedOnly(const ConstantRange &CR1,
232238
const ConstantRange &CR2) {
233239
return !CR1.isWrappedSet() && !CR1.isSignWrappedSet() &&
@@ -1506,7 +1512,9 @@ TEST_F(ConstantRangeTest, ShlWithNoWrap) {
15061512
using OBO = OverflowingBinaryOperator;
15071513
TestBinaryOpExhaustive(
15081514
[](const ConstantRange &CR1, const ConstantRange &CR2) {
1509-
return CR1.shlWithNoWrap(CR2, OBO::NoUnsignedWrap);
1515+
ConstantRange Res = CR1.shlWithNoWrap(CR2, OBO::NoUnsignedWrap);
1516+
EXPECT_TRUE(CR1.shl(CR2).contains(Res));
1517+
return Res;
15101518
},
15111519
[](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
15121520
bool IsOverflow;
@@ -1515,7 +1523,7 @@ TEST_F(ConstantRangeTest, ShlWithNoWrap) {
15151523
return std::nullopt;
15161524
return Res;
15171525
},
1518-
PreferSmallest, CheckCorrectnessOnly);
1526+
PreferSmallest, CheckNonWrappedOnly);
15191527
TestBinaryOpExhaustive(
15201528
[](const ConstantRange &CR1, const ConstantRange &CR2) {
15211529
return CR1.shlWithNoWrap(CR2, OBO::NoSignedWrap);
@@ -1527,7 +1535,7 @@ TEST_F(ConstantRangeTest, ShlWithNoWrap) {
15271535
return std::nullopt;
15281536
return Res;
15291537
},
1530-
PreferSmallest, CheckCorrectnessOnly);
1538+
PreferSmallestSigned, CheckNoSignedWrappedLHSAndNoWrappedRHSOnly);
15311539
TestBinaryOpExhaustive(
15321540
[](const ConstantRange &CR1, const ConstantRange &CR2) {
15331541
return CR1.shlWithNoWrap(CR2, OBO::NoUnsignedWrap | OBO::NoSignedWrap);
@@ -1542,6 +1550,31 @@ TEST_F(ConstantRangeTest, ShlWithNoWrap) {
15421550
return Res1;
15431551
},
15441552
PreferSmallest, CheckCorrectnessOnly);
1553+
1554+
EXPECT_EQ(One.shlWithNoWrap(Full, OBO::NoSignedWrap),
1555+
ConstantRange(APInt(16, 10), APInt(16, 20481)));
1556+
EXPECT_EQ(One.shlWithNoWrap(Full, OBO::NoUnsignedWrap),
1557+
ConstantRange(APInt(16, 10), APInt(16, -24575)));
1558+
EXPECT_EQ(One.shlWithNoWrap(Full, OBO::NoSignedWrap | OBO::NoUnsignedWrap),
1559+
ConstantRange(APInt(16, 10), APInt(16, 20481)));
1560+
ConstantRange NegOne(APInt(16, 0xffff));
1561+
EXPECT_EQ(NegOne.shlWithNoWrap(Full, OBO::NoSignedWrap),
1562+
ConstantRange(APInt(16, -32768), APInt(16, 0)));
1563+
EXPECT_EQ(NegOne.shlWithNoWrap(Full, OBO::NoUnsignedWrap), NegOne);
1564+
EXPECT_EQ(ConstantRange(APInt(16, 768))
1565+
.shlWithNoWrap(Full, OBO::NoSignedWrap | OBO::NoUnsignedWrap),
1566+
ConstantRange(APInt(16, 768), APInt(16, 24577)));
1567+
EXPECT_EQ(Full.shlWithNoWrap(ConstantRange(APInt(16, 1), APInt(16, 16)),
1568+
OBO::NoUnsignedWrap),
1569+
ConstantRange(APInt(16, 0), APInt(16, -1)));
1570+
EXPECT_EQ(ConstantRange(APInt(4, 3), APInt(4, -8))
1571+
.shlWithNoWrap(ConstantRange(APInt(4, 0), APInt(4, 4)),
1572+
OBO::NoSignedWrap),
1573+
ConstantRange(APInt(4, 3), APInt(4, -8)));
1574+
EXPECT_EQ(ConstantRange(APInt(4, -1), APInt(4, 0))
1575+
.shlWithNoWrap(ConstantRange(APInt(4, 1), APInt(4, 4)),
1576+
OBO::NoSignedWrap),
1577+
ConstantRange(APInt(4, -8), APInt(4, -1)));
15451578
}
15461579

15471580
TEST_F(ConstantRangeTest, Lshr) {

0 commit comments

Comments
 (0)