Skip to content

Commit 6674962

Browse files
committed
[SCEV] Preserve divisibility info when creating UMax/SMax expressions.
Currently we generate (S|U)Max(1, Op) for Op >= 1. This may discard divisibility info of Op. This patch rewrites such SMax/UMax expressions to use the lowest common multiplier for all non-constant operands.
1 parent e8e678c commit 6674962

File tree

2 files changed

+24
-4
lines changed

2 files changed

+24
-4
lines changed

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15850,12 +15850,17 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1585015850
To = SE.getUMaxExpr(FromRewritten, RHS);
1585115851
if (auto *UMin = dyn_cast<SCEVUMinExpr>(FromRewritten))
1585215852
EnqueueOperands(UMin);
15853+
if (RHS->isOne())
15854+
ExprsToRewrite.push_back(From);
1585315855
break;
1585415856
case CmpInst::ICMP_SGT:
1585515857
case CmpInst::ICMP_SGE:
1585615858
To = SE.getSMaxExpr(FromRewritten, RHS);
15857-
if (auto *SMin = dyn_cast<SCEVSMinExpr>(FromRewritten))
15859+
if (auto *SMin = dyn_cast<SCEVSMinExpr>(FromRewritten)) {
1585815860
EnqueueOperands(SMin);
15861+
}
15862+
if (RHS->isOne())
15863+
ExprsToRewrite.push_back(From);
1585915864
break;
1586015865
case CmpInst::ICMP_EQ:
1586115866
if (isa<SCEVConstant>(RHS))
@@ -15986,7 +15991,22 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1598615991
for (const SCEV *Expr : ExprsToRewrite) {
1598715992
const SCEV *RewriteTo = Guards.RewriteMap[Expr];
1598815993
Guards.RewriteMap.erase(Expr);
15989-
Guards.RewriteMap.insert({Expr, Guards.rewrite(RewriteTo)});
15994+
const SCEV *Rewritten = Guards.rewrite(RewriteTo);
15995+
15996+
// Try to strengthen divisibility of SMax/UMax expressions coming from >=
15997+
// 1 conditions.
15998+
if (auto *SMax = dyn_cast<SCEVSMaxExpr>(Rewritten)) {
15999+
unsigned MinTrailingZeros = SE.getMinTrailingZeros(SMax->getOperand(1));
16000+
for (const SCEV *Op : drop_begin(SMax->operands(), 2))
16001+
MinTrailingZeros =
16002+
std::min(MinTrailingZeros, SE.getMinTrailingZeros(Op));
16003+
if (MinTrailingZeros != 0)
16004+
Rewritten = SE.getSMaxExpr(
16005+
SE.getConstant(APInt(SMax->getType()->getScalarSizeInBits(), 1)
16006+
.shl(MinTrailingZeros)),
16007+
SMax);
16008+
}
16009+
Guards.RewriteMap.insert({Expr, Rewritten});
1599016010
}
1599116011
}
1599216012
}

llvm/test/Analysis/ScalarEvolution/trip-count-minmax.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ define void @umin(i32 noundef %a, i32 noundef %b) {
6161
; CHECK-NEXT: Loop %for.body: backedge-taken count is (-1 + ((2 * %a) umin (4 * %b)))
6262
; CHECK-NEXT: Loop %for.body: constant max backedge-taken count is i32 2147483646
6363
; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + ((2 * %a) umin (4 * %b)))
64-
; CHECK-NEXT: Loop %for.body: Trip multiple is 1
64+
; CHECK-NEXT: Loop %for.body: Trip multiple is 2
6565
;
6666
; void umin(unsigned a, unsigned b) {
6767
; a *= 2;
@@ -157,7 +157,7 @@ define void @smin(i32 noundef %a, i32 noundef %b) {
157157
; CHECK-NEXT: Loop %for.body: backedge-taken count is (-1 + ((2 * %a)<nsw> smin (4 * %b)<nsw>))
158158
; CHECK-NEXT: Loop %for.body: constant max backedge-taken count is i32 2147483646
159159
; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + ((2 * %a)<nsw> smin (4 * %b)<nsw>))
160-
; CHECK-NEXT: Loop %for.body: Trip multiple is 1
160+
; CHECK-NEXT: Loop %for.body: Trip multiple is 2
161161
;
162162
; void smin(signed a, signed b) {
163163
; a *= 2;

0 commit comments

Comments
 (0)