Skip to content

Commit

Permalink
[APInt] Assert correct values in APInt constructor
Browse files Browse the repository at this point in the history
If the uint64_t constructor is used, assert that the value is
actuall a signed or unsigned N-bit integer depending on whether
the isSigned flag is set.

Currently, we allow values to be silently truncated, which is
a constant source of subtle bugs -- a particularly common mistake
is to create -1 values without setting the isSigned flag, which
will work fine for all common bit widths (<= 64-bit) and miscompile
for larger integers.
  • Loading branch information
nikic committed Jun 20, 2024
1 parent 7f09aa9 commit f5cc0bf
Show file tree
Hide file tree
Showing 37 changed files with 388 additions and 348 deletions.
4 changes: 3 additions & 1 deletion llvm/include/llvm/ADT/APFixedPoint.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,9 @@ class APFixedPoint {
}

APFixedPoint(uint64_t Val, const FixedPointSemantics &Sema)
: APFixedPoint(APInt(Sema.getWidth(), Val, Sema.isSigned()), Sema) {}
: APFixedPoint(APInt(Sema.getWidth(), Val, Sema.isSigned(),
/*implicitTrunc*/ true),
Sema) {}

// Zero initialization.
APFixedPoint(const FixedPointSemantics &Sema) : APFixedPoint(0, Sema) {}
Expand Down
19 changes: 17 additions & 2 deletions llvm/include/llvm/ADT/APInt.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,26 @@ class [[nodiscard]] APInt {
/// \param numBits the bit width of the constructed APInt
/// \param val the initial value of the APInt
/// \param isSigned how to treat signedness of val
APInt(unsigned numBits, uint64_t val, bool isSigned = false)
/// \param implicitTrunc allow implicit truncation of non-zero/sign bits of
/// val beyond the range of numBits
APInt(unsigned numBits, uint64_t val, bool isSigned = false,
bool implicitTrunc = false)
: BitWidth(numBits) {
if (!implicitTrunc) {
if (BitWidth == 0) {
assert(val == 0 && "Value must be zero for 0-bit APInt");
} else if (isSigned) {
assert(llvm::isIntN(BitWidth, val) &&
"Value is not an N-bit signed value");
} else {
assert(llvm::isUIntN(BitWidth, val) &&
"Value is not an N-bit unsigned value");
}
}
if (isSingleWord()) {
U.VAL = val;
clearUnusedBits();
if (implicitTrunc || isSigned)
clearUnusedBits();
} else {
initSlowCase(val, isSigned);
}
Expand Down
8 changes: 5 additions & 3 deletions llvm/lib/Analysis/ConstantFolding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -889,7 +889,8 @@ Constant *SymbolicallyEvaluateGEP(const GEPOperator *GEP,
APInt Offset = APInt(
BitWidth,
DL.getIndexedOffsetInType(
SrcElemTy, ArrayRef((Value *const *)Ops.data() + 1, Ops.size() - 1)));
SrcElemTy, ArrayRef((Value *const *)Ops.data() + 1, Ops.size() - 1)),
/*isSigned*/ true, /*implicitTrunc*/ true);

std::optional<ConstantRange> InRange = GEP->getInRange();
if (InRange)
Expand Down Expand Up @@ -3401,8 +3402,9 @@ ConstantFoldScalarFrexpCall(Constant *Op, Type *IntTy) {

// The exponent is an "unspecified value" for inf/nan. We use zero to avoid
// using undef.
Constant *Result1 = FrexpMant.isFinite() ? ConstantInt::get(IntTy, FrexpExp)
: ConstantInt::getNullValue(IntTy);
Constant *Result1 = FrexpMant.isFinite()
? ConstantInt::getSigned(IntTy, FrexpExp)
: ConstantInt::getNullValue(IntTy);
return {Result0, Result1};
}

Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Analysis/MemoryBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,7 @@ Value *llvm::lowerObjectSizeCall(
if (!MustSucceed)
return nullptr;

return ConstantInt::get(ResultType, MaxVal ? -1ULL : 0);
return ConstantInt::get(ResultType, MaxVal ? -1ULL : 0, true);
}

STATISTIC(ObjectVisitorArgument,
Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/Analysis/ScalarEvolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1460,7 +1460,7 @@ bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,

APInt StartAI = StartC->getAPInt();

for (unsigned Delta : {-2, -1, 1, 2}) {
for (int Delta : {-2, -1, 1, 2}) {
const SCEV *PreStart = getConstant(StartAI - Delta);

FoldingSetNodeID ID;
Expand All @@ -1475,7 +1475,7 @@ bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
// Give up if we don't already have the add recurrence we need because
// actually constructing an add recurrence is relatively expensive.
if (PreAR && PreAR->getNoWrapFlags(WrapType)) { // proves (2)
const SCEV *DeltaS = getConstant(StartC->getType(), Delta);
const SCEV *DeltaS = getConstant(StartC->getType(), Delta, true);
ICmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE;
const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(
DeltaS, &Pred, this);
Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/Analysis/ValueTracking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9303,7 +9303,7 @@ static ConstantRange getRangeForIntrinsic(const IntrinsicInst &II) {
case Intrinsic::cttz:
// Maximum of set/clear bits is the bit width.
return ConstantRange::getNonEmpty(APInt::getZero(Width),
APInt(Width, Width + 1));
APInt(Width, Width) + 1);
case Intrinsic::uadd_sat:
// uadd.sat(x, C) produces [C, UINT_MAX].
if (match(II.getOperand(0), m_APInt(C)) ||
Expand Down Expand Up @@ -9454,7 +9454,7 @@ static void setLimitForFPToI(const Instruction *I, APInt &Lower, APInt &Upper) {
if (!I->getOperand(0)->getType()->getScalarType()->isHalfTy())
return;
if (isa<FPToSIInst>(I) && BitWidth >= 17) {
Lower = APInt(BitWidth, -65504);
Lower = APInt(BitWidth, -65504, true);
Upper = APInt(BitWidth, 65505);
}

Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Bitcode/Reader/BitcodeReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3147,7 +3147,7 @@ Error BitcodeReader::parseConstants() {
case bitc::CST_CODE_INTEGER: // INTEGER: [intval]
if (!CurTy->isIntOrIntVectorTy() || Record.empty())
return error("Invalid integer const record");
V = ConstantInt::get(CurTy, decodeSignRotatedValue(Record[0]));
V = ConstantInt::getSigned(CurTy, decodeSignRotatedValue(Record[0]));
break;
case bitc::CST_CODE_WIDE_INTEGER: {// WIDE_INTEGER: [n x intval]
if (!CurTy->isIntOrIntVectorTy() || Record.empty())
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/CodeGen/CodeGenPrepare.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1643,7 +1643,7 @@ static bool matchUAddWithOverflowConstantEdgeCases(CmpInst *Cmp,
if (Pred == ICmpInst::ICMP_EQ && match(B, m_AllOnes()))
B = ConstantInt::get(B->getType(), 1);
else if (Pred == ICmpInst::ICMP_NE && match(B, m_ZeroInt()))
B = ConstantInt::get(B->getType(), -1);
B = Constant::getAllOnesValue(B->getType());
else
return false;

Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/CodeGen/ExpandMemCmp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ void MemCmpExpansion::emitMemCmpResultBlock() {
ResBlock.PhiSrc2);

Value *Res =
Builder.CreateSelect(Cmp, ConstantInt::get(Builder.getInt32Ty(), -1),
Builder.CreateSelect(Cmp, Constant::getAllOnesValue(Builder.getInt32Ty()),
ConstantInt::get(Builder.getInt32Ty(), 1));

PhiRes->addIncoming(Res, ResBlock.BB);
Expand Down
5 changes: 4 additions & 1 deletion llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1625,7 +1625,10 @@ SDValue SelectionDAG::getConstant(uint64_t Val, const SDLoc &DL, EVT VT,
assert((EltVT.getSizeInBits() >= 64 ||
(uint64_t)((int64_t)Val >> EltVT.getSizeInBits()) + 1 < 2) &&
"getConstant with a uint64_t value that doesn't fit in the type!");
return getConstant(APInt(EltVT.getSizeInBits(), Val), DL, VT, isT, isO);
// TODO: Avoid implicit trunc?
return getConstant(
APInt(EltVT.getSizeInBits(), Val, false, /*implicitTrunc*/ true), DL, VT,
isT, isO);
}

SDValue SelectionDAG::getConstant(const APInt &Val, const SDLoc &DL, EVT VT,
Expand Down
4 changes: 3 additions & 1 deletion llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2171,7 +2171,9 @@ ScheduleDAGSDNodes *SelectionDAGISel::CreateScheduler() {
bool SelectionDAGISel::CheckAndMask(SDValue LHS, ConstantSDNode *RHS,
int64_t DesiredMaskS) const {
const APInt &ActualMask = RHS->getAPIntValue();
const APInt &DesiredMask = APInt(LHS.getValueSizeInBits(), DesiredMaskS);
// TODO: Avoid implicit trunc?
const APInt &DesiredMask = APInt(LHS.getValueSizeInBits(), DesiredMaskS,
false, /*implicitTrunc*/ true);

// If the actual mask exactly matches, success!
if (ActualMask == DesiredMask)
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4656,7 +4656,7 @@ OpenMPIRBuilder::createTargetInit(const LocationDescription &Loc, bool IsSPMD,
Builder.CreateCall(Fn, {KernelEnvironment, KernelLaunchEnvironment});

Value *ExecUserCode = Builder.CreateICmpEQ(
ThreadKind, ConstantInt::get(ThreadKind->getType(), -1),
ThreadKind, Constant::getAllOnesValue(ThreadKind->getType()),
"exec_user_code");

// ThreadKind = __kmpc_target_init(...)
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/FuzzMutate/OpDescriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ void fuzzerop::makeConstantsWithType(Type *T, std::vector<Constant *> &Cs) {
uint64_t W = IntTy->getBitWidth();
Cs.push_back(ConstantInt::get(IntTy, 0));
Cs.push_back(ConstantInt::get(IntTy, 1));
Cs.push_back(ConstantInt::get(IntTy, 42));
Cs.push_back(ConstantInt::get(
IntTy, APInt(W, 42, /*isSigned*/ false, /*implicitTrunc*/ true)));
Cs.push_back(ConstantInt::get(IntTy, APInt::getMaxValue(W)));
Cs.push_back(ConstantInt::get(IntTy, APInt::getMinValue(W)));
Cs.push_back(ConstantInt::get(IntTy, APInt::getSignedMaxValue(W)));
Expand Down
6 changes: 3 additions & 3 deletions llvm/lib/IR/ConstantRange.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1797,7 +1797,7 @@ ConstantRange ConstantRange::ctlz(bool ZeroIsPoison) const {
// Zero is either safe or not in the range. The output range is composed by
// the result of countLeadingZero of the two extremes.
return getNonEmpty(APInt(getBitWidth(), getUnsignedMax().countl_zero()),
APInt(getBitWidth(), getUnsignedMin().countl_zero() + 1));
APInt(getBitWidth(), getUnsignedMin().countl_zero()) + 1);
}

static ConstantRange getUnsignedCountTrailingZerosRange(const APInt &Lower,
Expand Down Expand Up @@ -1856,7 +1856,7 @@ ConstantRange ConstantRange::cttz(bool ZeroIsPoison) const {
}

if (isFullSet())
return getNonEmpty(Zero, APInt(BitWidth, BitWidth + 1));
return getNonEmpty(Zero, APInt(BitWidth, BitWidth) + 1);
if (!isWrappedSet())
return getUnsignedCountTrailingZerosRange(Lower, Upper);
// The range is wrapped. We decompose it into two ranges, [0, Upper) and
Expand Down Expand Up @@ -1901,7 +1901,7 @@ ConstantRange ConstantRange::ctpop() const {
unsigned BitWidth = getBitWidth();
APInt Zero = APInt::getZero(BitWidth);
if (isFullSet())
return getNonEmpty(Zero, APInt(BitWidth, BitWidth + 1));
return getNonEmpty(Zero, APInt(BitWidth, BitWidth) + 1);
if (!isWrappedSet())
return getUnsignedPopCountRange(Lower, Upper);
// The range is wrapped. We decompose it into two ranges, [0, Upper) and
Expand Down
14 changes: 9 additions & 5 deletions llvm/lib/Support/APInt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,8 @@ APInt& APInt::operator-=(uint64_t RHS) {
APInt APInt::operator*(const APInt& RHS) const {
assert(BitWidth == RHS.BitWidth && "Bit widths must be the same");
if (isSingleWord())
return APInt(BitWidth, U.VAL * RHS.U.VAL);
return APInt(BitWidth, U.VAL * RHS.U.VAL, /*isSigned*/ false,
/*implicitTrunc*/ true);

APInt Result(getMemory(getNumWords()), getBitWidth());
tcMultiply(Result.U.pVal, U.pVal, RHS.U.pVal, getNumWords());
Expand Down Expand Up @@ -455,15 +456,17 @@ APInt APInt::extractBits(unsigned numBits, unsigned bitPosition) const {
"Illegal bit extraction");

if (isSingleWord())
return APInt(numBits, U.VAL >> bitPosition);
return APInt(numBits, U.VAL >> bitPosition, /*isSigned*/ false,
/*implicitTrunc*/ true);

unsigned loBit = whichBit(bitPosition);
unsigned loWord = whichWord(bitPosition);
unsigned hiWord = whichWord(bitPosition + numBits - 1);

// Single word result extracting bits from a single word source.
if (loWord == hiWord)
return APInt(numBits, U.pVal[loWord] >> loBit);
return APInt(numBits, U.pVal[loWord] >> loBit, /*isSigned*/ false,
/*implicitTrunc*/ true);

// Extracting bits that start on a source word boundary can be done
// as a fast memory copy.
Expand Down Expand Up @@ -907,7 +910,8 @@ APInt APInt::trunc(unsigned width) const {
assert(width <= BitWidth && "Invalid APInt Truncate request");

if (width <= APINT_BITS_PER_WORD)
return APInt(width, getRawData()[0]);
return APInt(width, getRawData()[0], /*isSigned*/ false,
/*implicitTrunc*/ true);

if (width == BitWidth)
return *this;
Expand Down Expand Up @@ -955,7 +959,7 @@ APInt APInt::sext(unsigned Width) const {
assert(Width >= BitWidth && "Invalid APInt SignExtend request");

if (Width <= APINT_BITS_PER_WORD)
return APInt(Width, SignExtend64(U.VAL, BitWidth));
return APInt(Width, SignExtend64(U.VAL, BitWidth), /*isSigned*/ true);

if (Width == BitWidth)
return *this;
Expand Down
13 changes: 8 additions & 5 deletions llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1669,17 +1669,20 @@ Constant *DevirtModule::importConstant(VTableSlot Slot, ArrayRef<uint64_t> Args,
if (GV->hasMetadata(LLVMContext::MD_absolute_symbol))
return C;

auto SetAbsRange = [&](uint64_t Min, uint64_t Max) {
auto SetAbsRange = [&](const APInt &Min, const APInt &Max) {
auto *MinC = ConstantAsMetadata::get(ConstantInt::get(IntPtrTy, Min));
auto *MaxC = ConstantAsMetadata::get(ConstantInt::get(IntPtrTy, Max));
GV->setMetadata(LLVMContext::MD_absolute_symbol,
MDNode::get(M.getContext(), {MinC, MaxC}));
};
unsigned AbsWidth = IntTy->getBitWidth();
if (AbsWidth == IntPtrTy->getBitWidth())
SetAbsRange(~0ull, ~0ull); // Full set.
unsigned IntPtrWidth = IntPtrTy->getBitWidth();
if (AbsWidth == IntPtrWidth)
// Full set.
SetAbsRange(APInt::getAllOnes(IntPtrWidth), APInt::getAllOnes(IntPtrWidth));
else
SetAbsRange(0, 1ull << AbsWidth);
SetAbsRange(APInt::getZero(IntPtrWidth),
APInt::getOneBitSet(IntPtrWidth, AbsWidth));
return C;
}

Expand Down Expand Up @@ -1884,7 +1887,7 @@ bool DevirtModule::tryVirtualConstProp(
}

// Rewrite each call to a load from OffsetByte/OffsetBit.
Constant *ByteConst = ConstantInt::get(Int32Ty, OffsetByte);
Constant *ByteConst = ConstantInt::get(Int32Ty, OffsetByte, true);
Constant *BitConst = ConstantInt::get(Int8Ty, 1ULL << OffsetBit);
applyVirtualConstProp(CSByConstantArg.second,
TargetsForSlot[0].Fn->getName(), ByteConst, BitConst);
Expand Down
6 changes: 2 additions & 4 deletions llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,13 +260,11 @@ Instruction *InstCombinerImpl::SimplifyAnyMemSet(AnyMemSetInst *MI) {

// memset(s,c,n) -> store s, c (for n=1,2,4,8)
if (Len <= 8 && isPowerOf2_32((uint32_t)Len)) {
Type *ITy = IntegerType::get(MI->getContext(), Len*8); // n=1 -> i8.

Value *Dest = MI->getDest();

// Extract the fill value and store.
const uint64_t Fill = FillC->getZExtValue()*0x0101010101010101ULL;
Constant *FillVal = ConstantInt::get(ITy, Fill);
Constant *FillVal = ConstantInt::get(
MI->getContext(), APInt::getSplat(Len * 8, FillC->getValue()));
StoreInst *S = Builder.CreateStore(FillVal, Dest, MI->isVolatile());
S->copyMetadata(*MI, LLVMContext::MD_DIAssignID);
auto replaceOpForAssignmentMarkers = [FillC, FillVal](auto *DbgAssign) {
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ Instruction *InstCombinerImpl::foldCmpLoadFromIndexedGlobal(
DL.getTypeAllocSize(Init->getType()->getArrayElementType());
auto MaskIdx = [&](Value *Idx) {
if (!GEP->isInBounds() && llvm::countr_zero(ElementSize) != 0) {
Value *Mask = ConstantInt::get(Idx->getType(), -1);
Value *Mask = Constant::getAllOnesValue(Idx->getType());
Mask = Builder.CreateLShr(Mask, llvm::countr_zero(ElementSize));
Idx = Builder.CreateAnd(Idx, Mask);
}
Expand Down
8 changes: 4 additions & 4 deletions llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -675,11 +675,11 @@ static Value *foldSelectICmpLshrAshr(const ICmpInst *IC, Value *TrueVal,
Value *X, *Y;
unsigned Bitwidth = CmpRHS->getType()->getScalarSizeInBits();
if ((Pred != ICmpInst::ICMP_SGT ||
!match(CmpRHS,
m_SpecificInt_ICMP(ICmpInst::ICMP_SGE, APInt(Bitwidth, -1)))) &&
!match(CmpRHS, m_SpecificInt_ICMP(ICmpInst::ICMP_SGE,
APInt::getAllOnes(Bitwidth)))) &&
(Pred != ICmpInst::ICMP_SLT ||
!match(CmpRHS,
m_SpecificInt_ICMP(ICmpInst::ICMP_SGE, APInt(Bitwidth, 0)))))
!match(CmpRHS, m_SpecificInt_ICMP(ICmpInst::ICMP_SGE,
APInt::getZero(Bitwidth)))))
return nullptr;

// Canonicalize so that ashr is in FalseVal.
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Transforms/Scalar/ConstraintElimination.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -874,7 +874,7 @@ void ConstraintInfo::transferToOtherSystem(
addFact(CmpInst::ICMP_ULT, A, B, NumIn, NumOut, DFSInStack);
break;
case CmpInst::ICMP_SGT: {
if (doesHold(CmpInst::ICMP_SGE, B, ConstantInt::get(B->getType(), -1)))
if (doesHold(CmpInst::ICMP_SGE, B, Constant::getAllOnesValue(B->getType())))
addFact(CmpInst::ICMP_UGE, A, ConstantInt::get(B->getType(), 0), NumIn,
NumOut, DFSInStack);
if (IsKnownNonNegative(B))
Expand Down
14 changes: 7 additions & 7 deletions llvm/lib/Transforms/Scalar/IndVarSimplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -357,19 +357,19 @@ bool IndVarSimplify::handleFloatingPointIV(Loop *L, PHINode *PN) {
// Insert new integer induction variable.
PHINode *NewPHI =
PHINode::Create(Int32Ty, 2, PN->getName() + ".int", PN->getIterator());
NewPHI->addIncoming(ConstantInt::get(Int32Ty, InitValue),
NewPHI->addIncoming(ConstantInt::getSigned(Int32Ty, InitValue),
PN->getIncomingBlock(IncomingEdge));
NewPHI->setDebugLoc(PN->getDebugLoc());

Instruction *NewAdd =
BinaryOperator::CreateAdd(NewPHI, ConstantInt::get(Int32Ty, IncValue),
Incr->getName() + ".int", Incr->getIterator());
Instruction *NewAdd = BinaryOperator::CreateAdd(
NewPHI, ConstantInt::getSigned(Int32Ty, IncValue),
Incr->getName() + ".int", Incr->getIterator());
NewAdd->setDebugLoc(Incr->getDebugLoc());
NewPHI->addIncoming(NewAdd, PN->getIncomingBlock(BackEdge));

ICmpInst *NewCompare =
new ICmpInst(TheBr->getIterator(), NewPred, NewAdd,
ConstantInt::get(Int32Ty, ExitValue), Compare->getName());
ICmpInst *NewCompare = new ICmpInst(
TheBr->getIterator(), NewPred, NewAdd,
ConstantInt::getSigned(Int32Ty, ExitValue), Compare->getName());
NewCompare->setDebugLoc(Compare->getDebugLoc());

// In the following deletions, PN may become dead and may be deleted.
Expand Down
Loading

0 comments on commit f5cc0bf

Please sign in to comment.