@@ -15557,6 +15557,123 @@ void ScalarEvolution::LoopGuards::collectFromPHI(
1555715557 }
1555815558}
1555915559
15560+ // Checks whether Expr is a non-negative constant, and Divisor is a positive
15561+ // constant, and returns their APInt in ExprVal and in DivisorVal.
15562+ static bool getNonNegExprAndPosDivisor(const SCEV *Expr, const SCEV *Divisor,
15563+ APInt &ExprVal, APInt &DivisorVal) {
15564+ auto *ConstExpr = dyn_cast<SCEVConstant>(Expr);
15565+ auto *ConstDivisor = dyn_cast<SCEVConstant>(Divisor);
15566+ if (!ConstExpr || !ConstDivisor)
15567+ return false;
15568+ ExprVal = ConstExpr->getAPInt();
15569+ DivisorVal = ConstDivisor->getAPInt();
15570+ return ExprVal.isNonNegative() && !DivisorVal.isNonPositive();
15571+ }
15572+
15573+ // Return a new SCEV that modifies \p Expr to the closest number divisible by
15574+ // \p Divisor and less than or equal to Expr.
15575+ // For now, only handle constant Expr and Divisor.
15576+ static const SCEV *getPreviousSCEVDivisibleByDivisor(const SCEV *Expr,
15577+ const SCEV *Divisor,
15578+ ScalarEvolution &SE) {
15579+ APInt ExprVal;
15580+ APInt DivisorVal;
15581+ if (!getNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
15582+ return Expr;
15583+ APInt Rem = ExprVal.urem(DivisorVal);
15584+ // return the SCEV: Expr - Expr % Divisor
15585+ return SE.getConstant(ExprVal - Rem);
15586+ }
15587+
15588+ // Return a new SCEV that modifies \p Expr to the closest number divisible by
15589+ // \p Divisor and greater than or equal to Expr.
15590+ // For now, only handle constant Expr and Divisor.
15591+ static const SCEV *getNextSCEVDivisibleByDivisor(const SCEV *Expr,
15592+ const SCEV *Divisor,
15593+ ScalarEvolution &SE) {
15594+ APInt ExprVal;
15595+ APInt DivisorVal;
15596+ if (!getNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
15597+ return Expr;
15598+ APInt Rem = ExprVal.urem(DivisorVal);
15599+ if (!Rem.isZero())
15600+ // return the SCEV: Expr + Divisor - Expr % Divisor
15601+ return SE.getConstant(ExprVal + DivisorVal - Rem);
15602+ return Expr;
15603+ }
15604+
15605+ static bool collectDivisibilityInformation(
15606+ ICmpInst::Predicate Predicate, const SCEV *LHS, const SCEV *RHS,
15607+ DenseMap<const SCEV *, const SCEV *> &DivInfo,
15608+ DenseMap<const SCEV *, const SCEV *> &Multiples, ScalarEvolution &SE) {
15609+ // If we have LHS == 0, check if LHS is computing a property of some unknown
15610+ // SCEV %v which we can rewrite %v to express explicitly.
15611+ if (Predicate != CmpInst::ICMP_EQ || !match(RHS, m_scev_Zero()))
15612+ return false;
15613+ // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
15614+ // explicitly express that.
15615+ const SCEV *URemLHS = nullptr;
15616+ const SCEV *URemRHS = nullptr;
15617+ if (!SE.matchURem(LHS, URemLHS, URemRHS))
15618+ return false;
15619+ if (const SCEVUnknown *LHSUnknown = dyn_cast<SCEVUnknown>(URemLHS)) {
15620+ const auto *Multiple = SE.getMulExpr(SE.getUDivExpr(LHS, URemRHS), URemRHS);
15621+ DivInfo[LHSUnknown] = Multiple;
15622+ Multiples[LHSUnknown] = URemRHS;
15623+ return true;
15624+ }
15625+ return false;
15626+ }
15627+
15628+ // Check if the condition is a divisibility guard (A % B == 0).
15629+ static bool isDivisibilityGuard(const SCEV *LHS, const SCEV *RHS,
15630+ ScalarEvolution &SE) {
15631+ const SCEV *X, *Y;
15632+ return SE.matchURem(LHS, X, Y) && RHS->isZero();
15633+ }
15634+
15635+ // Apply divisibility by \p Divisor on MinMaxExpr with constant values,
15636+ // recursively. This is done by aligning up/down the constant value to the
15637+ // Divisor.
15638+ static const SCEV *applyDivisibilityOnMinMaxExpr(const SCEV *MinMaxExpr,
15639+ const SCEV *Divisor,
15640+ ScalarEvolution &SE) {
15641+ // Return true if \p Expr is a MinMax SCEV expression with a non-negative
15642+ // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
15643+ // the non-constant operand and in \p LHS the constant operand.
15644+ auto IsMinMaxSCEVWithNonNegativeConstant =
15645+ [](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
15646+ const SCEV *&RHS) {
15647+ if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15648+ if (MinMax->getNumOperands() != 2)
15649+ return false;
15650+ if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) {
15651+ if (C->getAPInt().isNegative())
15652+ return false;
15653+ SCTy = MinMax->getSCEVType();
15654+ LHS = MinMax->getOperand(0);
15655+ RHS = MinMax->getOperand(1);
15656+ return true;
15657+ }
15658+ }
15659+ return false;
15660+ };
15661+
15662+ const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15663+ SCEVTypes SCTy;
15664+ if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
15665+ MinMaxRHS))
15666+ return MinMaxExpr;
15667+ auto IsMin = isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15668+ assert(SE.isKnownNonNegative(MinMaxLHS) && "Expected non-negative operand!");
15669+ auto *DivisibleExpr =
15670+ IsMin ? getPreviousSCEVDivisibleByDivisor(MinMaxLHS, Divisor, SE)
15671+ : getNextSCEVDivisibleByDivisor(MinMaxLHS, Divisor, SE);
15672+ SmallVector<const SCEV *> Ops = {
15673+ applyDivisibilityOnMinMaxExpr(MinMaxRHS, Divisor, SE), DivisibleExpr};
15674+ return SE.getMinMaxExpr(SCTy, Ops);
15675+ }
15676+
1556015677void ScalarEvolution::LoopGuards::collectFromBlock(
1556115678 ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
1556215679 const BasicBlock *Block, const BasicBlock *Pred,
@@ -15567,19 +15684,14 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1556715684 SmallVector<const SCEV *> ExprsToRewrite;
1556815685 auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
1556915686 const SCEV *RHS,
15570- DenseMap<const SCEV *, const SCEV *>
15571- &RewriteMap) {
15687+ DenseMap<const SCEV *, const SCEV *> &RewriteMap,
15688+ const DenseMap<const SCEV *, const SCEV *>
15689+ &DivInfo) {
1557215690 // WARNING: It is generally unsound to apply any wrap flags to the proposed
1557315691 // replacement SCEV which isn't directly implied by the structure of that
1557415692 // SCEV. In particular, using contextual facts to imply flags is *NOT*
1557515693 // legal. See the scoping rules for flags in the header to understand why.
1557615694
15577- // If LHS is a constant, apply information to the other expression.
15578- if (isa<SCEVConstant>(LHS)) {
15579- std::swap(LHS, RHS);
15580- Predicate = CmpInst::getSwappedPredicate(Predicate);
15581- }
15582-
1558315695 // Check for a condition of the form (-C1 + X < C2). InstCombine will
1558415696 // create this form when combining two checks of the form (X u< C2 + C1) and
1558515697 // (X >=u C1).
@@ -15612,115 +15724,6 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1561215724 if (MatchRangeCheckIdiom())
1561315725 return;
1561415726
15615- // Return true if \p Expr is a MinMax SCEV expression with a non-negative
15616- // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
15617- // the non-constant operand and in \p LHS the constant operand.
15618- auto IsMinMaxSCEVWithNonNegativeConstant =
15619- [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
15620- const SCEV *&RHS) {
15621- if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15622- if (MinMax->getNumOperands() != 2)
15623- return false;
15624- if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) {
15625- if (C->getAPInt().isNegative())
15626- return false;
15627- SCTy = MinMax->getSCEVType();
15628- LHS = MinMax->getOperand(0);
15629- RHS = MinMax->getOperand(1);
15630- return true;
15631- }
15632- }
15633- return false;
15634- };
15635-
15636- // Checks whether Expr is a non-negative constant, and Divisor is a positive
15637- // constant, and returns their APInt in ExprVal and in DivisorVal.
15638- auto GetNonNegExprAndPosDivisor = [&](const SCEV *Expr, const SCEV *Divisor,
15639- APInt &ExprVal, APInt &DivisorVal) {
15640- auto *ConstExpr = dyn_cast<SCEVConstant>(Expr);
15641- auto *ConstDivisor = dyn_cast<SCEVConstant>(Divisor);
15642- if (!ConstExpr || !ConstDivisor)
15643- return false;
15644- ExprVal = ConstExpr->getAPInt();
15645- DivisorVal = ConstDivisor->getAPInt();
15646- return ExprVal.isNonNegative() && !DivisorVal.isNonPositive();
15647- };
15648-
15649- // Return a new SCEV that modifies \p Expr to the closest number divides by
15650- // \p Divisor and greater or equal than Expr.
15651- // For now, only handle constant Expr and Divisor.
15652- auto GetNextSCEVDividesByDivisor = [&](const SCEV *Expr,
15653- const SCEV *Divisor) {
15654- APInt ExprVal;
15655- APInt DivisorVal;
15656- if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
15657- return Expr;
15658- APInt Rem = ExprVal.urem(DivisorVal);
15659- if (!Rem.isZero())
15660- // return the SCEV: Expr + Divisor - Expr % Divisor
15661- return SE.getConstant(ExprVal + DivisorVal - Rem);
15662- return Expr;
15663- };
15664-
15665- // Return a new SCEV that modifies \p Expr to the closest number divides by
15666- // \p Divisor and less or equal than Expr.
15667- // For now, only handle constant Expr and Divisor.
15668- auto GetPreviousSCEVDividesByDivisor = [&](const SCEV *Expr,
15669- const SCEV *Divisor) {
15670- APInt ExprVal;
15671- APInt DivisorVal;
15672- if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
15673- return Expr;
15674- APInt Rem = ExprVal.urem(DivisorVal);
15675- // return the SCEV: Expr - Expr % Divisor
15676- return SE.getConstant(ExprVal - Rem);
15677- };
15678-
15679- // Apply divisibilty by \p Divisor on MinMaxExpr with constant values,
15680- // recursively. This is done by aligning up/down the constant value to the
15681- // Divisor.
15682- std::function<const SCEV *(const SCEV *, const SCEV *)>
15683- ApplyDivisibiltyOnMinMaxExpr = [&](const SCEV *MinMaxExpr,
15684- const SCEV *Divisor) {
15685- const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15686- SCEVTypes SCTy;
15687- if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
15688- MinMaxRHS))
15689- return MinMaxExpr;
15690- auto IsMin =
15691- isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15692- assert(SE.isKnownNonNegative(MinMaxLHS) &&
15693- "Expected non-negative operand!");
15694- auto *DivisibleExpr =
15695- IsMin ? GetPreviousSCEVDividesByDivisor(MinMaxLHS, Divisor)
15696- : GetNextSCEVDividesByDivisor(MinMaxLHS, Divisor);
15697- SmallVector<const SCEV *> Ops = {
15698- ApplyDivisibiltyOnMinMaxExpr(MinMaxRHS, Divisor), DivisibleExpr};
15699- return SE.getMinMaxExpr(SCTy, Ops);
15700- };
15701-
15702- // If we have LHS == 0, check if LHS is computing a property of some unknown
15703- // SCEV %v which we can rewrite %v to express explicitly.
15704- if (Predicate == CmpInst::ICMP_EQ && match(RHS, m_scev_Zero())) {
15705- // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
15706- // explicitly express that.
15707- const SCEV *URemLHS = nullptr;
15708- const SCEV *URemRHS = nullptr;
15709- if (SE.matchURem(LHS, URemLHS, URemRHS)) {
15710- if (const SCEVUnknown *LHSUnknown = dyn_cast<SCEVUnknown>(URemLHS)) {
15711- auto I = RewriteMap.find(LHSUnknown);
15712- const SCEV *RewrittenLHS =
15713- I != RewriteMap.end() ? I->second : LHSUnknown;
15714- RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
15715- const auto *Multiple =
15716- SE.getMulExpr(SE.getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
15717- RewriteMap[LHSUnknown] = Multiple;
15718- ExprsToRewrite.push_back(LHSUnknown);
15719- return;
15720- }
15721- }
15722- }
15723-
1572415727 // Do not apply information for constants or if RHS contains an AddRec.
1572515728 if (isa<SCEVConstant>(LHS) || SE.containsAddRecurrence(RHS))
1572615729 return;
@@ -15751,7 +15754,11 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1575115754
1575215755 const SCEV *RewrittenLHS = GetMaybeRewritten(LHS);
1575315756 const SCEV *DividesBy = nullptr;
15754- const APInt &Multiple = SE.getConstantMultiple(RewrittenLHS);
15757+ // Apply divisibility information when computing the constant multiple.
15758+ LoopGuards DivGuards(SE);
15759+ DivGuards.RewriteMap = DivInfo;
15760+ const APInt &Multiple =
15761+ SE.getConstantMultiple(DivGuards.rewrite(RewrittenLHS));
1575515762 if (!Multiple.isOne())
1575615763 DividesBy = SE.getConstant(Multiple);
1575715764
@@ -15775,21 +15782,23 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1577515782 [[fallthrough]];
1577615783 case CmpInst::ICMP_SLT: {
1577715784 RHS = SE.getMinusSCEV(RHS, One);
15778- RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15785+ RHS = DividesBy ? getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE)
15786+ : RHS;
1577915787 break;
1578015788 }
1578115789 case CmpInst::ICMP_UGT:
1578215790 case CmpInst::ICMP_SGT:
1578315791 RHS = SE.getAddExpr(RHS, One);
15784- RHS = DividesBy ? GetNextSCEVDividesByDivisor (RHS, DividesBy) : RHS;
15792+ RHS = DividesBy ? getNextSCEVDivisibleByDivisor (RHS, DividesBy, SE ) : RHS;
1578515793 break;
1578615794 case CmpInst::ICMP_ULE:
1578715795 case CmpInst::ICMP_SLE:
15788- RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15796+ RHS = DividesBy ? getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE)
15797+ : RHS;
1578915798 break;
1579015799 case CmpInst::ICMP_UGE:
1579115800 case CmpInst::ICMP_SGE:
15792- RHS = DividesBy ? GetNextSCEVDividesByDivisor (RHS, DividesBy) : RHS;
15801+ RHS = DividesBy ? getNextSCEVDivisibleByDivisor (RHS, DividesBy, SE ) : RHS;
1579315802 break;
1579415803 default:
1579515804 break;
@@ -15843,7 +15852,8 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1584315852 case CmpInst::ICMP_NE:
1584415853 if (match(RHS, m_scev_Zero())) {
1584515854 const SCEV *OneAlignedUp =
15846- DividesBy ? GetNextSCEVDividesByDivisor(One, DividesBy) : One;
15855+ DividesBy ? getNextSCEVDivisibleByDivisor(One, DividesBy, SE)
15856+ : One;
1584715857 To = SE.getUMaxExpr(FromRewritten, OneAlignedUp);
1584815858 }
1584915859 break;
@@ -15916,8 +15926,11 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1591615926
1591715927 // Now apply the information from the collected conditions to
1591815928 // Guards.RewriteMap. Conditions are processed in reverse order, so the
15919- // earliest conditions is processed first. This ensures the SCEVs with the
15929+ // earliest conditions is processed first, except guards with divisibility
15930+ // information, which are moved to the back. This ensures the SCEVs with the
1592015931 // shortest dependency chains are constructed first.
15932+ SmallVector<std::tuple<CmpInst::Predicate, const SCEV *, const SCEV *>>
15933+ GuardsToProcess;
1592115934 for (auto [Term, EnterIfTrue] : reverse(Terms)) {
1592215935 SmallVector<Value *, 8> Worklist;
1592315936 SmallPtrSet<Value *, 8> Visited;
@@ -15932,7 +15945,12 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1593215945 EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate();
1593315946 const auto *LHS = SE.getSCEV(Cmp->getOperand(0));
1593415947 const auto *RHS = SE.getSCEV(Cmp->getOperand(1));
15935- CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap);
15948+ // If LHS is a constant, apply information to the other expression.
15949+ if (isa<SCEVConstant>(LHS)) {
15950+ std::swap(LHS, RHS);
15951+ Predicate = CmpInst::getSwappedPredicate(Predicate);
15952+ }
15953+ GuardsToProcess.emplace_back(Predicate, LHS, RHS);
1593615954 continue;
1593715955 }
1593815956
@@ -15945,6 +15963,28 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1594515963 }
1594615964 }
1594715965
15966+ // Process divisibility guards in reverse order to populate DivInfo early.
15967+ DenseMap<const SCEV *, const SCEV *> Multiples;
15968+ DenseMap<const SCEV *, const SCEV *> DivInfo;
15969+ for (const auto &[Predicate, LHS, RHS] : GuardsToProcess) {
15970+ if (!isDivisibilityGuard(LHS, RHS, SE))
15971+ continue;
15972+ collectDivisibilityInformation(Predicate, LHS, RHS, DivInfo, Multiples, SE);
15973+ }
15974+
15975+ for (const auto &[Predicate, LHS, RHS] : GuardsToProcess)
15976+ CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap, DivInfo);
15977+
15978+ // Apply divisibility information last. This ensures it is applied to the
15979+ // outermost expression after other rewrites for the given value.
15980+ for (const auto &[K, V] : Multiples) {
15981+ Guards.RewriteMap[K] = SE.getMulExpr(
15982+ SE.getUDivExpr(applyDivisibilityOnMinMaxExpr(Guards.rewrite(K), V, SE),
15983+ V),
15984+ V);
15985+ ExprsToRewrite.push_back(K);
15986+ }
15987+
1594815988 // Let the rewriter preserve NUW/NSW flags if the unsigned/signed ranges of
1594915989 // the replacement expressions are contained in the ranges of the replaced
1595015990 // expressions.
0 commit comments