Skip to content

Commit 956297a

Browse files
committed
[SCEV] Improve handling of divisibility information from loop guards.
At the moment, the effectivness of guards that contain divisibility information (A % B == 0 ) depends on the order of the conditions. This patch makes using divisibility information independent of the order, by collecting and applying the divisibility information separately. We first collect all conditions in a vector, then collect the divisibility information from all guards. When processing other guards, we apply divisibility info collected earlier. After all guards have been processed, we add the divisibility info, rewriting the existing rewrite. This ensures we apply the divisibility info to the largest rewrite expression. This helps to improve results in a few cases, one in dtcxzyw/llvm-opt-benchmark#2921 and another one in a different large C/C++ based IR corpus.
1 parent 8785276 commit 956297a

File tree

3 files changed

+171
-132
lines changed

3 files changed

+171
-132
lines changed

llvm/include/llvm/Analysis/ScalarEvolution.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -763,6 +763,10 @@ class ScalarEvolution {
763763
getUMinFromMismatchedTypes(SmallVectorImpl<const SCEV *> &Ops,
764764
bool Sequential = false);
765765

766+
/// Try to match the pattern generated by getURemExpr(A, B). If successful,
767+
/// Assign A and B to LHS and RHS, respectively.
768+
LLVM_ABI bool matchURem(const SCEV *Expr, const SCEV *&LHS, const SCEV *&RHS);
769+
766770
/// Transitively follow the chain of pointer-type operands until reaching a
767771
/// SCEV that does not have a single pointer operand. This returns a
768772
/// SCEVUnknown pointer for well-formed pointer-type expressions, but corner
@@ -2316,10 +2320,6 @@ class ScalarEvolution {
23162320
/// an add rec on said loop.
23172321
void getUsedLoops(const SCEV *S, SmallPtrSetImpl<const Loop *> &LoopsUsed);
23182322

2319-
/// Try to match the pattern generated by getURemExpr(A, B). If successful,
2320-
/// Assign A and B to LHS and RHS, respectively.
2321-
LLVM_ABI bool matchURem(const SCEV *Expr, const SCEV *&LHS, const SCEV *&RHS);
2322-
23232323
/// Look for a SCEV expression with type `SCEVType` and operands `Ops` in
23242324
/// `UniqueSCEVs`. Return if found, else nullptr.
23252325
SCEV *findExistingSCEVInCache(SCEVTypes SCEVType, ArrayRef<const SCEV *> Ops);

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 165 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -15557,6 +15557,123 @@ void ScalarEvolution::LoopGuards::collectFromPHI(
1555715557
}
1555815558
}
1555915559

15560+
// Checks whether Expr is a non-negative constant, and Divisor is a positive
15561+
// constant, and returns their APInt in ExprVal and in DivisorVal.
15562+
static bool getNonNegExprAndPosDivisor(const SCEV *Expr, const SCEV *Divisor,
15563+
APInt &ExprVal, APInt &DivisorVal) {
15564+
auto *ConstExpr = dyn_cast<SCEVConstant>(Expr);
15565+
auto *ConstDivisor = dyn_cast<SCEVConstant>(Divisor);
15566+
if (!ConstExpr || !ConstDivisor)
15567+
return false;
15568+
ExprVal = ConstExpr->getAPInt();
15569+
DivisorVal = ConstDivisor->getAPInt();
15570+
return ExprVal.isNonNegative() && !DivisorVal.isNonPositive();
15571+
}
15572+
15573+
// Return a new SCEV that modifies \p Expr to the closest number divisible by
15574+
// \p Divisor and less than or equal to Expr.
15575+
// For now, only handle constant Expr and Divisor.
15576+
static const SCEV *getPreviousSCEVDivisibleByDivisor(const SCEV *Expr,
15577+
const SCEV *Divisor,
15578+
ScalarEvolution &SE) {
15579+
APInt ExprVal;
15580+
APInt DivisorVal;
15581+
if (!getNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
15582+
return Expr;
15583+
APInt Rem = ExprVal.urem(DivisorVal);
15584+
// return the SCEV: Expr - Expr % Divisor
15585+
return SE.getConstant(ExprVal - Rem);
15586+
}
15587+
15588+
// Return a new SCEV that modifies \p Expr to the closest number divisible by
15589+
// \p Divisor and greater than or equal to Expr.
15590+
// For now, only handle constant Expr and Divisor.
15591+
static const SCEV *getNextSCEVDivisibleByDivisor(const SCEV *Expr,
15592+
const SCEV *Divisor,
15593+
ScalarEvolution &SE) {
15594+
APInt ExprVal;
15595+
APInt DivisorVal;
15596+
if (!getNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
15597+
return Expr;
15598+
APInt Rem = ExprVal.urem(DivisorVal);
15599+
if (!Rem.isZero())
15600+
// return the SCEV: Expr + Divisor - Expr % Divisor
15601+
return SE.getConstant(ExprVal + DivisorVal - Rem);
15602+
return Expr;
15603+
}
15604+
15605+
static bool collectDivisibilityInformation(
15606+
ICmpInst::Predicate Predicate, const SCEV *LHS, const SCEV *RHS,
15607+
DenseMap<const SCEV *, const SCEV *> &DivInfo,
15608+
DenseMap<const SCEV *, const SCEV *> &Multiples, ScalarEvolution &SE) {
15609+
// If we have LHS == 0, check if LHS is computing a property of some unknown
15610+
// SCEV %v which we can rewrite %v to express explicitly.
15611+
if (Predicate != CmpInst::ICMP_EQ || !match(RHS, m_scev_Zero()))
15612+
return false;
15613+
// If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
15614+
// explicitly express that.
15615+
const SCEV *URemLHS = nullptr;
15616+
const SCEV *URemRHS = nullptr;
15617+
if (!SE.matchURem(LHS, URemLHS, URemRHS))
15618+
return false;
15619+
if (const SCEVUnknown *LHSUnknown = dyn_cast<SCEVUnknown>(URemLHS)) {
15620+
const auto *Multiple = SE.getMulExpr(SE.getUDivExpr(LHS, URemRHS), URemRHS);
15621+
DivInfo[LHSUnknown] = Multiple;
15622+
Multiples[LHSUnknown] = URemRHS;
15623+
return true;
15624+
}
15625+
return false;
15626+
}
15627+
15628+
// Check if the condition is a divisibility guard (A % B == 0).
15629+
static bool isDivisibilityGuard(const SCEV *LHS, const SCEV *RHS,
15630+
ScalarEvolution &SE) {
15631+
const SCEV *X, *Y;
15632+
return SE.matchURem(LHS, X, Y) && RHS->isZero();
15633+
}
15634+
15635+
// Apply divisibility by \p Divisor on MinMaxExpr with constant values,
15636+
// recursively. This is done by aligning up/down the constant value to the
15637+
// Divisor.
15638+
static const SCEV *applyDivisibilityOnMinMaxExpr(const SCEV *MinMaxExpr,
15639+
const SCEV *Divisor,
15640+
ScalarEvolution &SE) {
15641+
// Return true if \p Expr is a MinMax SCEV expression with a non-negative
15642+
// constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
15643+
// the non-constant operand and in \p LHS the constant operand.
15644+
auto IsMinMaxSCEVWithNonNegativeConstant =
15645+
[](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
15646+
const SCEV *&RHS) {
15647+
if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15648+
if (MinMax->getNumOperands() != 2)
15649+
return false;
15650+
if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) {
15651+
if (C->getAPInt().isNegative())
15652+
return false;
15653+
SCTy = MinMax->getSCEVType();
15654+
LHS = MinMax->getOperand(0);
15655+
RHS = MinMax->getOperand(1);
15656+
return true;
15657+
}
15658+
}
15659+
return false;
15660+
};
15661+
15662+
const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15663+
SCEVTypes SCTy;
15664+
if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
15665+
MinMaxRHS))
15666+
return MinMaxExpr;
15667+
auto IsMin = isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15668+
assert(SE.isKnownNonNegative(MinMaxLHS) && "Expected non-negative operand!");
15669+
auto *DivisibleExpr =
15670+
IsMin ? getPreviousSCEVDivisibleByDivisor(MinMaxLHS, Divisor, SE)
15671+
: getNextSCEVDivisibleByDivisor(MinMaxLHS, Divisor, SE);
15672+
SmallVector<const SCEV *> Ops = {
15673+
applyDivisibilityOnMinMaxExpr(MinMaxRHS, Divisor, SE), DivisibleExpr};
15674+
return SE.getMinMaxExpr(SCTy, Ops);
15675+
}
15676+
1556015677
void ScalarEvolution::LoopGuards::collectFromBlock(
1556115678
ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
1556215679
const BasicBlock *Block, const BasicBlock *Pred,
@@ -15567,19 +15684,14 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1556715684
SmallVector<const SCEV *> ExprsToRewrite;
1556815685
auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
1556915686
const SCEV *RHS,
15570-
DenseMap<const SCEV *, const SCEV *>
15571-
&RewriteMap) {
15687+
DenseMap<const SCEV *, const SCEV *> &RewriteMap,
15688+
const DenseMap<const SCEV *, const SCEV *>
15689+
&DivInfo) {
1557215690
// WARNING: It is generally unsound to apply any wrap flags to the proposed
1557315691
// replacement SCEV which isn't directly implied by the structure of that
1557415692
// SCEV. In particular, using contextual facts to imply flags is *NOT*
1557515693
// legal. See the scoping rules for flags in the header to understand why.
1557615694

15577-
// If LHS is a constant, apply information to the other expression.
15578-
if (isa<SCEVConstant>(LHS)) {
15579-
std::swap(LHS, RHS);
15580-
Predicate = CmpInst::getSwappedPredicate(Predicate);
15581-
}
15582-
1558315695
// Check for a condition of the form (-C1 + X < C2). InstCombine will
1558415696
// create this form when combining two checks of the form (X u< C2 + C1) and
1558515697
// (X >=u C1).
@@ -15612,115 +15724,6 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1561215724
if (MatchRangeCheckIdiom())
1561315725
return;
1561415726

15615-
// Return true if \p Expr is a MinMax SCEV expression with a non-negative
15616-
// constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
15617-
// the non-constant operand and in \p LHS the constant operand.
15618-
auto IsMinMaxSCEVWithNonNegativeConstant =
15619-
[&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
15620-
const SCEV *&RHS) {
15621-
if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15622-
if (MinMax->getNumOperands() != 2)
15623-
return false;
15624-
if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) {
15625-
if (C->getAPInt().isNegative())
15626-
return false;
15627-
SCTy = MinMax->getSCEVType();
15628-
LHS = MinMax->getOperand(0);
15629-
RHS = MinMax->getOperand(1);
15630-
return true;
15631-
}
15632-
}
15633-
return false;
15634-
};
15635-
15636-
// Checks whether Expr is a non-negative constant, and Divisor is a positive
15637-
// constant, and returns their APInt in ExprVal and in DivisorVal.
15638-
auto GetNonNegExprAndPosDivisor = [&](const SCEV *Expr, const SCEV *Divisor,
15639-
APInt &ExprVal, APInt &DivisorVal) {
15640-
auto *ConstExpr = dyn_cast<SCEVConstant>(Expr);
15641-
auto *ConstDivisor = dyn_cast<SCEVConstant>(Divisor);
15642-
if (!ConstExpr || !ConstDivisor)
15643-
return false;
15644-
ExprVal = ConstExpr->getAPInt();
15645-
DivisorVal = ConstDivisor->getAPInt();
15646-
return ExprVal.isNonNegative() && !DivisorVal.isNonPositive();
15647-
};
15648-
15649-
// Return a new SCEV that modifies \p Expr to the closest number divides by
15650-
// \p Divisor and greater or equal than Expr.
15651-
// For now, only handle constant Expr and Divisor.
15652-
auto GetNextSCEVDividesByDivisor = [&](const SCEV *Expr,
15653-
const SCEV *Divisor) {
15654-
APInt ExprVal;
15655-
APInt DivisorVal;
15656-
if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
15657-
return Expr;
15658-
APInt Rem = ExprVal.urem(DivisorVal);
15659-
if (!Rem.isZero())
15660-
// return the SCEV: Expr + Divisor - Expr % Divisor
15661-
return SE.getConstant(ExprVal + DivisorVal - Rem);
15662-
return Expr;
15663-
};
15664-
15665-
// Return a new SCEV that modifies \p Expr to the closest number divides by
15666-
// \p Divisor and less or equal than Expr.
15667-
// For now, only handle constant Expr and Divisor.
15668-
auto GetPreviousSCEVDividesByDivisor = [&](const SCEV *Expr,
15669-
const SCEV *Divisor) {
15670-
APInt ExprVal;
15671-
APInt DivisorVal;
15672-
if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
15673-
return Expr;
15674-
APInt Rem = ExprVal.urem(DivisorVal);
15675-
// return the SCEV: Expr - Expr % Divisor
15676-
return SE.getConstant(ExprVal - Rem);
15677-
};
15678-
15679-
// Apply divisibilty by \p Divisor on MinMaxExpr with constant values,
15680-
// recursively. This is done by aligning up/down the constant value to the
15681-
// Divisor.
15682-
std::function<const SCEV *(const SCEV *, const SCEV *)>
15683-
ApplyDivisibiltyOnMinMaxExpr = [&](const SCEV *MinMaxExpr,
15684-
const SCEV *Divisor) {
15685-
const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15686-
SCEVTypes SCTy;
15687-
if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
15688-
MinMaxRHS))
15689-
return MinMaxExpr;
15690-
auto IsMin =
15691-
isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15692-
assert(SE.isKnownNonNegative(MinMaxLHS) &&
15693-
"Expected non-negative operand!");
15694-
auto *DivisibleExpr =
15695-
IsMin ? GetPreviousSCEVDividesByDivisor(MinMaxLHS, Divisor)
15696-
: GetNextSCEVDividesByDivisor(MinMaxLHS, Divisor);
15697-
SmallVector<const SCEV *> Ops = {
15698-
ApplyDivisibiltyOnMinMaxExpr(MinMaxRHS, Divisor), DivisibleExpr};
15699-
return SE.getMinMaxExpr(SCTy, Ops);
15700-
};
15701-
15702-
// If we have LHS == 0, check if LHS is computing a property of some unknown
15703-
// SCEV %v which we can rewrite %v to express explicitly.
15704-
if (Predicate == CmpInst::ICMP_EQ && match(RHS, m_scev_Zero())) {
15705-
// If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
15706-
// explicitly express that.
15707-
const SCEV *URemLHS = nullptr;
15708-
const SCEV *URemRHS = nullptr;
15709-
if (SE.matchURem(LHS, URemLHS, URemRHS)) {
15710-
if (const SCEVUnknown *LHSUnknown = dyn_cast<SCEVUnknown>(URemLHS)) {
15711-
auto I = RewriteMap.find(LHSUnknown);
15712-
const SCEV *RewrittenLHS =
15713-
I != RewriteMap.end() ? I->second : LHSUnknown;
15714-
RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
15715-
const auto *Multiple =
15716-
SE.getMulExpr(SE.getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
15717-
RewriteMap[LHSUnknown] = Multiple;
15718-
ExprsToRewrite.push_back(LHSUnknown);
15719-
return;
15720-
}
15721-
}
15722-
}
15723-
1572415727
// Do not apply information for constants or if RHS contains an AddRec.
1572515728
if (isa<SCEVConstant>(LHS) || SE.containsAddRecurrence(RHS))
1572615729
return;
@@ -15751,7 +15754,11 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1575115754

1575215755
const SCEV *RewrittenLHS = GetMaybeRewritten(LHS);
1575315756
const SCEV *DividesBy = nullptr;
15754-
const APInt &Multiple = SE.getConstantMultiple(RewrittenLHS);
15757+
// Apply divisibility information when computing the constant multiple.
15758+
LoopGuards DivGuards(SE);
15759+
DivGuards.RewriteMap = DivInfo;
15760+
const APInt &Multiple =
15761+
SE.getConstantMultiple(DivGuards.rewrite(RewrittenLHS));
1575515762
if (!Multiple.isOne())
1575615763
DividesBy = SE.getConstant(Multiple);
1575715764

@@ -15775,21 +15782,23 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1577515782
[[fallthrough]];
1577615783
case CmpInst::ICMP_SLT: {
1577715784
RHS = SE.getMinusSCEV(RHS, One);
15778-
RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15785+
RHS = DividesBy ? getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE)
15786+
: RHS;
1577915787
break;
1578015788
}
1578115789
case CmpInst::ICMP_UGT:
1578215790
case CmpInst::ICMP_SGT:
1578315791
RHS = SE.getAddExpr(RHS, One);
15784-
RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15792+
RHS = DividesBy ? getNextSCEVDivisibleByDivisor(RHS, DividesBy, SE) : RHS;
1578515793
break;
1578615794
case CmpInst::ICMP_ULE:
1578715795
case CmpInst::ICMP_SLE:
15788-
RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15796+
RHS = DividesBy ? getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE)
15797+
: RHS;
1578915798
break;
1579015799
case CmpInst::ICMP_UGE:
1579115800
case CmpInst::ICMP_SGE:
15792-
RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15801+
RHS = DividesBy ? getNextSCEVDivisibleByDivisor(RHS, DividesBy, SE) : RHS;
1579315802
break;
1579415803
default:
1579515804
break;
@@ -15843,7 +15852,8 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1584315852
case CmpInst::ICMP_NE:
1584415853
if (match(RHS, m_scev_Zero())) {
1584515854
const SCEV *OneAlignedUp =
15846-
DividesBy ? GetNextSCEVDividesByDivisor(One, DividesBy) : One;
15855+
DividesBy ? getNextSCEVDivisibleByDivisor(One, DividesBy, SE)
15856+
: One;
1584715857
To = SE.getUMaxExpr(FromRewritten, OneAlignedUp);
1584815858
}
1584915859
break;
@@ -15916,8 +15926,11 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1591615926

