diff --git a/llvm/include/llvm/ADT/APInt.h b/llvm/include/llvm/ADT/APInt.h index 6f2f25548cc84b..3a9fb6032f4cb2 100644 --- a/llvm/include/llvm/ADT/APInt.h +++ b/llvm/include/llvm/ADT/APInt.h @@ -108,9 +108,19 @@ class [[nodiscard]] APInt { /// \param isSigned how to treat signedness of val APInt(unsigned numBits, uint64_t val, bool isSigned = false) : BitWidth(numBits) { + if (BitWidth == 0) { + assert(val == 0 && "Value must be zero for 0-bit APInt"); + } else if (isSigned) { + assert(llvm::isIntN(BitWidth, val) && + "Value is not an N-bit signed value"); + } else { + assert(llvm::isUIntN(BitWidth, val) && + "Value is not an N-bit unsigned value"); + } if (isSingleWord()) { U.VAL = val; - clearUnusedBits(); + if (isSigned) + clearUnusedBits(); } else { initSlowCase(val, isSigned); } diff --git a/llvm/lib/Support/APInt.cpp b/llvm/lib/Support/APInt.cpp index 05b1526da95ff7..313c874f9caf1d 100644 --- a/llvm/lib/Support/APInt.cpp +++ b/llvm/lib/Support/APInt.cpp @@ -234,7 +234,7 @@ APInt& APInt::operator-=(uint64_t RHS) { APInt APInt::operator*(const APInt& RHS) const { assert(BitWidth == RHS.BitWidth && "Bit widths must be the same"); if (isSingleWord()) - return APInt(BitWidth, U.VAL * RHS.U.VAL); + return APInt(BitWidth, (U.VAL * RHS.U.VAL) & maxUIntN(BitWidth)); APInt Result(getMemory(getNumWords()), getBitWidth()); tcMultiply(Result.U.pVal, U.pVal, RHS.U.pVal, getNumWords()); @@ -907,7 +907,7 @@ APInt APInt::trunc(unsigned width) const { assert(width <= BitWidth && "Invalid APInt Truncate request"); if (width <= APINT_BITS_PER_WORD) - return APInt(width, getRawData()[0]); + return APInt(width, getRawData()[0] & maxUIntN(width)); if (width == BitWidth) return *this; @@ -955,7 +955,7 @@ APInt APInt::sext(unsigned Width) const { assert(Width >= BitWidth && "Invalid APInt SignExtend request"); if (Width <= APINT_BITS_PER_WORD) - return APInt(Width, SignExtend64(U.VAL, BitWidth)); + return APInt(Width, SignExtend64(U.VAL, BitWidth), /*isSigned*/ true); if (Width == BitWidth) return *this; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index d295853798b808..b3b425cc1ac42b 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -308,7 +308,7 @@ Instruction *InstCombinerImpl::foldCmpLoadFromIndexedGlobal( DL.getTypeAllocSize(Init->getType()->getArrayElementType()); auto MaskIdx = [&](Value *Idx) { if (!GEP->isInBounds() && llvm::countr_zero(ElementSize) != 0) { - Value *Mask = ConstantInt::get(Idx->getType(), -1); + Value *Mask = Constant::getAllOnesValue(Idx->getType()); Mask = Builder.CreateLShr(Mask, llvm::countr_zero(ElementSize)); Idx = Builder.CreateAnd(Idx, Mask); } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index 453e4d788705f3..efc3b99e836060 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -665,11 +665,11 @@ static Value *foldSelectICmpLshrAshr(const ICmpInst *IC, Value *TrueVal, Value *X, *Y; unsigned Bitwidth = CmpRHS->getType()->getScalarSizeInBits(); if ((Pred != ICmpInst::ICMP_SGT || - !match(CmpRHS, - m_SpecificInt_ICMP(ICmpInst::ICMP_SGE, APInt(Bitwidth, -1)))) && + !match(CmpRHS, m_SpecificInt_ICMP(ICmpInst::ICMP_SGE, + APInt::getAllOnes(Bitwidth)))) && (Pred != ICmpInst::ICMP_SLT || - !match(CmpRHS, - m_SpecificInt_ICMP(ICmpInst::ICMP_SGE, APInt(Bitwidth, 0))))) + !match(CmpRHS, m_SpecificInt_ICMP(ICmpInst::ICMP_SGE, + APInt::getZero(Bitwidth))))) return nullptr; // Canonicalize so that ashr is in FalseVal. diff --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp index 52eef9ab58a4d9..f24d7a1295c10b 100644 --- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -553,7 +553,8 @@ Value *LibCallSimplifier::optimizeStrCmp(CallInst *CI, IRBuilderBase &B) { // strcmp(x, y) -> cnst (if both x and y are constant strings) if (HasStr1 && HasStr2) return ConstantInt::get(CI->getType(), - std::clamp(Str1.compare(Str2), -1, 1)); + std::clamp(Str1.compare(Str2), -1, 1), + /*isSigned*/ true); if (HasStr1 && Str1.empty()) // strcmp("", x) -> -*x return B.CreateNeg(B.CreateZExt( @@ -638,7 +639,8 @@ Value *LibCallSimplifier::optimizeStrNCmp(CallInst *CI, IRBuilderBase &B) { StringRef SubStr1 = substr(Str1, Length); StringRef SubStr2 = substr(Str2, Length); return ConstantInt::get(CI->getType(), - std::clamp(SubStr1.compare(SubStr2), -1, 1)); + std::clamp(SubStr1.compare(SubStr2), -1, 1), + /*isSigned*/ true); } if (HasStr1 && Str1.empty()) // strncmp("", x, n) -> -*x @@ -1518,7 +1520,7 @@ static Value *optimizeMemCmpVarSize(CallInst *CI, Value *LHS, Value *RHS, int IRes = UChar(LStr[Pos]) < UChar(RStr[Pos]) ? -1 : 1; Value *MaxSize = ConstantInt::get(Size->getType(), Pos); Value *Cmp = B.CreateICmp(ICmpInst::ICMP_ULE, Size, MaxSize); - Value *Res = ConstantInt::get(CI->getType(), IRes); + Value *Res = ConstantInt::get(CI->getType(), IRes, /*isSigned*/ true); return B.CreateSelect(Cmp, Zero, Res); }