@@ -6344,61 +6344,62 @@ const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
63446344 return getGEPExpr(GEP, IndexExprs);
63456345}
63466346
6347- APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) {
6347+ APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S,
6348+ const Instruction *CtxI) {
63486349 uint64_t BitWidth = getTypeSizeInBits(S->getType());
63496350 auto GetShiftedByZeros = [BitWidth](uint32_t TrailingZeros) {
63506351 return TrailingZeros >= BitWidth
63516352 ? APInt::getZero(BitWidth)
63526353 : APInt::getOneBitSet(BitWidth, TrailingZeros);
63536354 };
6354- auto GetGCDMultiple = [this](const SCEVNAryExpr *N) {
6355+ auto GetGCDMultiple = [this, CtxI ](const SCEVNAryExpr *N) {
63556356 // The result is GCD of all operands results.
6356- APInt Res = getConstantMultiple(N->getOperand(0));
6357+ APInt Res = getConstantMultiple(N->getOperand(0), CtxI );
63576358 for (unsigned I = 1, E = N->getNumOperands(); I < E && Res != 1; ++I)
63586359 Res = APIntOps::GreatestCommonDivisor(
6359- Res, getConstantMultiple(N->getOperand(I)));
6360+ Res, getConstantMultiple(N->getOperand(I), CtxI ));
63606361 return Res;
63616362 };
63626363
63636364 switch (S->getSCEVType()) {
63646365 case scConstant:
63656366 return cast<SCEVConstant>(S)->getAPInt();
63666367 case scPtrToInt:
6367- return getConstantMultiple(cast<SCEVPtrToIntExpr>(S)->getOperand());
6368+ return getConstantMultiple(cast<SCEVPtrToIntExpr>(S)->getOperand(), CtxI );
63686369 case scUDivExpr:
63696370 case scVScale:
63706371 return APInt(BitWidth, 1);
63716372 case scTruncate: {
63726373 // Only multiples that are a power of 2 will hold after truncation.
63736374 const SCEVTruncateExpr *T = cast<SCEVTruncateExpr>(S);
6374- uint32_t TZ = getMinTrailingZeros(T->getOperand());
6375+ uint32_t TZ = getMinTrailingZeros(T->getOperand(), CtxI );
63756376 return GetShiftedByZeros(TZ);
63766377 }
63776378 case scZeroExtend: {
63786379 const SCEVZeroExtendExpr *Z = cast<SCEVZeroExtendExpr>(S);
6379- return getConstantMultiple(Z->getOperand()).zext(BitWidth);
6380+ return getConstantMultiple(Z->getOperand(), CtxI ).zext(BitWidth);
63806381 }
63816382 case scSignExtend: {
63826383 // Only multiples that are a power of 2 will hold after sext.
63836384 const SCEVSignExtendExpr *E = cast<SCEVSignExtendExpr>(S);
6384- uint32_t TZ = getMinTrailingZeros(E->getOperand());
6385+ uint32_t TZ = getMinTrailingZeros(E->getOperand(), CtxI );
63856386 return GetShiftedByZeros(TZ);
63866387 }
63876388 case scMulExpr: {
63886389 const SCEVMulExpr *M = cast<SCEVMulExpr>(S);
63896390 if (M->hasNoUnsignedWrap()) {
63906391 // The result is the product of all operand results.
6391- APInt Res = getConstantMultiple(M->getOperand(0));
6392+ APInt Res = getConstantMultiple(M->getOperand(0), CtxI );
63926393 for (const SCEV *Operand : M->operands().drop_front())
6393- Res = Res * getConstantMultiple(Operand);
6394+ Res = Res * getConstantMultiple(Operand, CtxI );
63946395 return Res;
63956396 }
63966397
63976398 // If there are no wrap guarentees, find the trailing zeros, which is the
63986399 // sum of trailing zeros for all its operands.
63996400 uint32_t TZ = 0;
64006401 for (const SCEV *Operand : M->operands())
6401- TZ += getMinTrailingZeros(Operand);
6402+ TZ += getMinTrailingZeros(Operand, CtxI );
64026403 return GetShiftedByZeros(TZ);
64036404 }
64046405 case scAddExpr:
@@ -6407,9 +6408,9 @@ APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) {
64076408 if (N->hasNoUnsignedWrap())
64086409 return GetGCDMultiple(N);
64096410 // Find the trailing bits, which is the minimum of its operands.
6410- uint32_t TZ = getMinTrailingZeros(N->getOperand(0));
6411+ uint32_t TZ = getMinTrailingZeros(N->getOperand(0), CtxI );
64116412 for (const SCEV *Operand : N->operands().drop_front())
6412- TZ = std::min(TZ, getMinTrailingZeros(Operand));
6413+ TZ = std::min(TZ, getMinTrailingZeros(Operand, CtxI ));
64136414 return GetShiftedByZeros(TZ);
64146415 }
64156416 case scUMaxExpr:
@@ -6422,7 +6423,7 @@ APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) {
64226423 // ask ValueTracking for known bits
64236424 const SCEVUnknown *U = cast<SCEVUnknown>(S);
64246425 unsigned Known =
6425- computeKnownBits(U->getValue(), getDataLayout(), &AC, nullptr , &DT)
6426+ computeKnownBits(U->getValue(), getDataLayout(), &AC, CtxI , &DT)
64266427 .countMinTrailingZeros();
64276428 return GetShiftedByZeros(Known);
64286429 }
@@ -6432,12 +6433,18 @@ APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) {
64326433 llvm_unreachable("Unknown SCEV kind!");
64336434}
64346435
6435- APInt ScalarEvolution::getConstantMultiple(const SCEV *S) {
6436+ APInt ScalarEvolution::getConstantMultiple(const SCEV *S,
6437+ const Instruction *CtxI) {
6438+ // Skip looking up and updating the cache if there is a context instruction,
6439+ // as the result will only be valid in the specified context.
6440+ if (CtxI)
6441+ return getConstantMultipleImpl(S, CtxI);
6442+
64366443 auto I = ConstantMultipleCache.find(S);
64376444 if (I != ConstantMultipleCache.end())
64386445 return I->second;
64396446
6440- APInt Result = getConstantMultipleImpl(S);
6447+ APInt Result = getConstantMultipleImpl(S, CtxI );
64416448 auto InsertPair = ConstantMultipleCache.insert({S, Result});
64426449 assert(InsertPair.second && "Should insert a new key");
64436450 return InsertPair.first->second;
@@ -6448,8 +6455,9 @@ APInt ScalarEvolution::getNonZeroConstantMultiple(const SCEV *S) {
64486455 return Multiple == 0 ? APInt(Multiple.getBitWidth(), 1) : Multiple;
64496456}
64506457
6451- uint32_t ScalarEvolution::getMinTrailingZeros(const SCEV *S) {
6452- return std::min(getConstantMultiple(S).countTrailingZeros(),
6458+ uint32_t ScalarEvolution::getMinTrailingZeros(const SCEV *S,
6459+ const Instruction *CtxI) {
6460+ return std::min(getConstantMultiple(S, CtxI).countTrailingZeros(),
64536461 (unsigned)getTypeSizeInBits(S->getType()));
64546462}
64556463
@@ -10228,8 +10236,7 @@ const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const {
1022810236static const SCEV *
1022910237SolveLinEquationWithOverflow(const APInt &A, const SCEV *B,
1023010238 SmallVectorImpl<const SCEVPredicate *> *Predicates,
10231-
10232- ScalarEvolution &SE) {
10239+ ScalarEvolution &SE, const Loop *L) {
1023310240 uint32_t BW = A.getBitWidth();
1023410241 assert(BW == SE.getTypeSizeInBits(B->getType()));
1023510242 assert(A != 0 && "A must be non-zero.");
@@ -10245,7 +10252,12 @@ SolveLinEquationWithOverflow(const APInt &A, const SCEV *B,
1024510252 //
1024610253 // B is divisible by D if and only if the multiplicity of prime factor 2 for B
1024710254 // is not less than multiplicity of this prime factor for D.
10248- if (SE.getMinTrailingZeros(B) < Mult2) {
10255+ unsigned MinTZ = SE.getMinTrailingZeros(B);
10256+ // Try again with the terminator of the loop predecessor for context-specific
10257+ // result, if MinTZ s too small.
10258+ if (MinTZ < Mult2 && L->getLoopPredecessor())
10259+ MinTZ = SE.getMinTrailingZeros(B, L->getLoopPredecessor()->getTerminator());
10260+ if (MinTZ < Mult2) {
1024910261 // Check if we can prove there's no remainder using URem.
1025010262 const SCEV *URem =
1025110263 SE.getURemExpr(B, SE.getConstant(APInt::getOneBitSet(BW, Mult2)));
@@ -10693,7 +10705,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
1069310705 return getCouldNotCompute();
1069410706 const SCEV *E = SolveLinEquationWithOverflow(
1069510707 StepC->getAPInt(), getNegativeSCEV(Start),
10696- AllowPredicates ? &Predicates : nullptr, *this);
10708+ AllowPredicates ? &Predicates : nullptr, *this, L );
1069710709
1069810710 const SCEV *M = E;
1069910711 if (E != getCouldNotCompute()) {
0 commit comments