Skip to content

Commit 037fffd

Browse files
committed
[SCEV] Pass loop pred branch as context instruction to getMinTrailingZ. (llvm#160941)
When computing the backedge taken count, we know that the expression must be valid just before we enter the loop. Using the terminator of the loop predecessor as context instruction for getConstantMultiple, getMinTrailingZeros allows using information from things like alignment assumptions. When a context instruction is used, the result is not cached, as it is only valid at the specific context instruction. Compile-time looks neutral: http://llvm-compile-time-tracker.com/compare.php?from=9be276ec75c087595ebb62fe11b35c1a90371a49&to=745980f5e1c8094ea1293cd145d0ef1390f03029&stat=instructions:u No impact on llvm-opt-benchmark (dtcxzyw/llvm-opt-benchmark#2867), but leads to additonal unrolling in ~90 files across a C/C++ based corpus including LLVM on AArch64 using libc++ (which emits alignment assumptions for things like std::vector::begin). PR: llvm#160941 (cherry picked from commit c7fbe38)
1 parent 921ddad commit 037fffd

File tree

3 files changed

+57
-55
lines changed

3 files changed

+57
-55
lines changed

llvm/include/llvm/Analysis/ScalarEvolution.h

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,10 +1000,14 @@ class ScalarEvolution {
10001000
/// (at every loop iteration). It is, at the same time, the minimum number
10011001
/// of times S is divisible by 2. For example, given {4,+,8} it returns 2.
10021002
/// If S is guaranteed to be 0, it returns the bitwidth of S.
1003-
LLVM_ABI uint32_t getMinTrailingZeros(const SCEV *S);
1003+
/// If \p CtxI is not nullptr, return a constant multiple valid at \p CtxI.
1004+
LLVM_ABI uint32_t getMinTrailingZeros(const SCEV *S,
1005+
const Instruction *CtxI = nullptr);
10041006

1005-
/// Returns the max constant multiple of S.
1006-
LLVM_ABI APInt getConstantMultiple(const SCEV *S);
1007+
/// Returns the max constant multiple of S. If \p CtxI is not nullptr, return
1008+
/// a constant multiple valid at \p CtxI.
1009+
LLVM_ABI APInt getConstantMultiple(const SCEV *S,
1010+
const Instruction *CtxI = nullptr);
10071011

10081012
// Returns the max constant multiple of S. If S is exactly 0, return 1.
10091013
LLVM_ABI APInt getNonZeroConstantMultiple(const SCEV *S);
@@ -1525,8 +1529,10 @@ class ScalarEvolution {
15251529
/// Return the Value set from which the SCEV expr is generated.
15261530
ArrayRef<Value *> getSCEVValues(const SCEV *S);
15271531

1528-
/// Private helper method for the getConstantMultiple method.
1529-
APInt getConstantMultipleImpl(const SCEV *S);
1532+
/// Private helper method for the getConstantMultiple method. If \p CtxI is
1533+
/// not nullptr, return a constant multiple valid at \p CtxI.
1534+
APInt getConstantMultipleImpl(const SCEV *S,
1535+
const Instruction *Ctx = nullptr);
15301536

15311537
/// Information about the number of times a particular loop exit may be
15321538
/// reached before exiting the loop.

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6344,61 +6344,62 @@ const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
63446344
return getGEPExpr(GEP, IndexExprs);
63456345
}
63466346

6347-
APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) {
6347+
APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S,
6348+
const Instruction *CtxI) {
63486349
uint64_t BitWidth = getTypeSizeInBits(S->getType());
63496350
auto GetShiftedByZeros = [BitWidth](uint32_t TrailingZeros) {
63506351
return TrailingZeros >= BitWidth
63516352
? APInt::getZero(BitWidth)
63526353
: APInt::getOneBitSet(BitWidth, TrailingZeros);
63536354
};
6354-
auto GetGCDMultiple = [this](const SCEVNAryExpr *N) {
6355+
auto GetGCDMultiple = [this, CtxI](const SCEVNAryExpr *N) {
63556356
// The result is GCD of all operands results.
6356-
APInt Res = getConstantMultiple(N->getOperand(0));
6357+
APInt Res = getConstantMultiple(N->getOperand(0), CtxI);
63576358
for (unsigned I = 1, E = N->getNumOperands(); I < E && Res != 1; ++I)
63586359
Res = APIntOps::GreatestCommonDivisor(
6359-
Res, getConstantMultiple(N->getOperand(I)));
6360+
Res, getConstantMultiple(N->getOperand(I), CtxI));
63606361
return Res;
63616362
};
63626363

63636364
switch (S->getSCEVType()) {
63646365
case scConstant:
63656366
return cast<SCEVConstant>(S)->getAPInt();
63666367
case scPtrToInt:
6367-
return getConstantMultiple(cast<SCEVPtrToIntExpr>(S)->getOperand());
6368+
return getConstantMultiple(cast<SCEVPtrToIntExpr>(S)->getOperand(), CtxI);
63686369
case scUDivExpr:
63696370
case scVScale:
63706371
return APInt(BitWidth, 1);
63716372
case scTruncate: {
63726373
// Only multiples that are a power of 2 will hold after truncation.
63736374
const SCEVTruncateExpr *T = cast<SCEVTruncateExpr>(S);
6374-
uint32_t TZ = getMinTrailingZeros(T->getOperand());
6375+
uint32_t TZ = getMinTrailingZeros(T->getOperand(), CtxI);
63756376
return GetShiftedByZeros(TZ);
63766377
}
63776378
case scZeroExtend: {
63786379
const SCEVZeroExtendExpr *Z = cast<SCEVZeroExtendExpr>(S);
6379-
return getConstantMultiple(Z->getOperand()).zext(BitWidth);
6380+
return getConstantMultiple(Z->getOperand(), CtxI).zext(BitWidth);
63806381
}
63816382
case scSignExtend: {
63826383
// Only multiples that are a power of 2 will hold after sext.
63836384
const SCEVSignExtendExpr *E = cast<SCEVSignExtendExpr>(S);
6384-
uint32_t TZ = getMinTrailingZeros(E->getOperand());
6385+
uint32_t TZ = getMinTrailingZeros(E->getOperand(), CtxI);
63856386
return GetShiftedByZeros(TZ);
63866387
}
63876388
case scMulExpr: {
63886389
const SCEVMulExpr *M = cast<SCEVMulExpr>(S);
63896390
if (M->hasNoUnsignedWrap()) {
63906391
// The result is the product of all operand results.
6391-
APInt Res = getConstantMultiple(M->getOperand(0));
6392+
APInt Res = getConstantMultiple(M->getOperand(0), CtxI);
63926393
for (const SCEV *Operand : M->operands().drop_front())
6393-
Res = Res * getConstantMultiple(Operand);
6394+
Res = Res * getConstantMultiple(Operand, CtxI);
63946395
return Res;
63956396
}
63966397

63976398
// If there are no wrap guarentees, find the trailing zeros, which is the
63986399
// sum of trailing zeros for all its operands.
63996400
uint32_t TZ = 0;
64006401
for (const SCEV *Operand : M->operands())
6401-
TZ += getMinTrailingZeros(Operand);
6402+
TZ += getMinTrailingZeros(Operand, CtxI);
64026403
return GetShiftedByZeros(TZ);
64036404
}
64046405
case scAddExpr:
@@ -6407,9 +6408,9 @@ APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) {
64076408
if (N->hasNoUnsignedWrap())
64086409
return GetGCDMultiple(N);
64096410
// Find the trailing bits, which is the minimum of its operands.
6410-
uint32_t TZ = getMinTrailingZeros(N->getOperand(0));
6411+
uint32_t TZ = getMinTrailingZeros(N->getOperand(0), CtxI);
64116412
for (const SCEV *Operand : N->operands().drop_front())
6412-
TZ = std::min(TZ, getMinTrailingZeros(Operand));
6413+
TZ = std::min(TZ, getMinTrailingZeros(Operand, CtxI));
64136414
return GetShiftedByZeros(TZ);
64146415
}
64156416
case scUMaxExpr:
@@ -6422,7 +6423,7 @@ APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) {
64226423
// ask ValueTracking for known bits
64236424
const SCEVUnknown *U = cast<SCEVUnknown>(S);
64246425
unsigned Known =
6425-
computeKnownBits(U->getValue(), getDataLayout(), &AC, nullptr, &DT)
6426+
computeKnownBits(U->getValue(), getDataLayout(), &AC, CtxI, &DT)
64266427
.countMinTrailingZeros();
64276428
return GetShiftedByZeros(Known);
64286429
}
@@ -6432,12 +6433,18 @@ APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) {
64326433
llvm_unreachable("Unknown SCEV kind!");
64336434
}
64346435

6435-
APInt ScalarEvolution::getConstantMultiple(const SCEV *S) {
6436+
APInt ScalarEvolution::getConstantMultiple(const SCEV *S,
6437+
const Instruction *CtxI) {
6438+
// Skip looking up and updating the cache if there is a context instruction,
6439+
// as the result will only be valid in the specified context.
6440+
if (CtxI)
6441+
return getConstantMultipleImpl(S, CtxI);
6442+
64366443
auto I = ConstantMultipleCache.find(S);
64376444
if (I != ConstantMultipleCache.end())
64386445
return I->second;
64396446

6440-
APInt Result = getConstantMultipleImpl(S);
6447+
APInt Result = getConstantMultipleImpl(S, CtxI);
64416448
auto InsertPair = ConstantMultipleCache.insert({S, Result});
64426449
assert(InsertPair.second && "Should insert a new key");
64436450
return InsertPair.first->second;
@@ -6448,8 +6455,9 @@ APInt ScalarEvolution::getNonZeroConstantMultiple(const SCEV *S) {
64486455
return Multiple == 0 ? APInt(Multiple.getBitWidth(), 1) : Multiple;
64496456
}
64506457

6451-
uint32_t ScalarEvolution::getMinTrailingZeros(const SCEV *S) {
6452-
return std::min(getConstantMultiple(S).countTrailingZeros(),
6458+
uint32_t ScalarEvolution::getMinTrailingZeros(const SCEV *S,
6459+
const Instruction *CtxI) {
6460+
return std::min(getConstantMultiple(S, CtxI).countTrailingZeros(),
64536461
(unsigned)getTypeSizeInBits(S->getType()));
64546462
}
64556463

@@ -10228,8 +10236,7 @@ const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const {
1022810236
static const SCEV *
1022910237
SolveLinEquationWithOverflow(const APInt &A, const SCEV *B,
1023010238
SmallVectorImpl<const SCEVPredicate *> *Predicates,
10231-
10232-
ScalarEvolution &SE) {
10239+
ScalarEvolution &SE, const Loop *L) {
1023310240
uint32_t BW = A.getBitWidth();
1023410241
assert(BW == SE.getTypeSizeInBits(B->getType()));
1023510242
assert(A != 0 && "A must be non-zero.");
@@ -10245,7 +10252,12 @@ SolveLinEquationWithOverflow(const APInt &A, const SCEV *B,
1024510252
//
1024610253
// B is divisible by D if and only if the multiplicity of prime factor 2 for B
1024710254
// is not less than multiplicity of this prime factor for D.
10248-
if (SE.getMinTrailingZeros(B) < Mult2) {
10255+
unsigned MinTZ = SE.getMinTrailingZeros(B);
10256+
// Try again with the terminator of the loop predecessor for context-specific
10257+
// result, if MinTZ s too small.
10258+
if (MinTZ < Mult2 && L->getLoopPredecessor())
10259+
MinTZ = SE.getMinTrailingZeros(B, L->getLoopPredecessor()->getTerminator());
10260+
if (MinTZ < Mult2) {
1024910261
// Check if we can prove there's no remainder using URem.
1025010262
const SCEV *URem =
1025110263
SE.getURemExpr(B, SE.getConstant(APInt::getOneBitSet(BW, Mult2)));
@@ -10693,7 +10705,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
1069310705
return getCouldNotCompute();
1069410706
const SCEV *E = SolveLinEquationWithOverflow(
1069510707
StepC->getAPInt(), getNegativeSCEV(Start),
10696-
AllowPredicates ? &Predicates : nullptr, *this);
10708+
AllowPredicates ? &Predicates : nullptr, *this, L);
1069710709

1069810710
const SCEV *M = E;
1069910711
if (E != getCouldNotCompute()) {

llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll

Lines changed: 12 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -615,22 +615,14 @@ define void @test_ptrs_aligned_by_4_via_assumption(ptr %start, ptr %end) {
615615
; CHECK-LABEL: 'test_ptrs_aligned_by_4_via_assumption'
616616
; CHECK-NEXT: Classifying expressions for: @test_ptrs_aligned_by_4_via_assumption
617617
; CHECK-NEXT: %iv = phi ptr [ %start, %entry ], [ %iv.next, %loop ]
618-
; CHECK-NEXT: --> {%start,+,4}<%loop> U: full-set S: full-set Exits: <<Unknown>> LoopDispositions: { %loop: Computable }
618+
; CHECK-NEXT: --> {%start,+,4}<%loop> U: full-set S: full-set Exits: ((4 * ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4))<nuw> + %start) LoopDispositions: { %loop: Computable }
619619
; CHECK-NEXT: %iv.next = getelementptr i8, ptr %iv, i64 4
620-
; CHECK-NEXT: --> {(4 + %start),+,4}<%loop> U: full-set S: full-set Exits: <<Unknown>> LoopDispositions: { %loop: Computable }
620+
; CHECK-NEXT: --> {(4 + %start),+,4}<%loop> U: full-set S: full-set Exits: (4 + (4 * ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4))<nuw> + %start) LoopDispositions: { %loop: Computable }
621621
; CHECK-NEXT: Determining loop execution counts for: @test_ptrs_aligned_by_4_via_assumption
622-
; CHECK-NEXT: Loop %loop: Unpredictable backedge-taken count.
623-
; CHECK-NEXT: Loop %loop: Unpredictable constant max backedge-taken count.
624-
; CHECK-NEXT: Loop %loop: Unpredictable symbolic max backedge-taken count.
625-
; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4)
626-
; CHECK-NEXT: Predicates:
627-
; CHECK-NEXT: Equal predicate: (zext i2 ((trunc i64 (ptrtoint ptr %end to i64) to i2) + (-1 * (trunc i64 (ptrtoint ptr %start to i64) to i2))) to i64) == 0
628-
; CHECK-NEXT: Loop %loop: Predicated constant max backedge-taken count is i64 4611686018427387903
629-
; CHECK-NEXT: Predicates:
630-
; CHECK-NEXT: Equal predicate: (zext i2 ((trunc i64 (ptrtoint ptr %end to i64) to i2) + (-1 * (trunc i64 (ptrtoint ptr %start to i64) to i2))) to i64) == 0
631-
; CHECK-NEXT: Loop %loop: Predicated symbolic max backedge-taken count is ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4)
632-
; CHECK-NEXT: Predicates:
633-
; CHECK-NEXT: Equal predicate: (zext i2 ((trunc i64 (ptrtoint ptr %end to i64) to i2) + (-1 * (trunc i64 (ptrtoint ptr %start to i64) to i2))) to i64) == 0
622+
; CHECK-NEXT: Loop %loop: backedge-taken count is ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4)
623+
; CHECK-NEXT: Loop %loop: constant max backedge-taken count is i64 4611686018427387903
624+
; CHECK-NEXT: Loop %loop: symbolic max backedge-taken count is ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4)
625+
; CHECK-NEXT: Loop %loop: Trip multiple is 1
634626
;
635627
entry:
636628
call void @llvm.assume(i1 true) [ "align"(ptr %start, i64 4) ]
@@ -652,22 +644,14 @@ define void @test_ptrs_aligned_by_8_via_assumption(ptr %start, ptr %end) {
652644
; CHECK-LABEL: 'test_ptrs_aligned_by_8_via_assumption'
653645
; CHECK-NEXT: Classifying expressions for: @test_ptrs_aligned_by_8_via_assumption
654646
; CHECK-NEXT: %iv = phi ptr [ %start, %entry ], [ %iv.next, %loop ]
655-
; CHECK-NEXT: --> {%start,+,4}<%loop> U: full-set S: full-set Exits: <<Unknown>> LoopDispositions: { %loop: Computable }
647+
; CHECK-NEXT: --> {%start,+,4}<%loop> U: full-set S: full-set Exits: ((4 * ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4))<nuw> + %start) LoopDispositions: { %loop: Computable }
656648
; CHECK-NEXT: %iv.next = getelementptr i8, ptr %iv, i64 4
657-
; CHECK-NEXT: --> {(4 + %start),+,4}<%loop> U: full-set S: full-set Exits: <<Unknown>> LoopDispositions: { %loop: Computable }
649+
; CHECK-NEXT: --> {(4 + %start),+,4}<%loop> U: full-set S: full-set Exits: (4 + (4 * ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4))<nuw> + %start) LoopDispositions: { %loop: Computable }
658650
; CHECK-NEXT: Determining loop execution counts for: @test_ptrs_aligned_by_8_via_assumption
659-
; CHECK-NEXT: Loop %loop: Unpredictable backedge-taken count.
660-
; CHECK-NEXT: Loop %loop: Unpredictable constant max backedge-taken count.
661-
; CHECK-NEXT: Loop %loop: Unpredictable symbolic max backedge-taken count.
662-
; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4)
663-
; CHECK-NEXT: Predicates:
664-
; CHECK-NEXT: Equal predicate: (zext i2 ((trunc i64 (ptrtoint ptr %end to i64) to i2) + (-1 * (trunc i64 (ptrtoint ptr %start to i64) to i2))) to i64) == 0
665-
; CHECK-NEXT: Loop %loop: Predicated constant max backedge-taken count is i64 4611686018427387903
666-
; CHECK-NEXT: Predicates:
667-
; CHECK-NEXT: Equal predicate: (zext i2 ((trunc i64 (ptrtoint ptr %end to i64) to i2) + (-1 * (trunc i64 (ptrtoint ptr %start to i64) to i2))) to i64) == 0
668-
; CHECK-NEXT: Loop %loop: Predicated symbolic max backedge-taken count is ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4)
669-
; CHECK-NEXT: Predicates:
670-
; CHECK-NEXT: Equal predicate: (zext i2 ((trunc i64 (ptrtoint ptr %end to i64) to i2) + (-1 * (trunc i64 (ptrtoint ptr %start to i64) to i2))) to i64) == 0
651+
; CHECK-NEXT: Loop %loop: backedge-taken count is ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4)
652+
; CHECK-NEXT: Loop %loop: constant max backedge-taken count is i64 4611686018427387903
653+
; CHECK-NEXT: Loop %loop: symbolic max backedge-taken count is ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4)
654+
; CHECK-NEXT: Loop %loop: Trip multiple is 1
671655
;
672656
entry:
673657
call void @llvm.assume(i1 true) [ "align"(ptr %start, i64 8) ]

0 commit comments

Comments
 (0)