diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index 34c9e0fde4f428..8034befb6b9449 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -2227,18 +2227,24 @@ Instruction *InstCombinerImpl::foldICmpMulConstant(ICmpInst &Cmp, return NewC ? new ICmpInst(Pred, X, NewC) : nullptr; } -/// Fold icmp (shl 1, Y), C. -static Instruction *foldICmpShlOne(ICmpInst &Cmp, Instruction *Shl, - const APInt &C) { +/// Fold icmp (shl nuw C2, Y), C. +static Instruction *foldICmpShlLHSC(ICmpInst &Cmp, Instruction *Shl, + const APInt &C) { Value *Y; - if (!match(Shl, m_Shl(m_One(), m_Value(Y)))) + const APInt *C2; + if (!match(Shl, m_NUWShl(m_APInt(C2), m_Value(Y)))) return nullptr; Type *ShiftType = Shl->getType(); unsigned TypeBits = C.getBitWidth(); - bool CIsPowerOf2 = C.isPowerOf2(); ICmpInst::Predicate Pred = Cmp.getPredicate(); if (Cmp.isUnsigned()) { + APInt Div, Rem; + APInt::udivrem(C, *C2, Div, Rem); + if (!Rem.isZero()) + return nullptr; + bool CIsPowerOf2 = Div.isPowerOf2(); + // (1 << Y) pred C -> Y pred Log2(C) if (!CIsPowerOf2) { // (1 << Y) < 30 -> Y <= 4 @@ -2251,9 +2257,9 @@ static Instruction *foldICmpShlOne(ICmpInst &Cmp, Instruction *Shl, Pred = ICmpInst::ICMP_UGT; } - unsigned CLog2 = C.logBase2(); + unsigned CLog2 = Div.logBase2(); return new ICmpInst(Pred, Y, ConstantInt::get(ShiftType, CLog2)); - } else if (Cmp.isSigned()) { + } else if (Cmp.isSigned() && C2->isOne()) { Constant *BitWidthMinusOne = ConstantInt::get(ShiftType, TypeBits - 1); // (1 << Y) > 0 -> Y != 31 // (1 << Y) > C -> Y != 31 if C is negative. @@ -2307,7 +2313,7 @@ Instruction *InstCombinerImpl::foldICmpShlConstant(ICmpInst &Cmp, const APInt *ShiftAmt; if (!match(Shl->getOperand(1), m_APInt(ShiftAmt))) - return foldICmpShlOne(Cmp, Shl, C); + return foldICmpShlLHSC(Cmp, Shl, C); // Check that the shift amount is in range. If not, don't perform undefined // shifts. When the shift is visited, it will be simplified. diff --git a/llvm/test/Transforms/InstCombine/icmp-shl-nuw.ll b/llvm/test/Transforms/InstCombine/icmp-shl-nuw.ll index 8cfd3228999ac5..46671f83610fd1 100644 --- a/llvm/test/Transforms/InstCombine/icmp-shl-nuw.ll +++ b/llvm/test/Transforms/InstCombine/icmp-shl-nuw.ll @@ -93,10 +93,8 @@ define <2 x i1> @icmp_ugt_16x2(<2 x i32>) { define i1 @fold_icmp_shl_nuw_c1(i32 %x) { ; CHECK-LABEL: @fold_icmp_shl_nuw_c1( -; CHECK-NEXT: [[LSHR:%.*]] = lshr i32 [[X:%.*]], 12 -; CHECK-NEXT: [[AND:%.*]] = and i32 [[LSHR]], 15 -; CHECK-NEXT: [[SHL:%.*]] = shl nuw nsw i32 2, [[AND]] -; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[SHL]], 4 +; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[X:%.*]], 61440 +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[TMP1]], 0 ; CHECK-NEXT: ret i1 [[CMP]] ; %lshr = lshr i32 %x, 12 @@ -108,8 +106,7 @@ define i1 @fold_icmp_shl_nuw_c1(i32 %x) { define i1 @fold_icmp_shl_nuw_c2(i32 %x) { ; CHECK-LABEL: @fold_icmp_shl_nuw_c2( -; CHECK-NEXT: [[SHL:%.*]] = shl nuw i32 16, [[X:%.*]] -; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[SHL]], 64 +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[X:%.*]], 2 ; CHECK-NEXT: ret i1 [[CMP]] ; %shl = shl nuw i32 16, %x @@ -119,8 +116,7 @@ define i1 @fold_icmp_shl_nuw_c2(i32 %x) { define i1 @fold_icmp_shl_nuw_c2_non_pow2(i32 %x) { ; CHECK-LABEL: @fold_icmp_shl_nuw_c2_non_pow2( -; CHECK-NEXT: [[SHL:%.*]] = shl nuw i32 48, [[X:%.*]] -; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[SHL]], 192 +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[X:%.*]], 2 ; CHECK-NEXT: ret i1 [[CMP]] ; %shl = shl nuw i32 48, %x @@ -130,8 +126,7 @@ define i1 @fold_icmp_shl_nuw_c2_non_pow2(i32 %x) { define i1 @fold_icmp_shl_nuw_c2_div_non_pow2(i32 %x) { ; CHECK-LABEL: @fold_icmp_shl_nuw_c2_div_non_pow2( -; CHECK-NEXT: [[SHL:%.*]] = shl nuw i32 2, [[X:%.*]] -; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[SHL]], 60 +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[X:%.*]], 5 ; CHECK-NEXT: ret i1 [[CMP]] ; %shl = shl nuw i32 2, %x