1591715927
// Now apply the information from the collected conditions to
1591815928
// Guards.RewriteMap. Conditions are processed in reverse order, so the
15919-
// earliest conditions is processed first. This ensures the SCEVs with the
15929+
// earliest conditions is processed first, except guards with divisibility
15930+
// information, which are moved to the back. This ensures the SCEVs with the
1592015931
// shortest dependency chains are constructed first.
15932+
SmallVector<std::tuple<CmpInst::Predicate, const SCEV *, const SCEV *>>
15933+
GuardsToProcess;
1592115934
for (auto [Term, EnterIfTrue] : reverse(Terms)) {
1592215935
SmallVector<Value *, 8> Worklist;
1592315936
SmallPtrSet<Value *, 8> Visited;
@@ -15932,7 +15945,12 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1593215945
EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate();
1593315946
const auto *LHS = SE.getSCEV(Cmp->getOperand(0));
1593415947
const auto *RHS = SE.getSCEV(Cmp->getOperand(1));
15935-
CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap);
15948+
// If LHS is a constant, apply information to the other expression.
15949+
if (isa<SCEVConstant>(LHS)) {
15950+
std::swap(LHS, RHS);
15951+
Predicate = CmpInst::getSwappedPredicate(Predicate);
15952+
}
15953+
GuardsToProcess.emplace_back(Predicate, LHS, RHS);
1593615954
continue;
1593715955
}
1593815956

