@@ -6351,61 +6351,62 @@ const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
63516351 return getGEPExpr(GEP, IndexExprs);
63526352}
63536353
6354- APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) {
6354+ APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S,
6355+ const Instruction *CtxI) {
63556356 uint64_t BitWidth = getTypeSizeInBits(S->getType());
63566357 auto GetShiftedByZeros = [BitWidth](uint32_t TrailingZeros) {
63576358 return TrailingZeros >= BitWidth
63586359 ? APInt::getZero(BitWidth)
63596360 : APInt::getOneBitSet(BitWidth, TrailingZeros);
63606361 };
6361- auto GetGCDMultiple = [this](const SCEVNAryExpr *N) {
6362+ auto GetGCDMultiple = [this, CtxI ](const SCEVNAryExpr *N) {
63626363 // The result is GCD of all operands results.
6363- APInt Res = getConstantMultiple(N->getOperand(0));
6364+ APInt Res = getConstantMultiple(N->getOperand(0), CtxI );
63646365 for (unsigned I = 1, E = N->getNumOperands(); I < E && Res != 1; ++I)
63656366 Res = APIntOps::GreatestCommonDivisor(
6366- Res, getConstantMultiple(N->getOperand(I)));
6367+ Res, getConstantMultiple(N->getOperand(I), CtxI ));
63676368 return Res;
63686369 };
63696370
63706371 switch (S->getSCEVType()) {
63716372 case scConstant:
63726373 return cast<SCEVConstant>(S)->getAPInt();
63736374 case scPtrToInt:
6374- return getConstantMultiple(cast<SCEVPtrToIntExpr>(S)->getOperand());
6375+ return getConstantMultiple(cast<SCEVPtrToIntExpr>(S)->getOperand(), CtxI );
63756376 case scUDivExpr:
63766377 case scVScale:
63776378 return APInt(BitWidth, 1);
63786379 case scTruncate: {
63796380 // Only multiples that are a power of 2 will hold after truncation.
63806381 const SCEVTruncateExpr *T = cast<SCEVTruncateExpr>(S);
6381- uint32_t TZ = getMinTrailingZeros(T->getOperand());
6382+ uint32_t TZ = getMinTrailingZeros(T->getOperand(), CtxI );
63826383 return GetShiftedByZeros(TZ);
63836384 }
63846385 case scZeroExtend: {
63856386 const SCEVZeroExtendExpr *Z = cast<SCEVZeroExtendExpr>(S);
6386- return getConstantMultiple(Z->getOperand()).zext(BitWidth);
6387+ return getConstantMultiple(Z->getOperand(), CtxI ).zext(BitWidth);
63876388 }
63886389 case scSignExtend: {
63896390 // Only multiples that are a power of 2 will hold after sext.
63906391 const SCEVSignExtendExpr *E = cast<SCEVSignExtendExpr>(S);
6391- uint32_t TZ = getMinTrailingZeros(E->getOperand());
6392+ uint32_t TZ = getMinTrailingZeros(E->getOperand(), CtxI );
63926393 return GetShiftedByZeros(TZ);
63936394 }
63946395 case scMulExpr: {
63956396 const SCEVMulExpr *M = cast<SCEVMulExpr>(S);
63966397 if (M->hasNoUnsignedWrap()) {
63976398 // The result is the product of all operand results.
6398- APInt Res = getConstantMultiple(M->getOperand(0));
6399+ APInt Res = getConstantMultiple(M->getOperand(0), CtxI );
63996400 for (const SCEV *Operand : M->operands().drop_front())
6400- Res = Res * getConstantMultiple(Operand);
6401+ Res = Res * getConstantMultiple(Operand, CtxI );
64016402 return Res;
64026403 }
64036404
64046405 // If there are no wrap guarentees, find the trailing zeros, which is the
64056406 // sum of trailing zeros for all its operands.
64066407 uint32_t TZ = 0;
64076408 for (const SCEV *Operand : M->operands())
6408- TZ += getMinTrailingZeros(Operand);
6409+ TZ += getMinTrailingZeros(Operand, CtxI );
64096410 return GetShiftedByZeros(TZ);
64106411 }
64116412 case scAddExpr:
@@ -6414,9 +6415,9 @@ APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) {
64146415 if (N->hasNoUnsignedWrap())
64156416 return GetGCDMultiple(N);
64166417 // Find the trailing bits, which is the minimum of its operands.
6417- uint32_t TZ = getMinTrailingZeros(N->getOperand(0));
6418+ uint32_t TZ = getMinTrailingZeros(N->getOperand(0), CtxI );
64186419 for (const SCEV *Operand : N->operands().drop_front())
6419- TZ = std::min(TZ, getMinTrailingZeros(Operand));
6420+ TZ = std::min(TZ, getMinTrailingZeros(Operand, CtxI ));
64206421 return GetShiftedByZeros(TZ);
64216422 }
64226423 case scUMaxExpr:
@@ -6429,7 +6430,7 @@ APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) {
64296430 // ask ValueTracking for known bits
64306431 const SCEVUnknown *U = cast<SCEVUnknown>(S);
64316432 unsigned Known =
6432- computeKnownBits(U->getValue(), getDataLayout(), &AC, nullptr , &DT)
6433+ computeKnownBits(U->getValue(), getDataLayout(), &AC, CtxI , &DT)
64336434 .countMinTrailingZeros();
64346435 return GetShiftedByZeros(Known);
64356436 }
@@ -6439,12 +6440,18 @@ APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) {
64396440 llvm_unreachable("Unknown SCEV kind!");
64406441}
64416442
6442- APInt ScalarEvolution::getConstantMultiple(const SCEV *S) {
6443+ APInt ScalarEvolution::getConstantMultiple(const SCEV *S,
6444+ const Instruction *CtxI) {
6445+ // Skip looking up and updating the cache if there is a context instruction,
6446+ // as the result will only be valid in the specified context.
6447+ if (CtxI)
6448+ return getConstantMultipleImpl(S, CtxI);
6449+
64436450 auto I = ConstantMultipleCache.find(S);
64446451 if (I != ConstantMultipleCache.end())
64456452 return I->second;
64466453
6447- APInt Result = getConstantMultipleImpl(S);
6454+ APInt Result = getConstantMultipleImpl(S, CtxI );
64486455 auto InsertPair = ConstantMultipleCache.insert({S, Result});
64496456 assert(InsertPair.second && "Should insert a new key");
64506457 return InsertPair.first->second;
@@ -6455,8 +6462,9 @@ APInt ScalarEvolution::getNonZeroConstantMultiple(const SCEV *S) {
64556462 return Multiple == 0 ? APInt(Multiple.getBitWidth(), 1) : Multiple;
64566463}
64576464
6458- uint32_t ScalarEvolution::getMinTrailingZeros(const SCEV *S) {
6459- return std::min(getConstantMultiple(S).countTrailingZeros(),
6465+ uint32_t ScalarEvolution::getMinTrailingZeros(const SCEV *S,
6466+ const Instruction *CtxI) {
6467+ return std::min(getConstantMultiple(S, CtxI).countTrailingZeros(),
64606468 (unsigned)getTypeSizeInBits(S->getType()));
64616469}
64626470
@@ -10243,8 +10251,7 @@ const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const {
1024310251static const SCEV *
1024410252SolveLinEquationWithOverflow(const APInt &A, const SCEV *B,
1024510253 SmallVectorImpl<const SCEVPredicate *> *Predicates,
10246-
10247- ScalarEvolution &SE) {
10254+ ScalarEvolution &SE, const Loop *L) {
1024810255 uint32_t BW = A.getBitWidth();
1024910256 assert(BW == SE.getTypeSizeInBits(B->getType()));
1025010257 assert(A != 0 && "A must be non-zero.");
@@ -10260,7 +10267,12 @@ SolveLinEquationWithOverflow(const APInt &A, const SCEV *B,
1026010267 //
1026110268 // B is divisible by D if and only if the multiplicity of prime factor 2 for B
1026210269 // is not less than multiplicity of this prime factor for D.
10263- if (SE.getMinTrailingZeros(B) < Mult2) {
10270+ unsigned MinTZ = SE.getMinTrailingZeros(B);
10271+ // Try again with the terminator of the loop predecessor for context-specific
10272+ // result, if MinTZ s too small.
10273+ if (MinTZ < Mult2 && L->getLoopPredecessor())
10274+ MinTZ = SE.getMinTrailingZeros(B, L->getLoopPredecessor()->getTerminator());
10275+ if (MinTZ < Mult2) {
1026410276 // Check if we can prove there's no remainder using URem.
1026510277 const SCEV *URem =
1026610278 SE.getURemExpr(B, SE.getConstant(APInt::getOneBitSet(BW, Mult2)));
@@ -10708,7 +10720,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
1070810720 return getCouldNotCompute();
1070910721 const SCEV *E = SolveLinEquationWithOverflow(
1071010722 StepC->getAPInt(), getNegativeSCEV(Start),
10711- AllowPredicates ? &Predicates : nullptr, *this);
10723+ AllowPredicates ? &Predicates : nullptr, *this, L );
1071210724
1071310725 const SCEV *M = E;
1071410726 if (E != getCouldNotCompute()) {
0 commit comments