Skip to content

Commit f56c096

Browse files
committed
[APInt] Assert correct values in APInt constructor
If the uint64_t constructor is used, assert that the value is actuall a signed or unsigned N-bit integer depending on whether the isSigned flag is set. Currently, we allow values to be silently truncated, which is a constant source of subtle bugs -- a particularly common mistake is to create -1 values without setting the isSigned flag, which will work fine for all common bit widths (<= 64-bit) and miscompile for larger integers.
1 parent 62ae7d9 commit f56c096

File tree

5 files changed

+24
-12
lines changed

5 files changed

+24
-12
lines changed

llvm/include/llvm/ADT/APInt.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,19 @@ class [[nodiscard]] APInt {
108108
/// \param isSigned how to treat signedness of val
109109
APInt(unsigned numBits, uint64_t val, bool isSigned = false)
110110
: BitWidth(numBits) {
111+
if (BitWidth == 0) {
112+
assert(val == 0 && "Value must be zero for 0-bit APInt");
113+
} else if (isSigned) {
114+
assert(llvm::isIntN(BitWidth, val) &&
115+
"Value is not an N-bit signed value");
116+
} else {
117+
assert(llvm::isUIntN(BitWidth, val) &&
118+
"Value is not an N-bit unsigned value");
119+
}
111120
if (isSingleWord()) {
112121
U.VAL = val;
113-
clearUnusedBits();
122+
if (isSigned)
123+
clearUnusedBits();
114124
} else {
115125
initSlowCase(val, isSigned);
116126
}

llvm/lib/Support/APInt.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ APInt& APInt::operator-=(uint64_t RHS) {
234234
APInt APInt::operator*(const APInt& RHS) const {
235235
assert(BitWidth == RHS.BitWidth && "Bit widths must be the same");
236236
if (isSingleWord())
237-
return APInt(BitWidth, U.VAL * RHS.U.VAL);
237+
return APInt(BitWidth, (U.VAL * RHS.U.VAL) & maxUIntN(BitWidth));
238238

239239
APInt Result(getMemory(getNumWords()), getBitWidth());
240240
tcMultiply(Result.U.pVal, U.pVal, RHS.U.pVal, getNumWords());
@@ -907,7 +907,7 @@ APInt APInt::trunc(unsigned width) const {
907907
assert(width <= BitWidth && "Invalid APInt Truncate request");
908908

909909
if (width <= APINT_BITS_PER_WORD)
910-
return APInt(width, getRawData()[0]);
910+
return APInt(width, getRawData()[0] & maxUIntN(width));
911911

912912
if (width == BitWidth)
913913
return *this;
@@ -955,7 +955,7 @@ APInt APInt::sext(unsigned Width) const {
955955
assert(Width >= BitWidth && "Invalid APInt SignExtend request");
956956

957957
if (Width <= APINT_BITS_PER_WORD)
958-
return APInt(Width, SignExtend64(U.VAL, BitWidth));
958+
return APInt(Width, SignExtend64(U.VAL, BitWidth), /*isSigned*/ true);
959959

960960
if (Width == BitWidth)
961961
return *this;

llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ Instruction *InstCombinerImpl::foldCmpLoadFromIndexedGlobal(
308308
DL.getTypeAllocSize(Init->getType()->getArrayElementType());
309309
auto MaskIdx = [&](Value *Idx) {
310310
if (!GEP->isInBounds() && llvm::countr_zero(ElementSize) != 0) {
311-
Value *Mask = ConstantInt::get(Idx->getType(), -1);
311+
Value *Mask = Constant::getAllOnesValue(Idx->getType());
312312
Mask = Builder.CreateLShr(Mask, llvm::countr_zero(ElementSize));
313313
Idx = Builder.CreateAnd(Idx, Mask);
314314
}

llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -665,11 +665,11 @@ static Value *foldSelectICmpLshrAshr(const ICmpInst *IC, Value *TrueVal,
665665
Value *X, *Y;
666666
unsigned Bitwidth = CmpRHS->getType()->getScalarSizeInBits();
667667
if ((Pred != ICmpInst::ICMP_SGT ||
668-
!match(CmpRHS,
669-
m_SpecificInt_ICMP(ICmpInst::ICMP_SGE, APInt(Bitwidth, -1)))) &&
668+
!match(CmpRHS, m_SpecificInt_ICMP(ICmpInst::ICMP_SGE,
669+
APInt::getAllOnes(Bitwidth)))) &&
670670
(Pred != ICmpInst::ICMP_SLT ||
671-
!match(CmpRHS,
672-
m_SpecificInt_ICMP(ICmpInst::ICMP_SGE, APInt(Bitwidth, 0)))))
671+
!match(CmpRHS, m_SpecificInt_ICMP(ICmpInst::ICMP_SGE,
672+
APInt::getZero(Bitwidth)))))
673673
return nullptr;
674674

675675
// Canonicalize so that ashr is in FalseVal.

llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,8 @@ Value *LibCallSimplifier::optimizeStrCmp(CallInst *CI, IRBuilderBase &B) {
553553
// strcmp(x, y) -> cnst (if both x and y are constant strings)
554554
if (HasStr1 && HasStr2)
555555
return ConstantInt::get(CI->getType(),
556-
std::clamp(Str1.compare(Str2), -1, 1));
556+
std::clamp(Str1.compare(Str2), -1, 1),
557+
/*isSigned*/ true);
557558

558559
if (HasStr1 && Str1.empty()) // strcmp("", x) -> -*x
559560
return B.CreateNeg(B.CreateZExt(
@@ -638,7 +639,8 @@ Value *LibCallSimplifier::optimizeStrNCmp(CallInst *CI, IRBuilderBase &B) {
638639
StringRef SubStr1 = substr(Str1, Length);
639640
StringRef SubStr2 = substr(Str2, Length);
640641
return ConstantInt::get(CI->getType(),
641-
std::clamp(SubStr1.compare(SubStr2), -1, 1));
642+
std::clamp(SubStr1.compare(SubStr2), -1, 1),
643+
/*isSigned*/ true);
642644
}
643645

644646
if (HasStr1 && Str1.empty()) // strncmp("", x, n) -> -*x
@@ -1518,7 +1520,7 @@ static Value *optimizeMemCmpVarSize(CallInst *CI, Value *LHS, Value *RHS,
15181520
int IRes = UChar(LStr[Pos]) < UChar(RStr[Pos]) ? -1 : 1;
15191521
Value *MaxSize = ConstantInt::get(Size->getType(), Pos);
15201522
Value *Cmp = B.CreateICmp(ICmpInst::ICMP_ULE, Size, MaxSize);
1521-
Value *Res = ConstantInt::get(CI->getType(), IRes);
1523+
Value *Res = ConstantInt::get(CI->getType(), IRes, /*isSigned*/ true);
15221524
return B.CreateSelect(Cmp, Zero, Res);
15231525
}
15241526

0 commit comments

Comments
 (0)