Skip to content

[InstCombine][InstSimplify] Pass SimplifyQuery to computeKnownBits directly. NFC. #74246

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 3, 2023

Conversation

dtcxzyw
Copy link
Member

@dtcxzyw dtcxzyw commented Dec 3, 2023

This patch passes SimplifyQuery to computeKnownBits directly in InstSimplify and InstCombine.
As the DomConditionCache in #73662 is only used in InstCombine, it is inconvenient to introduce a new argument DC to computeKnownBits.

#74242 will be fixed by this patch and #73662.

@dtcxzyw dtcxzyw requested a review from goldsteinn December 3, 2023 17:40
@dtcxzyw dtcxzyw requested a review from nikic as a code owner December 3, 2023 17:40
@llvmbot llvmbot added llvm:analysis Includes value tracking, cost tables and constant folding llvm:transforms labels Dec 3, 2023
@llvmbot
Copy link
Member

llvmbot commented Dec 3, 2023

@llvm/pr-subscribers-llvm-analysis

Author: Yingwei Zheng (dtcxzyw)

Changes

This patch passes SimplifyQuery to computeKnownBits directly in InstSimplify and InstCombine.
As the DomConditionCache in #73662 is only used in InstCombine, it is inconvenient to introduce a new argument DC to computeKnownBits.

#74242 will be fixed by this patch and #73662.


Full diff: https://github.com/llvm/llvm-project/pull/74246.diff

2 Files Affected:

  • (modified) llvm/lib/Analysis/InstructionSimplify.cpp (+19-21)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp (+2-4)
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index cef9f6ec179ba..2a45acf63aa2c 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -811,7 +811,7 @@ static Value *simplifySubInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW,
     if (IsNUW)
       return Constant::getNullValue(Op0->getType());
 
