Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[InstCombine] Generalize icmp (shl nuw C2, Y), C -> icmp Y, C3 #104696

Merged
merged 6 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2227,18 +2227,24 @@ Instruction *InstCombinerImpl::foldICmpMulConstant(ICmpInst &Cmp,
return NewC ? new ICmpInst(Pred, X, NewC) : nullptr;
}

/// Fold icmp (shl 1, Y), C.
static Instruction *foldICmpShlOne(ICmpInst &Cmp, Instruction *Shl,
const APInt &C) {
/// Fold icmp (shl nuw C2, Y), C.
static Instruction *foldICmpShlLHSC(ICmpInst &Cmp, Instruction *Shl,
const APInt &C) {
Value *Y;
if (!match(Shl, m_Shl(m_One(), m_Value(Y))))
const APInt *C2;
if (!match(Shl, m_NUWShl(m_APInt(C2), m_Value(Y))))
return nullptr;

Type *ShiftType = Shl->getType();
unsigned TypeBits = C.getBitWidth();
bool CIsPowerOf2 = C.isPowerOf2();
ICmpInst::Predicate Pred = Cmp.getPredicate();
if (Cmp.isUnsigned()) {
if (C2->isZero() || C2->ugt(C))
return nullptr;
APInt Div, Rem;
APInt::udivrem(C, *C2, Div, Rem);
bool CIsPowerOf2 = Rem.isZero() && Div.isPowerOf2();

// (1 << Y) pred C -> Y pred Log2(C)
if (!CIsPowerOf2) {
// (1 << Y) < 30 -> Y <= 4
Expand All @@ -2251,9 +2257,9 @@ static Instruction *foldICmpShlOne(ICmpInst &Cmp, Instruction *Shl,
Pred = ICmpInst::ICMP_UGT;
}

unsigned CLog2 = C.logBase2();
unsigned CLog2 = Div.logBase2();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test where C and C2 are not a power of two and we need to roundup the log? I.e
(48 << X) u>= 144?

I don't see the ceil logic for that case.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return new ICmpInst(Pred, Y, ConstantInt::get(ShiftType, CLog2));
} else if (Cmp.isSigned()) {
} else if (Cmp.isSigned() && C2->isOne()) {
Constant *BitWidthMinusOne = ConstantInt::get(ShiftType, TypeBits - 1);
// (1 << Y) > 0 -> Y != 31
// (1 << Y) > C -> Y != 31 if C is negative.
Expand Down Expand Up @@ -2307,7 +2313,7 @@ Instruction *InstCombinerImpl::foldICmpShlConstant(ICmpInst &Cmp,

const APInt *ShiftAmt;
if (!match(Shl->getOperand(1), m_APInt(ShiftAmt)))
return foldICmpShlOne(Cmp, Shl, C);
return foldICmpShlLHSC(Cmp, Shl, C);

// Check that the shift amount is in range. If not, don't perform undefined
// shifts. When the shift is visited, it will be simplified.
Expand Down
106 changes: 106 additions & 0 deletions llvm/test/Transforms/InstCombine/icmp-shl-nuw.ll
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,109 @@ define <2 x i1> @icmp_ugt_16x2(<2 x i32>) {
%d = icmp ugt <2 x i32> %c, <i32 1048575, i32 1048575>
ret <2 x i1> %d
}

define i1 @fold_icmp_shl_nuw_c1(i32 %x) {
; CHECK-LABEL: @fold_icmp_shl_nuw_c1(
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[X:%.*]], 61440
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[TMP1]], 0
; CHECK-NEXT: ret i1 [[CMP]]
;
%lshr = lshr i32 %x, 12
%and = and i32 %lshr, 15
%shl = shl nuw i32 2, %and
%cmp = icmp ult i32 %shl, 4
ret i1 %cmp
}

define i1 @fold_icmp_shl_nuw_c2(i32 %x) {
; CHECK-LABEL: @fold_icmp_shl_nuw_c2(
; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[X:%.*]], 2
; CHECK-NEXT: ret i1 [[CMP]]
;
%shl = shl nuw i32 16, %x
%cmp = icmp ult i32 %shl, 64
ret i1 %cmp
}

define i1 @fold_icmp_shl_nuw_c2_non_pow2(i32 %x) {
; CHECK-LABEL: @fold_icmp_shl_nuw_c2_non_pow2(
; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[X:%.*]], 2
; CHECK-NEXT: ret i1 [[CMP]]
;
%shl = shl nuw i32 48, %x
%cmp = icmp ult i32 %shl, 192
ret i1 %cmp
}

define i1 @fold_icmp_shl_nuw_c2_div_non_pow2(i32 %x) {
; CHECK-LABEL: @fold_icmp_shl_nuw_c2_div_non_pow2(
; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[X:%.*]], 5
; CHECK-NEXT: ret i1 [[CMP]]
;
%shl = shl nuw i32 2, %x
%cmp = icmp ult i32 %shl, 60
ret i1 %cmp
}

define i1 @fold_icmp_shl_nuw_c3(i32 %x) {
; CHECK-LABEL: @fold_icmp_shl_nuw_c3(
; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i32 [[X:%.*]], 1
; CHECK-NEXT: ret i1 [[CMP]]
;
%shl = shl nuw i32 48, %x
%cmp = icmp uge i32 %shl, 144
ret i1 %cmp
}

define i1 @fold_icmp_shl_nuw_c2_indivisible(i32 %x) {
; CHECK-LABEL: @fold_icmp_shl_nuw_c2_indivisible(
; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[X:%.*]], 2
; CHECK-NEXT: ret i1 [[CMP]]
;
%shl = shl nuw i32 16, %x
%cmp = icmp ult i32 %shl, 63
ret i1 %cmp
}

; Negative tests

define i1 @fold_icmp_shl_c2_without_nuw(i32 %x) {
; CHECK-LABEL: @fold_icmp_shl_c2_without_nuw(
; CHECK-NEXT: [[SHL:%.*]] = shl i32 16, [[X:%.*]]
; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[SHL]], 64
; CHECK-NEXT: ret i1 [[CMP]]
;
%shl = shl i32 16, %x
%cmp = icmp ult i32 %shl, 64
ret i1 %cmp
}

; Make sure this trivial case is folded by InstSimplify.
define i1 @fold_icmp_shl_nuw_c2_precondition1(i32 %x) {
; CHECK-LABEL: @fold_icmp_shl_nuw_c2_precondition1(
; CHECK-NEXT: ret i1 true
;
%shl = shl nuw i32 0, %x
%cmp = icmp ult i32 %shl, 63
ret i1 %cmp
}

; Make sure this trivial case is folded by InstSimplify.
define i1 @fold_icmp_shl_nuw_c2_precondition2(i32 %x) {
; CHECK-LABEL: @fold_icmp_shl_nuw_c2_precondition2(
; CHECK-NEXT: ret i1 false
;
%shl = shl nuw i32 127, %x
%cmp = icmp ult i32 %shl, 63
ret i1 %cmp
}

; Make sure we don't crash on this case.
define i1 @fold_icmp_shl_nuw_c2_precondition3(i32 %x) {
; CHECK-LABEL: @fold_icmp_shl_nuw_c2_precondition3(
; CHECK-NEXT: ret i1 false
;
%shl = shl nuw i32 1, %x
%cmp = icmp ult i32 %shl, 1
ret i1 %cmp
}
Loading