-
Notifications
You must be signed in to change notification settings - Fork 12.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SCEV] Use Step and Start to check if SCEVWrapPredicate is implied. #118184
Conversation
@llvm/pr-subscribers-llvm-analysis Author: Florian Hahn (fhahn) ChangesA SCEVWrapPredicate A implies B, if
See https://alive2.llvm.org/ce/z/n2T4ss (first pair with known constants as strides, second pair with variable strides). Note that this is limited to steps of the same size, due to NSUW having slightly different semantics than regular NUW. We should be able to remove this restriction for NSSW (which matches NSW) in the future. Full diff: https://github.com/llvm/llvm-project/pull/118184.diff 4 Files Affected:
diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index 885c5985f9d23a..27df25cbf2b7d2 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -241,7 +241,7 @@ class SCEVPredicate : public FoldingSetNode {
virtual bool isAlwaysTrue() const = 0;
/// Returns true if this predicate implies \p N.
- virtual bool implies(const SCEVPredicate *N) const = 0;
+ virtual bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const = 0;
/// Prints a textual representation of this predicate with an indentation of
/// \p Depth.
@@ -286,7 +286,7 @@ class SCEVComparePredicate final : public SCEVPredicate {
const SCEV *LHS, const SCEV *RHS);
/// Implementation of the SCEVPredicate interface
- bool implies(const SCEVPredicate *N) const override;
+ bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override;
void print(raw_ostream &OS, unsigned Depth = 0) const override;
bool isAlwaysTrue() const override;
@@ -393,7 +393,7 @@ class SCEVWrapPredicate final : public SCEVPredicate {
/// Implementation of the SCEVPredicate interface
const SCEVAddRecExpr *getExpr() const;
- bool implies(const SCEVPredicate *N) const override;
+ bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override;
void print(raw_ostream &OS, unsigned Depth = 0) const override;
bool isAlwaysTrue() const override;
@@ -418,16 +418,17 @@ class SCEVUnionPredicate final : public SCEVPredicate {
SmallVector<const SCEVPredicate *, 16> Preds;
/// Adds a predicate to this union.
- void add(const SCEVPredicate *N);
+ void add(const SCEVPredicate *N, ScalarEvolution &SE);
public:
- SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds);
+ SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds,
+ ScalarEvolution &SE);
ArrayRef<const SCEVPredicate *> getPredicates() const { return Preds; }
/// Implementation of the SCEVPredicate interface
bool isAlwaysTrue() const override;
- bool implies(const SCEVPredicate *N) const override;
+ bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override;
void print(raw_ostream &OS, unsigned Depth) const override;
/// We estimate the complexity of a union predicate as the size number of
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index c3f296b9ff3347..a98e53450b9520 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -5725,8 +5725,9 @@ bool PredicatedScalarEvolution::areAddRecsEqualWithPreds(
return true;
auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
- if (Expr1 != Expr2 && !Preds->implies(SE.getEqualPredicate(Expr1, Expr2)) &&
- !Preds->implies(SE.getEqualPredicate(Expr2, Expr1)))
+ if (Expr1 != Expr2 &&
+ !Preds->implies(SE.getEqualPredicate(Expr1, Expr2), SE) &&
+ !Preds->implies(SE.getEqualPredicate(Expr2, Expr1), SE))
return false;
return true;
};
@@ -14823,7 +14824,7 @@ class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
bool addOverflowAssumption(const SCEVPredicate *P) {
if (!NewPreds) {
// Check if we've already made this assumption.
- return Pred && Pred->implies(P);
+ return Pred && Pred->implies(P, SE);
}
NewPreds->push_back(P);
return true;
@@ -14904,7 +14905,8 @@ SCEVComparePredicate::SCEVComparePredicate(const FoldingSetNodeIDRef ID,
assert(LHS != RHS && "LHS and RHS are the same SCEV");
}
-bool SCEVComparePredicate::implies(const SCEVPredicate *N) const {
+bool SCEVComparePredicate::implies(const SCEVPredicate *N,
+ ScalarEvolution &SE) const {
const auto *Op = dyn_cast<SCEVComparePredicate>(N);
if (!Op)
@@ -14934,10 +14936,52 @@ SCEVWrapPredicate::SCEVWrapPredicate(const FoldingSetNodeIDRef ID,
const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; }
-bool SCEVWrapPredicate::implies(const SCEVPredicate *N) const {
+bool SCEVWrapPredicate::implies(const SCEVPredicate *N,
+ ScalarEvolution &SE) const {
const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
+ if (!Op)
+ return false;
+
+ if (setFlags(Flags, Op->Flags) != Flags)
+ return false;
+
+ if (Op->AR == AR)
+ return true;
+
+ if (Flags != SCEVWrapPredicate::IncrementNSSW &&
+ Flags != SCEVWrapPredicate::IncrementNUSW)
+ return false;
- return Op && Op->AR == AR && setFlags(Flags, Op->Flags) == Flags;
+ bool IsNUW = Flags == SCEVWrapPredicate::IncrementNUSW;
+ const SCEV *Step = AR->getStepRecurrence(SE);
+ const SCEV *OpStep = Op->AR->getStepRecurrence(SE);
+
+ // If both steps are positive, this implies N, if N's start and step are
+ // ULE/SLE (for NSUW/NSSW) than this'.
+ if (SE.isKnownPositive(Step) && SE.isKnownPositive(OpStep)) {
+ const SCEV *OpStart = Op->AR->getStart();
+ const SCEV *Start = AR->getStart();
+ if (SE.getTypeSizeInBits(Step->getType()) >
+ SE.getTypeSizeInBits(OpStep->getType())) {
+ OpStep = SE.getZeroExtendExpr(OpStep, Step->getType());
+ } else {
+ Step = IsNUW ? SE.getNoopOrZeroExtend(Step, OpStep->getType())
+ : SE.getNoopOrSignExtend(Step, OpStep->getType());
+ }
+ if (SE.getTypeSizeInBits(Start->getType()) >
+ SE.getTypeSizeInBits(OpStart->getType())) {
+ OpStart = IsNUW ? SE.getZeroExtendExpr(OpStart, Start->getType())
+ : SE.getSignExtendExpr(OpStart, Start->getType());
+ } else {
+ Start = IsNUW ? SE.getNoopOrZeroExtend(Start, OpStart->getType())
+ : SE.getNoopOrSignExtend(Start, OpStart->getType());
+ }
+
+ CmpInst::Predicate Pred = IsNUW ? CmpInst::ICMP_ULE : CmpInst::ICMP_SLE;
+ return SE.isKnownPredicate(Pred, OpStep, Step) &&
+ SE.isKnownPredicate(Pred, OpStart, Start);
+ }
+ return false;
}
bool SCEVWrapPredicate::isAlwaysTrue() const {
@@ -14981,10 +15025,11 @@ SCEVWrapPredicate::getImpliedFlags(const SCEVAddRecExpr *AR,
}
/// Union predicates don't get cached so create a dummy set ID for it.
-SCEVUnionPredicate::SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds)
- : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
+SCEVUnionPredicate::SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds,
+ ScalarEvolution &SE)
+ : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
for (const auto *P : Preds)
- add(P);
+ add(P, SE);
}
bool SCEVUnionPredicate::isAlwaysTrue() const {
@@ -14992,13 +15037,15 @@ bool SCEVUnionPredicate::isAlwaysTrue() const {
[](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
}
-bool SCEVUnionPredicate::implies(const SCEVPredicate *N) const {
+bool SCEVUnionPredicate::implies(const SCEVPredicate *N,
+ ScalarEvolution &SE) const {
if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
- return all_of(Set->Preds,
- [this](const SCEVPredicate *I) { return this->implies(I); });
+ return all_of(Set->Preds, [this, &SE](const SCEVPredicate *I) {
+ return this->implies(I, SE);
+ });
return any_of(Preds,
- [N](const SCEVPredicate *I) { return I->implies(N); });
+ [N, &SE](const SCEVPredicate *I) { return I->implies(N, SE); });
}
void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const {
@@ -15006,15 +15053,15 @@ void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const {
Pred->print(OS, Depth);
}
-void SCEVUnionPredicate::add(const SCEVPredicate *N) {
+void SCEVUnionPredicate::add(const SCEVPredicate *N, ScalarEvolution &SE) {
if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
for (const auto *Pred : Set->Preds)
- add(Pred);
+ add(Pred, SE);
return;
}
// Only add predicate if it is not already implied by this union predicate.
- if (!implies(N))
+ if (!implies(N, SE))
Preds.push_back(N);
}
@@ -15022,7 +15069,7 @@ PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE,
Loop &L)
: SE(SE), L(L) {
SmallVector<const SCEVPredicate*, 4> Empty;
- Preds = std::make_unique<SCEVUnionPredicate>(Empty);
+ Preds = std::make_unique<SCEVUnionPredicate>(Empty, SE);
}
void ScalarEvolution::registerUser(const SCEV *User,
@@ -15086,12 +15133,12 @@ unsigned PredicatedScalarEvolution::getSmallConstantMaxTripCount() {
}
void PredicatedScalarEvolution::addPredicate(const SCEVPredicate &Pred) {
- if (Preds->implies(&Pred))
+ if (Preds->implies(&Pred, SE))
return;
SmallVector<const SCEVPredicate *, 4> NewPreds(Preds->getPredicates());
NewPreds.push_back(&Pred);
- Preds = std::make_unique<SCEVUnionPredicate>(NewPreds);
+ Preds = std::make_unique<SCEVUnionPredicate>(NewPreds, SE);
updateGeneration();
}
@@ -15158,9 +15205,10 @@ const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec(Value *V) {
PredicatedScalarEvolution::PredicatedScalarEvolution(
const PredicatedScalarEvolution &Init)
- : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
- Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates())),
- Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
+ : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
+ Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates(),
+ SE)),
+ Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
for (auto I : Init.FlagsMap)
FlagsMap.insert(I);
}
diff --git a/llvm/test/Analysis/LoopAccessAnalysis/memcheck-wrapping-pointers.ll b/llvm/test/Analysis/LoopAccessAnalysis/memcheck-wrapping-pointers.ll
index 6dbb4a0c0129a6..ae10ab841420fd 100644
--- a/llvm/test/Analysis/LoopAccessAnalysis/memcheck-wrapping-pointers.ll
+++ b/llvm/test/Analysis/LoopAccessAnalysis/memcheck-wrapping-pointers.ll
@@ -29,20 +29,19 @@ target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
; CHECK-NEXT: Run-time memory checks:
; CHECK-NEXT: Check 0:
; CHECK-NEXT: Comparing group
-; CHECK-NEXT: %arrayidx = getelementptr inbounds i32, ptr %a, i64 %idxprom
-; CHECK-NEXT: Against group
; CHECK-NEXT: %arrayidx4 = getelementptr inbounds i32, ptr %b, i64 %conv11
+; CHECK-NEXT: Against group
+; CHECK-NEXT: %arrayidx = getelementptr inbounds i32, ptr %a, i64 %idxprom
; CHECK-NEXT: Grouped accesses:
; CHECK-NEXT: Group
-; CHECK-NEXT: (Low: (4 + %a) High: (4 + (4 * (1 umax %x)) + %a))
-; CHECK-NEXT: Member: {(4 + %a),+,4}<%for.body>
-; CHECK-NEXT: Group
; CHECK-NEXT: (Low: %b High: ((4 * (1 umax %x)) + %b))
; CHECK-NEXT: Member: {%b,+,4}<%for.body>
+; CHECK-NEXT: Group
+; CHECK-NEXT: (Low: (4 + %a) High: (4 + (4 * (1 umax %x)) + %a))
+; CHECK-NEXT: Member: {(4 + %a),+,4}<%for.body>
; CHECK: Non vectorizable stores to invariant address were not found in loop.
; CHECK-NEXT: SCEV assumptions:
; CHECK-NEXT: {1,+,1}<%for.body> Added Flags: <nusw>
-; CHECK-NEXT: {0,+,1}<%for.body> Added Flags: <nusw>
; CHECK: Expressions re-written:
; CHECK-NEXT: [PSE] %arrayidx = getelementptr inbounds i32, ptr %a, i64 %idxprom:
; CHECK-NEXT: ((4 * (zext i32 {1,+,1}<%for.body> to i64))<nuw><nsw> + %a)<nuw>
@@ -85,7 +84,6 @@ exit:
; CHECK: Memory dependences are safe
; CHECK: SCEV assumptions:
; CHECK-NEXT: {1,+,1}<%for.body> Added Flags: <nusw>
-; CHECK-NEXT: {0,+,1}<%for.body> Added Flags: <nusw>
define void @test2(i64 %x, ptr %a) {
entry:
br label %for.body
diff --git a/llvm/test/Analysis/LoopAccessAnalysis/nssw-predicate-implied.ll b/llvm/test/Analysis/LoopAccessAnalysis/nssw-predicate-implied.ll
index 1a07805c2614f8..4f595b44ae5fde 100644
--- a/llvm/test/Analysis/LoopAccessAnalysis/nssw-predicate-implied.ll
+++ b/llvm/test/Analysis/LoopAccessAnalysis/nssw-predicate-implied.ll
@@ -3,7 +3,7 @@
target datalayout = "e-m:o-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-n32:64-S128-Fn32"
-; FIXME: {0,+,3} implies {0,+,2}.
+; {0,+,3} [nssw] implies {0,+,2} [nssw]
define void @wrap_check_iv.3_implies_iv.2(i32 noundef %N, ptr %dst, ptr %src) {
; CHECK-LABEL: 'wrap_check_iv.3_implies_iv.2'
; CHECK-NEXT: loop:
@@ -26,7 +26,6 @@ define void @wrap_check_iv.3_implies_iv.2(i32 noundef %N, ptr %dst, ptr %src) {
; CHECK-NEXT: Non vectorizable stores to invariant address were not found in loop.
; CHECK-NEXT: SCEV assumptions:
; CHECK-NEXT: {0,+,3}<%loop> Added Flags: <nssw>
-; CHECK-NEXT: {0,+,2}<%loop> Added Flags: <nssw>
; CHECK-EMPTY:
; CHECK-NEXT: Expressions re-written:
; CHECK-NEXT: [PSE] %gep.iv.2 = getelementptr inbounds i32, ptr %src, i64 %ext.iv.2:
@@ -59,7 +58,7 @@ exit:
ret void
}
-; FIXME: {2,+,2} implies {0,+,2}.
+; {2,+,2} [nssw] implies {0,+,2} [nssw].
define void @wrap_check_iv.3_implies_iv.2_different_start(i32 noundef %N, ptr %dst, ptr %src) {
; CHECK-LABEL: 'wrap_check_iv.3_implies_iv.2_different_start'
; CHECK-NEXT: loop:
@@ -82,7 +81,6 @@ define void @wrap_check_iv.3_implies_iv.2_different_start(i32 noundef %N, ptr %d
; CHECK-NEXT: Non vectorizable stores to invariant address were not found in loop.
; CHECK-NEXT: SCEV assumptions:
; CHECK-NEXT: {2,+,2}<%loop> Added Flags: <nssw>
-; CHECK-NEXT: {0,+,2}<%loop> Added Flags: <nssw>
; CHECK-EMPTY:
; CHECK-NEXT: Expressions re-written:
; CHECK-NEXT: [PSE] %gep.iv.2 = getelementptr inbounds i32, ptr %src, i64 %ext.iv.2:
|
When adding a new predicate to a union predicate, some of the existing predicates may be implied by the new predicate. Remove any existing predicates that are already implied by the new predicate. Depends on llvm#118184.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM w/minor comments.
if (!Op) | ||
return false; | ||
|
||
if (setFlags(Flags, Op->Flags) != Flags) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could use intersection here, and prove that flags are weaker than implied.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like it the moment flags are either SCEVWrapPredicate::IncrementNSSW
or SCEVWrapPredicate::IncrementNUSW
and never both. Left as is for now.
SE.getTypeSizeInBits(OpStep->getType())) { | ||
OpStep = SE.getZeroExtendExpr(OpStep, Step->getType()); | ||
} else { | ||
Step = IsNUW ? SE.getNoopOrZeroExtend(Step, OpStep->getType()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If both step and opstep are positive, then aren't sext and zext the same thing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adjusted, thanks!
A SCEVWrapPredicate A implies B, if * they have the same flag, * both steps are positive and * B's start and step are ULE/SLE (for NSUW/NSSW) than A's. See https://alive2.llvm.org/ce/z/n2T4ss (first pair with known constants as strides, second pair with variable strides). Note that this is limited to steps of the same size, due to NSUW having slightly different semantics than regular NUW. We should be able to remove this restriction for NSSW (which matches NSW) in the future.
81bcaa7
to
4a10b2b
Compare
When adding a new predicate to a union predicate, some of the existing predicates may be implied by the new predicate. Remove any existing predicates that are already implied by the new predicate. Depends on llvm#118184.
Hi @fhahn The following starts failing with this patch:
|
When adding a new predicate to a union predicate, some of the existing predicates may be implied by the new predicate. Remove any existing predicates that are already implied by the new predicate. Depends on llvm#118184.
@mikaelholmen should be fixed by bailing out on mixed int/pointer AddRecs: 95eb49a |
Yep, thanks! |
…nes. (#118185) When adding a new predicate to a union predicate, some of the existing predicates may be implied by the new predicate. Remove any existing predicates that are already implied by the new predicate. Depends on llvm/llvm-project#118184 to show the main benefit. PR: llvm/llvm-project#118185
A SCEVWrapPredicate A implies B, if
See https://alive2.llvm.org/ce/z/n2T4ss (first pair with known constants as strides, second pair with variable strides).
Note that this is limited to steps of the same size, due to NSUW having slightly different semantics than regular NUW. We should be able to remove this restriction for NSSW (which matches NSW) in the future.