-    KnownBits Known = computeKnownBits(Op1, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+    KnownBits Known = computeKnownBits(Op1, /* Depth */ 0, Q);
     if (Known.Zero.isMaxSignedValue()) {
       // Op1 is either 0 or the minimum signed value. If the sub is NSW, then
       // Op1 must be 0 because negating the minimum signed value is undefined.
@@ -1063,7 +1063,7 @@ static bool isDivZero(Value *X, Value *Y, const SimplifyQuery &Q,
   //      ("computeConstantRangeIncludingKnownBits")?
   const APInt *C;
   if (match(Y, m_APInt(C)) &&
-      computeKnownBits(X, Q.DL, 0, Q.AC, Q.CxtI, Q.DT).getMaxValue().ult(*C))
+      computeKnownBits(X, /* Depth */ 0, Q).getMaxValue().ult(*C))
     return true;
 
   // Try again for any divisor:
@@ -1125,8 +1125,7 @@ static Value *simplifyDivRem(Instruction::BinaryOps Opcode, Value *Op0,
   if (Op0 == Op1)
     return IsDiv ? ConstantInt::get(Ty, 1) : Constant::getNullValue(Ty);
 
-
-  KnownBits Known = computeKnownBits(Op1, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+  KnownBits Known = computeKnownBits(Op1, /* Depth */ 0, Q);
   // X / 0 -> poison
   // X % 0 -> poison
   // If the divisor is known to be zero, just return poison. This can happen in
@@ -1195,7 +1194,7 @@ static Value *simplifyDiv(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1,
   // less trailing zeros, then the result must be poison.
   const APInt *DivC;
   if (IsExact && match(Op1, m_APInt(DivC)) && DivC->countr_zero()) {
-    KnownBits KnownOp0 = computeKnownBits(Op0, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+    KnownBits KnownOp0 = computeKnownBits(Op0, /* Depth */ 0, Q);
     if (KnownOp0.countMaxTrailingZeros() < DivC->countr_zero())
       return PoisonValue::get(Op0->getType());
   }
@@ -1355,7 +1354,7 @@ static Value *simplifyShift(Instruction::BinaryOps Opcode, Value *Op0,
 
   // If any bits in the shift amount make that value greater than or equal to
   // the number of bits in the type, the shift is undefined.
-  KnownBits KnownAmt = computeKnownBits(Op1, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+  KnownBits KnownAmt = computeKnownBits(Op1, /* Depth */ 0, Q);
   if (KnownAmt.getMinValue().uge(KnownAmt.getBitWidth()))
     return PoisonValue::get(Op0->getType());
 
@@ -1368,7 +1367,7 @@ static Value *simplifyShift(Instruction::BinaryOps Opcode, Value *Op0,
   // Check for nsw shl leading to a poison value.
   if (IsNSW) {
     assert(Opcode == Instruction::Shl && "Expected shl for nsw instruction");
-    KnownBits KnownVal = computeKnownBits(Op0, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+    KnownBits KnownVal = computeKnownBits(Op0, /* Depth */ 0, Q);
     KnownBits KnownShl = KnownBits::shl(KnownVal, KnownAmt);
 
     if (KnownVal.Zero.isSignBitSet())
@@ -1404,8 +1403,7 @@ static Value *simplifyRightShift(Instruction::BinaryOps Opcode, Value *Op0,
   // The low bit cannot be shifted out of an exact shift if it is set.
   // TODO: Generalize by counting trailing zeros (see fold for exact division).
   if (IsExact) {
-    KnownBits Op0Known =
-        computeKnownBits(Op0, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT);
+    KnownBits Op0Known = computeKnownBits(Op0, /* Depth */ 0, Q);
     if (Op0Known.One[0])
       return Op0;
   }
@@ -1477,7 +1475,7 @@ static Value *simplifyLShrInst(Value *Op0, Value *Op1, bool IsExact,
   if (Q.IIQ.UseInstrInfo && match(Op1, m_APInt(ShRAmt)) &&
       match(Op0, m_c_Or(m_NUWShl(m_Value(X), m_APInt(ShLAmt)), m_Value(Y))) &&
       *ShRAmt == *ShLAmt) {
-    const KnownBits YKnown = computeKnownBits(Y, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+    const KnownBits YKnown = computeKnownBits(Y, /* Depth */ 0, Q);
     const unsigned EffWidthY = YKnown.countMaxActiveBits();
     if (ShRAmt->uge(EffWidthY))
       return X;
@@ -2105,7 +2103,7 @@ static Value *simplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,
       match(Op0, m_Add(m_Value(Shift), m_AllOnes())) &&
       isKnownToBeAPowerOfTwo(Shift, Q.DL, /*OrZero*/ false, 0, Q.AC, Q.CxtI,
                              Q.DT)) {
-    KnownBits Known = computeKnownBits(Shift, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+    KnownBits Known = computeKnownBits(Shift, /* Depth */ 0, Q);
     // Use getActiveBits() to make use of the additional power of two knowledge
     if (PowerC->getActiveBits() >= Known.getMaxValue().getActiveBits())
       return ConstantInt::getNullValue(Op1->getType());
@@ -2169,10 +2167,10 @@ static Value *simplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,
                         m_Value(Y)))) {
     const unsigned Width = Op0->getType()->getScalarSizeInBits();
     const unsigned ShftCnt = ShAmt->getLimitedValue(Width);
-    const KnownBits YKnown = computeKnownBits(Y, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+    const KnownBits YKnown = computeKnownBits(Y, /* Depth */ 0, Q);
     const unsigned EffWidthY = YKnown.countMaxActiveBits();
     if (EffWidthY <= ShftCnt) {
-      const KnownBits XKnown = computeKnownBits(X, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+      const KnownBits XKnown = computeKnownBits(X, /* Depth */ 0, Q);
       const unsigned EffWidthX = XKnown.countMaxActiveBits();
       const APInt EffBitsY = APInt::getLowBitsSet(Width, EffWidthY);
       const APInt EffBitsX = APInt::getLowBitsSet(Width, EffWidthX) << ShftCnt;
@@ -2968,7 +2966,7 @@ static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS,
       return getTrue(ITy);
     break;
   case ICmpInst::ICMP_SLT: {
-    KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+    KnownBits LHSKnown = computeKnownBits(LHS, /* Depth */ 0, Q);
     if (LHSKnown.isNegative())
       return getTrue(ITy);
     if (LHSKnown.isNonNegative())
@@ -2976,7 +2974,7 @@ static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS,
     break;
   }
   case ICmpInst::ICMP_SLE: {
-    KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+    KnownBits LHSKnown = computeKnownBits(LHS, /* Depth */ 0, Q);
     if (LHSKnown.isNegative())
       return getTrue(ITy);
     if (LHSKnown.isNonNegative() &&
@@ -2985,7 +2983,7 @@ static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS,
     break;
   }
   case ICmpInst::ICMP_SGE: {
-    KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+    KnownBits LHSKnown = computeKnownBits(LHS, /* Depth */ 0, Q);
     if (LHSKnown.isNegative())
       return getFalse(ITy);
     if (LHSKnown.isNonNegative())
@@ -2993,7 +2991,7 @@ static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS,
     break;
   }
   case ICmpInst::ICMP_SGT: {
-    KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+    KnownBits LHSKnown = computeKnownBits(LHS, /* Depth */ 0, Q);
     if (LHSKnown.isNegative())
       return getFalse(ITy);
     if (LHSKnown.isNonNegative() &&
@@ -3070,8 +3068,8 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
       return getTrue(ITy);
 
     if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGE) {
-      KnownBits RHSKnown = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
-      KnownBits YKnown = computeKnownBits(Y, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+      KnownBits RHSKnown = computeKnownBits(RHS, /* Depth */ 0, Q);
+      KnownBits YKnown = computeKnownBits(Y, /* Depth */ 0, Q);
       if (RHSKnown.isNonNegative() && YKnown.isNegative())
         return Pred == ICmpInst::ICMP_SLT ? getTrue(ITy) : getFalse(ITy);
       if (RHSKnown.isNegative() || YKnown.isNonNegative())
@@ -3094,7 +3092,7 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
       break;
     case ICmpInst::ICMP_SGT:
     case ICmpInst::ICMP_SGE: {
-      KnownBits Known = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+      KnownBits Known = computeKnownBits(RHS, /* Depth */ 0, Q);
       if (!Known.isNonNegative())
         break;
       [[fallthrough]];
@@ -3105,7 +3103,7 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
       return getFalse(ITy);
     case ICmpInst::ICMP_SLT:
     case ICmpInst::ICMP_SLE: {
-      KnownBits Known = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+      KnownBits Known = computeKnownBits(RHS, /* Depth */ 0, Q);
       if (!Known.isNonNegative())
         break;
       [[fallthrough]];
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index c72eb0f74de8e..b7958978c450c 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -962,15 +962,13 @@ static bool setShiftFlags(BinaryOperator &I, const SimplifyQuery &Q) {
   }
 
   // Compute what we know about shift count.
-  KnownBits KnownCnt =
-      computeKnownBits(I.getOperand(1), Q.DL, /*Depth*/ 0, Q.AC, Q.CxtI, Q.DT);
+  KnownBits KnownCnt = computeKnownBits(I.getOperand(1), /* Depth */ 0, Q);
   unsigned BitWidth = KnownCnt.getBitWidth();
   // Since shift produces a poison value if RHS is equal to or larger than the
   // bit width, we can safely assume that RHS is less than the bit width.
   uint64_t MaxCnt = KnownCnt.getMaxValue().getLimitedValue(BitWidth - 1);
 
-  KnownBits KnownAmt =
-      computeKnownBits(I.getOperand(0), Q.DL, /*Depth*/ 0, Q.AC, Q.CxtI, Q.DT);
+  KnownBits KnownAmt = computeKnownBits(I.getOperand(0), /* Depth */ 0, Q);
   bool Changed = false;
 
   if (I.getOpcode() == Instruction::Shl) {

@llvmbot
Copy link
Member

llvmbot commented Dec 3, 2023

@llvm/pr-subscribers-llvm-transforms

Author: Yingwei Zheng (dtcxzyw)

Changes

This patch passes SimplifyQuery to computeKnownBits directly in InstSimplify and InstCombine.
As the DomConditionCache in #73662 is only used in InstCombine, it is inconvenient to introduce a new argument DC to computeKnownBits.

#74242 will be fixed by this patch and #73662.


Full diff: https://github.com/llvm/llvm-project/pull/74246.diff

2 Files Affected:

  • (modified) llvm/lib/Analysis/InstructionSimplify.cpp (+19-21)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp (+2-4)
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index cef9f6ec179ba..2a45acf63aa2c 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -811,7 +811,7 @@ static Value *simplifySubInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW,
     if (IsNUW)
       return Constant::getNullValue(Op0->getType());
 
-    KnownBits Known = computeKnownBits(Op1, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+    KnownBits Known = computeKnownBits(Op1, /* Depth */ 0, Q);
     if (Known.Zero.isMaxSignedValue()) {
       // Op1 is either 0 or the minimum signed value. If the sub is NSW, then
       // Op1 must be 0 because negating the minimum signed value is undefined.
@@ -1063,7 +1063,7 @@ static bool isDivZero(Value *X, Value *Y, const SimplifyQuery &Q,
   //      ("computeConstantRangeIncludingKnownBits")?
   const APInt *C;
   if (match(Y, m_APInt(C)) &&
-      computeKnownBits(X, Q.DL, 0, Q.AC, Q.CxtI, Q.DT).getMaxValue().ult(*C))
+      computeKnownBits(X, /* Depth */ 0, Q).getMaxValue().ult(*C))
     return true;
 
   // Try again for any divisor:
@@ -1125,8 +1125,7 @@ static Value *simplifyDivRem(Instruction::BinaryOps Opcode, Value *Op0,
   if (Op0 == Op1)
     return IsDiv ? ConstantInt::get(Ty, 1) : Constant::getNullValue(Ty);
 
-
-  KnownBits Known = computeKnownBits(Op1, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+  KnownBits Known = computeKnownBits(Op1, /* Depth */ 0, Q);
   // X / 0 -> poison
   // X % 0 -> poison
   // If the divisor is known to be zero, just return poison. This can happen in
@@ -1195,7 +1194,7 @@ static Value *simplifyDiv(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1,
   // less trailing zeros, then the result must be poison.
   const APInt *DivC;
   if (IsExact && match(Op1, m_APInt(DivC)) && DivC->countr_zero()) {
-    KnownBits KnownOp0 = computeKnownBits(Op0, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+    KnownBits KnownOp0 = computeKnownBits(Op0, /* Depth */ 0, Q);
     if (KnownOp0.countMaxTrailingZeros() < DivC->countr_zero())
       return PoisonValue::get(Op0->getType());
   }
@@ -1355,7 +1354,7 @@ static Value *simplifyShift(Instruction::BinaryOps Opcode, Value *Op0,
 
   // If any bits in the shift amount make that value greater than or equal to
   // the number of bits in the type, the shift is undefined.
-  KnownBits KnownAmt = computeKnownBits(Op1, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+  KnownBits KnownAmt = computeKnownBits(Op1, /* Depth */ 0, Q);
   if (KnownAmt.getMinValue().uge(KnownAmt.getBitWidth()))
     return PoisonValue::get(Op0->getType());
 
@@ -1368,7 +1367,7 @@ static Value *simplifyShift(Instruction::BinaryOps Opcode, Value *Op0,
   // Check for nsw shl leading to a poison value.
   if (IsNSW) {
     assert(Opcode == Instruction::Shl && "Expected shl for nsw instruction");
-    KnownBits KnownVal = computeKnownBits(Op0, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+    KnownBits KnownVal = computeKnownBits(Op0, /* Depth */ 0, Q);
     KnownBits KnownShl = KnownBits::shl(KnownVal, KnownAmt);
 
     if (KnownVal.Zero.isSignBitSet())
@@ -1404,8 +1403,7 @@ static Value *simplifyRightShift(Instruction::BinaryOps Opcode, Value *Op0,
   // The low bit cannot be shifted out of an exact shift if it is set.
   // TODO: Generalize by counting trailing zeros (see fold for exact division).
   if (IsExact) {
-    KnownBits Op0Known =
-        computeKnownBits(Op0, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT);
+    KnownBits Op0Known = computeKnownBits(Op0, /* Depth */ 0, Q);
     if (Op0Known.One[0])
       return Op0;
   }
@@ -1477,7 +1475,7 @@ static Value *simplifyLShrInst(Value *Op0, Value *Op1, bool IsExact,
   if (Q.IIQ.UseInstrInfo && match(Op1, m_APInt(ShRAmt)) &&
       match(Op0, m_c_Or(m_NUWShl(m_Value(X), m_APInt(ShLAmt)), m_Value(Y))) &&
       *ShRAmt == *ShLAmt) {
-    const KnownBits YKnown = computeKnownBits(Y, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+    const KnownBits YKnown = computeKnownBits(Y, /* Depth */ 0, Q);
     const unsigned EffWidthY = YKnown.countMaxActiveBits();
     if (ShRAmt->uge(EffWidthY))
       return X;
@@ -2105,7 +2103,7 @@ static Value *simplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,
       match(Op0, m_Add(m_Value(Shift), m_AllOnes())) &&
       isKnownToBeAPowerOfTwo(Shift, Q.DL, /*OrZero*/ false, 0, Q.AC, Q.CxtI,
                              Q.DT)) {
-    KnownBits Known = computeKnownBits(Shift, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+    KnownBits Known = computeKnownBits(Shift, /* Depth */ 0, Q);
     // Use getActiveBits() to make use of the additional power of two knowledge
     if (PowerC->getActiveBits() >= Known.getMaxValue().getActiveBits())
       return ConstantInt::getNullValue(Op1->getType());
@@ -2169,10 +2167,10 @@ static Value *simplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,
                         m_Value(Y)))) {
     const unsigned Width = Op0->getType()->getScalarSizeInBits();
     const unsigned ShftCnt = ShAmt->getLimitedValue(Width);
-    const KnownBits YKnown = computeKnownBits(Y, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+    const KnownBits YKnown = computeKnownBits(Y, /* Depth */ 0, Q);
     const unsigned EffWidthY = YKnown.countMaxActiveBits();
     if (EffWidthY <= ShftCnt) {
-      const KnownBits XKnown = computeKnownBits(X, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+      const KnownBits XKnown = computeKnownBits(X, /* Depth */ 0, Q);
       const unsigned EffWidthX = XKnown.countMaxActiveBits();
       const APInt EffBitsY = APInt::getLowBitsSet(Width, EffWidthY);
       const APInt EffBitsX = APInt::getLowBitsSet(Width, EffWidthX) << ShftCnt;
@@ -2968,7 +2966,7 @@ static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS,
       return getTrue(ITy);
     break;
   case ICmpInst::ICMP_SLT: {
-    KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+    KnownBits LHSKnown = computeKnownBits(LHS, /* Depth */ 0, Q);
     if (LHSKnown.isNegative())
       return getTrue(ITy);
     if (LHSKnown.isNonNegative())
@@ -2976,7 +2974,7 @@ static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS,
     break;
   }
   case ICmpInst::ICMP_SLE: {
-    KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+    KnownBits LHSKnown = computeKnownBits(LHS, /* Depth */ 0, Q);
     if (LHSKnown.isNegative())
       return getTrue(ITy);
     if (LHSKnown.isNonNegative() &&
@@ -2985,7 +2983,7 @@ static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS,
     break;
   }
   case ICmpInst::ICMP_SGE: {
-    KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+    KnownBits LHSKnown = computeKnownBits(LHS, /* Depth */ 0, Q);
     if (LHSKnown.isNegative())
       return getFalse(ITy);
     if (LHSKnown.isNonNegative())
@@ -2993,7 +2991,7 @@ static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS,
     break;
   }
   case ICmpInst::ICMP_SGT: {
-    KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+    KnownBits LHSKnown = computeKnownBits(LHS, /* Depth */ 0, Q);
     if (LHSKnown.isNegative())
       return getFalse(ITy);
     if (LHSKnown.isNonNegative() &&
@@ -3070,8 +3068,8 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
       return getTrue(ITy);
 
     if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGE) {
-      KnownBits RHSKnown = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
-      KnownBits YKnown = computeKnownBits(Y, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+      KnownBits RHSKnown = computeKnownBits(RHS, /* Depth */ 0, Q);
+      KnownBits YKnown = computeKnownBits(Y, /* Depth */ 0, Q);
       if (RHSKnown.isNonNegative() && YKnown.isNegative())
         return Pred == ICmpInst::ICMP_SLT ? getTrue(ITy) : getFalse(ITy);
       if (RHSKnown.isNegative() || YKnown.isNonNegative())
@@ -3094,7 +3092,7 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
       break;
     case ICmpInst::ICMP_SGT:
     case ICmpInst::ICMP_SGE: {
-      KnownBits Known = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+      KnownBits Known = computeKnownBits(RHS, /* Depth */ 0, Q);
       if (!Known.isNonNegative())
         break;
       [[fallthrough]];
@@ -3105,7 +3103,7 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
       return getFalse(ITy);
     case ICmpInst::ICMP_SLT:
     case ICmpInst::ICMP_SLE: {
-      KnownBits Known = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+      KnownBits Known = computeKnownBits(RHS, /* Depth */ 0, Q);
       if (!Known.isNonNegative())
         break;
       [[fallthrough]];
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index c72eb0f74de8e..b7958978c450c 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -962,15 +962,13 @@ static bool setShiftFlags(BinaryOperator &I, const SimplifyQuery &Q) {
   }
 
   // Compute what we know about shift count.
-  KnownBits KnownCnt =
-      computeKnownBits(I.getOperand(1), Q.DL, /*Depth*/ 0, Q.AC, Q.CxtI, Q.DT);
+  KnownBits KnownCnt = computeKnownBits(I.getOperand(1), /* Depth */ 0, Q);
   unsigned BitWidth = KnownCnt.getBitWidth();
   // Since shift produces a poison value if RHS is equal to or larger than the
   // bit width, we can safely assume that RHS is less than the bit width.
   uint64_t MaxCnt = KnownCnt.getMaxValue().getLimitedValue(BitWidth - 1);
 
-  KnownBits KnownAmt =
-      computeKnownBits(I.getOperand(0), Q.DL, /*Depth*/ 0, Q.AC, Q.CxtI, Q.DT);
+  KnownBits KnownAmt = computeKnownBits(I.getOperand(0), /* Depth */ 0, Q);
   bool Changed = false;
 
   if (I.getOpcode() == Instruction::Shl) {

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@dtcxzyw dtcxzyw merged commit 741975d into llvm:main Dec 3, 2023
@dtcxzyw dtcxzyw deleted the perf/compute-known-bits branch December 3, 2023 18:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llvm:analysis Includes value tracking, cost tables and constant folding llvm:transforms
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants