Skip to content

Commit 23d922e

Browse files
committed
[SCEV] Preserve divisibility info when creating UMax/SMax expressions.
Currently we generate (S|U)Max(1, Op) for Op >= 1. This may discard divisibility info of Op. This patch rewrites such SMax/UMax expressions to use the lowest common multiplier for all non-constant operands.
1 parent c6a4e84 commit 23d922e

File tree

2 files changed

+24
-4
lines changed

2 files changed

+24
-4
lines changed

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15857,12 +15857,17 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1585715857
To = SE.getUMaxExpr(FromRewritten, RHS);
1585815858
if (auto *UMin = dyn_cast<SCEVUMinExpr>(FromRewritten))
1585915859
EnqueueOperands(UMin);
15860+
if (RHS->isOne())
15861+
ExprsToRewrite.push_back(From);
1586015862
break;
1586115863
case CmpInst::ICMP_SGT:
1586215864
case CmpInst::ICMP_SGE:
1586315865
To = SE.getSMaxExpr(FromRewritten, RHS);
15864-
if (auto *SMin = dyn_cast<SCEVSMinExpr>(FromRewritten))
15866+
if (auto *SMin = dyn_cast<SCEVSMinExpr>(FromRewritten)) {
1586515867
EnqueueOperands(SMin);
15868+
}
15869+
if (RHS->isOne())
15870+
ExprsToRewrite.push_back(From);
1586615871
break;
1586715872
case CmpInst::ICMP_EQ:
1586815873
if (isa<SCEVConstant>(RHS))
@@ -15993,7 +15998,22 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1599315998
for (const SCEV *Expr : ExprsToRewrite) {
1599415999
const SCEV *RewriteTo = Guards.RewriteMap[Expr];
1599516000
Guards.RewriteMap.erase(Expr);
15996-
Guards.RewriteMap.insert({Expr, Guards.rewrite(RewriteTo)});
16001+
const SCEV *Rewritten = Guards.rewrite(RewriteTo);
16002+
16003+
// Try to strengthen divisibility of SMax/UMax expressions coming from >=
16004+
// 1 conditions.
16005+
if (auto *SMax = dyn_cast<SCEVSMaxExpr>(Rewritten)) {
16006+
unsigned MinTrailingZeros = SE.getMinTrailingZeros(SMax->getOperand(1));
16007+
for (const SCEV *Op : drop_begin(SMax->operands(), 2))
16008+
MinTrailingZeros =
16009+
std::min(MinTrailingZeros, SE.getMinTrailingZeros(Op));
16010+
if (MinTrailingZeros != 0)
16011+
Rewritten = SE.getSMaxExpr(
16012+
SE.getConstant(APInt(SMax->getType()->getScalarSizeInBits(), 1)
16013+
.shl(MinTrailingZeros)),
16014+
SMax);
16015+
}
16016+
Guards.RewriteMap.insert({Expr, Rewritten});
1599716017
}
1599816018
}
1599916019
}

llvm/test/Analysis/ScalarEvolution/trip-count-minmax.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ define void @umin(i32 noundef %a, i32 noundef %b) {
6161
; CHECK-NEXT: Loop %for.body: backedge-taken count is (-1 + ((2 * %a) umin (4 * %b)))
6262
; CHECK-NEXT: Loop %for.body: constant max backedge-taken count is i32 2147483646
6363
; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + ((2 * %a) umin (4 * %b)))
64-
; CHECK-NEXT: Loop %for.body: Trip multiple is 1
64+
; CHECK-NEXT: Loop %for.body: Trip multiple is 2
6565
;
6666
; void umin(unsigned a, unsigned b) {
6767
; a *= 2;
@@ -157,7 +157,7 @@ define void @smin(i32 noundef %a, i32 noundef %b) {
157157
; CHECK-NEXT: Loop %for.body: backedge-taken count is (-1 + ((2 * %a)<nsw> smin (4 * %b)<nsw>))
158158
; CHECK-NEXT: Loop %for.body: constant max backedge-taken count is i32 2147483646
159159
; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + ((2 * %a)<nsw> smin (4 * %b)<nsw>))
160-
; CHECK-NEXT: Loop %for.body: Trip multiple is 1
160+
; CHECK-NEXT: Loop %for.body: Trip multiple is 2
161161
;
162162
; void smin(signed a, signed b) {
163163
; a *= 2;

0 commit comments

Comments
 (0)