@@ -15945,6 +15963,28 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1594515963
}
1594615964
}
1594715965

15966+
// Process divisibility guards in reverse order to populate DivInfo early.
15967+
DenseMap<const SCEV *, const SCEV *> Multiples;
15968+
DenseMap<const SCEV *, const SCEV *> DivInfo;
15969+
for (const auto &[Predicate, LHS, RHS] : GuardsToProcess) {
15970+
if (!isDivisibilityGuard(LHS, RHS, SE))
15971+
continue;
15972+
collectDivisibilityInformation(Predicate, LHS, RHS, DivInfo, Multiples, SE);
15973+
}
15974+
15975+
for (const auto &[Predicate, LHS, RHS] : GuardsToProcess)
15976+
CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap, DivInfo);
15977+
15978+
// Apply divisibility information last. This ensures it is applied to the
15979+
// outermost expression after other rewrites for the given value.
15980+
for (const auto &[K, V] : Multiples) {
15981+
Guards.RewriteMap[K] = SE.getMulExpr(
15982+
SE.getUDivExpr(applyDivisibilityOnMinMaxExpr(Guards.rewrite(K), V, SE),
15983+
V),
15984+
V);
15985+
ExprsToRewrite.push_back(K);
15986+
}
15987+
1594815988
// Let the rewriter preserve NUW/NSW flags if the unsigned/signed ranges of
1594915989
// the replacement expressions are contained in the ranges of the replaced
1595015990
// expressions.

0 commit comments

Comments
 (0)