-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[InstCombine] Fold max/min when incrementing/decrementing by 1 #142466
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[InstCombine] Fold max/min when incrementing/decrementing by 1 #142466
Conversation
@llvm/pr-subscribers-llvm-transforms @llvm/pr-subscribers-llvm-analysis Author: Alex MacLean (AlexMaclean) ChangesAdd the following folds for integer min max folding in ValueTracking:
These are safe when overflow corresponding to the sign of the comparison is poison. (proof https://alive2.llvm.org/ce/z/oj5iiI). The most common of these patterns is likely the minimum case which occurs in some internal library code when clamping an integer index to a range (The maximum cases are included for completeness). Here is a simplified example: int clampToWidth(int idx, int width) {
if (idx >= width)
return width - 1;
return idx;
} https://cuda.godbolt.org/z/nhPzWrc3W Full diff: https://github.com/llvm/llvm-project/pull/142466.diff 2 Files Affected:
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index fc19b2ccf7964..416d586d52963 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -8388,6 +8388,24 @@ static SelectPatternResult matchMinMax(CmpInst::Predicate Pred,
}
}
+ // (X > Y) ? X : (Y - 1) ==> MIN(X, Y - 1)
+ // (X < Y) ? X : (Y + 1) ==> MAX(X, Y + 1)
+ // When overflow corresponding to the sign of the comparison is poison.
+ // Note that the UMIN case is not possible as we canonicalize to addition.
+ if (CmpLHS == TrueVal) {
+ if (Pred == CmpInst::ICMP_SGT &&
+ match(FalseVal, m_NSWAddLike(m_Specific(CmpRHS), m_ConstantInt<1>())))
+ return {SPF_SMAX, SPNB_NA, false};
+
+ if (Pred == CmpInst::ICMP_SLT &&
+ match(FalseVal, m_NSWAddLike(m_Specific(CmpRHS), m_ConstantInt<-1>())))
+ return {SPF_SMIN, SPNB_NA, false};
+
+ if (Pred == CmpInst::ICMP_UGT &&
+ match(FalseVal, m_NUWAddLike(m_Specific(CmpRHS), m_ConstantInt<1>())))
+ return {SPF_UMAX, SPNB_NA, false};
+ }
+
if (Pred != CmpInst::ICMP_SGT && Pred != CmpInst::ICMP_SLT)
return {SPF_UNKNOWN, SPNB_NA, false};
diff --git a/llvm/test/Transforms/InstCombine/minmax-fold.ll b/llvm/test/Transforms/InstCombine/minmax-fold.ll
index cd376b74fb36c..cf3515614321d 100644
--- a/llvm/test/Transforms/InstCombine/minmax-fold.ll
+++ b/llvm/test/Transforms/InstCombine/minmax-fold.ll
@@ -1596,3 +1596,40 @@ define <2 x i32> @test_umax_smax_vec_neg(<2 x i32> %x) {
%umax = call <2 x i32> @llvm.umax.v2i32(<2 x i32> %smax, <2 x i32> <i32 1, i32 10>)
ret <2 x i32> %umax
}
+
+define i32 @test_smin_sub1_nsw(i32 %x, i32 %w) {
+; CHECK-LABEL: @test_smin_sub1_nsw(
+; CHECK-NEXT: [[SUB:%.*]] = add nsw i32 [[W:%.*]], -1
+; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.smin.i32(i32 [[X:%.*]], i32 [[SUB]])
+; CHECK-NEXT: ret i32 [[R]]
+;
+ %cmp = icmp slt i32 %x, %w
+ %sub = add nsw i32 %w, -1
+ %r = select i1 %cmp, i32 %x, i32 %sub
+ ret i32 %r
+}
+
+define i32 @test_smax_add1_nsw(i32 %x, i32 %w) {
+; CHECK-LABEL: @test_smax_add1_nsw(
+; CHECK-NEXT: [[X2:%.*]] = add nsw i32 [[W:%.*]], 1
+; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.smax.i32(i32 [[X:%.*]], i32 [[X2]])
+; CHECK-NEXT: ret i32 [[R]]
+;
+ %cmp = icmp sgt i32 %x, %w
+ %add = add nsw i32 %w, 1
+ %r = select i1 %cmp, i32 %x, i32 %add
+ ret i32 %r
+}
+
+define i32 @test_umax_add1_nsw(i32 %x, i32 %w) {
+; CHECK-LABEL: @test_umax_add1_nsw(
+; CHECK-NEXT: [[X2:%.*]] = add nuw i32 [[W:%.*]], 1
+; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.umax.i32(i32 [[X:%.*]], i32 [[X2]])
+; CHECK-NEXT: ret i32 [[R]]
+;
+ %cmp = icmp ugt i32 %x, %w
+ %add = add nuw i32 %w, 1
+ %r = select i1 %cmp, i32 %x, i32 %add
+ ret i32 %r
+}
+
|
llvm/lib/Analysis/ValueTracking.cpp
Outdated
if (Pred == CmpInst::ICMP_UGT && | ||
match(FalseVal, m_NUWAddLike(m_Specific(CmpRHS), m_ConstantInt<1>()))) | ||
return {SPF_UMAX, SPNB_NA, false}; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing fold for (X > Y) ? X : (Y -nuw 1) ==> umin(X, Y - 1)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately, it seems this case is not possible. We canonicalize sub nuw %x, 1
to add %x, -1
, discarding the nuw
flag as unsigned wrap is now guaranteed unless %x is 0. I've noted that this case isn't possible in a comment. Is there a work around to this problem you can think of?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we check isKnownNonZero(%x)
instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could use Sounds good!isKnownNonZero
for this case. Would that make sense?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, I've added this case in
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Miscompilation reproducer: https://alive2.llvm.org/ce/z/nxGw_V
define i8 @src(i8 %x, i8 %w) {
%cmp = icmp ugt i8 %x, %w
%add = add nsw nuw i8 %w, 1
%r = select i1 %cmp, i8 %x, i8 %add
ret i8 %r
}
define i8 @tgt(i8 %x, i8 %w) {
%add = add nsw nuw i8 %w, 1
%r = call i8 @llvm.umax(i8 %x, i8 %add)
ret i8 %r
}
We need to drop nsw/nuw flags in this case.
Thanks! I've fixed this case by ensuring that we now drop the no-wrap flag which doesn't correspond to the sign of the min/max. I've added tests to confirm as well. |
abb02ec
to
5d6aba5
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thank you!
@@ -565,6 +565,59 @@ Instruction *InstCombinerImpl::foldSelectIntoOp(SelectInst &SI, Value *TrueVal, | |||
return nullptr; | |||
} | |||
|
|||
/// Try to fold a select to a min/max intrinsic. Many cases are already handled | |||
/// by matchDecomposedSelectPattern but here we handle the cases where more | |||
/// exensive modification of the IR is required. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
extensive?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
} | ||
|
||
// Note: We must use isKnownNonZero here because "sub nuw %x, 1" will be | ||
// canonicalize to "add %x, -1" discarding the nuw flag. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// canonicalize to "add %x, -1" discarding the nuw flag. | |
// canonicalized to "add %x, -1" discarding the nuw flag. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
Pred = CmpInst::getSwappedPredicate(Pred); | ||
} | ||
|
||
// TODO: consider handeling 'or disjoint' as well, though these would need to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// TODO: consider handeling 'or disjoint' as well, though these would need to | |
// TODO: consider handling 'or disjoint' as well, though these would need to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
|
||
// TODO: consider handeling 'or disjoint' as well, though these would need to | ||
// be converted to 'add' instructions. | ||
if (CmpLHS == TVal && isa<Instruction>(FVal)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Early exit?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. Updated.
if (Pred == CmpInst::ICMP_ULT && | ||
match(FVal, m_Add(m_Specific(CmpRHS), m_AllOnes())) && | ||
isKnownNonZero(CmpRHS, SQ)) { | ||
cast<Instruction>(FVal)->setHasNoSignedWrap(false); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case the nuw flag is also invalid: https://alive2.llvm.org/ce/z/ZpSaKv
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated. In practice, I'm not sure we could ever reach this point in that case, at least I was not able to construct a test. But I agree it is good to stay on the safe side and not rely on unrelated optimizations to prevent a mis-compilation here.
Add the following folds for integer min max folding in InstCombine:
These are safe when overflow corresponding to the sign of the comparison is poison. (proof https://alive2.llvm.org/ce/z/oj5iiI).
The most common of these patterns is likely the minimum case which occurs in some internal library code when clamping an integer index to a range (The maximum cases are included for completeness). Here is a simplified example:
https://cuda.godbolt.org/z/nhPzWrc3W