Skip to content

[NFC][LLVM] Refactor IRBuilder::Create{VScale,ElementCount,TypeSize}. #142803

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions clang/lib/CodeGen/TargetBuiltins/ARM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4792,12 +4792,7 @@ Value *CodeGenFunction::EmitAArch64SVEBuiltinExpr(unsigned BuiltinID,
case SVE::BI__builtin_sve_svlen_u32:
case SVE::BI__builtin_sve_svlen_u64: {
SVETypeFlags TF(Builtin->TypeModifier);
auto VTy = cast<llvm::VectorType>(getSVEType(TF));
auto *NumEls =
llvm::ConstantInt::get(Ty, VTy->getElementCount().getKnownMinValue());

Function *F = CGM.getIntrinsic(Intrinsic::vscale, Ty);
return Builder.CreateMul(NumEls, Builder.CreateCall(F));
return Builder.CreateElementCount(Ty, getSVEType(TF)->getElementCount());
}

case SVE::BI__builtin_sve_svtbl2_u8:
Expand All @@ -4813,8 +4808,7 @@ Value *CodeGenFunction::EmitAArch64SVEBuiltinExpr(unsigned BuiltinID,
case SVE::BI__builtin_sve_svtbl2_f32:
case SVE::BI__builtin_sve_svtbl2_f64: {
SVETypeFlags TF(Builtin->TypeModifier);
auto VTy = cast<llvm::ScalableVectorType>(getSVEType(TF));
Function *F = CGM.getIntrinsic(Intrinsic::aarch64_sve_tbl2, VTy);
Function *F = CGM.getIntrinsic(Intrinsic::aarch64_sve_tbl2, getSVEType(TF));
return Builder.CreateCall(F, Ops);
}

Expand Down
11 changes: 6 additions & 5 deletions llvm/include/llvm/IR/IRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -945,17 +945,18 @@ class IRBuilderBase {
LLVM_ABI CallInst *CreateGCGetPointerOffset(Value *DerivedPtr,
const Twine &Name = "");

/// Create a call to llvm.vscale, multiplied by \p Scaling. The type of VScale
/// will be the same type as that of \p Scaling.
LLVM_ABI Value *CreateVScale(Constant *Scaling, const Twine &Name = "");
/// Create a call to llvm.vscale.<Ty>().
LLVM_ABI Value *CreateVScale(Type *Ty, const Twine &Name = "") {
return CreateIntrinsic(Intrinsic::vscale, {Ty}, {}, {}, Name);
}

/// Create an expression which evaluates to the number of elements in \p EC
/// at runtime.
LLVM_ABI Value *CreateElementCount(Type *DstType, ElementCount EC);
LLVM_ABI Value *CreateElementCount(Type *Ty, ElementCount EC);

/// Create an expression which evaluates to the number of units in \p Size
/// at runtime. This works for both units of bits and bytes.
LLVM_ABI Value *CreateTypeSize(Type *DstType, TypeSize Size);
LLVM_ABI Value *CreateTypeSize(Type *Ty, TypeSize Size);

/// Creates a vector of type \p DstType with the linear sequence <0, 1, ...>
LLVM_ABI Value *CreateStepVector(Type *DstType, const Twine &Name = "");
Expand Down
3 changes: 1 addition & 2 deletions llvm/lib/CodeGen/ExpandVectorPredication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -507,8 +507,7 @@ bool CachingVPExpander::discardEVLParameter(VPIntrinsic &VPI) {
// TODO add caching
IRBuilder<> Builder(VPI.getParent(), VPI.getIterator());
Value *FactorConst = Builder.getInt32(StaticElemCount.getKnownMinValue());
Value *VScale = Builder.CreateIntrinsic(Intrinsic::vscale, Int32Ty, {},
/*FMFSource=*/nullptr, "vscale");
Value *VScale = Builder.CreateVScale(Int32Ty, "vscale");
MaxEVL = Builder.CreateMul(VScale, FactorConst, "scalable_size",
/*NUW*/ true, /*NSW*/ false);
} else {
Expand Down
29 changes: 16 additions & 13 deletions llvm/lib/IR/IRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,23 +120,26 @@ IRBuilderBase::createCallHelper(Function *Callee, ArrayRef<Value *> Ops,
return CI;
}

Value *IRBuilderBase::CreateVScale(Constant *Scaling, const Twine &Name) {
assert(isa<ConstantInt>(Scaling) && "Expected constant integer");
if (cast<ConstantInt>(Scaling)->isZero())
return Scaling;
CallInst *CI =
CreateIntrinsic(Intrinsic::vscale, {Scaling->getType()}, {}, {}, Name);
return cast<ConstantInt>(Scaling)->isOne() ? CI : CreateMul(CI, Scaling);
static Value *CreateVScaleMultiple(IRBuilderBase &B, Type *Ty, uint64_t Scale) {
Value *VScale = B.CreateVScale(Ty);
if (Scale == 1)
return VScale;

return B.CreateMul(VScale, ConstantInt::get(Ty, Scale));
}

Value *IRBuilderBase::CreateElementCount(Type *DstType, ElementCount EC) {
Constant *MinEC = ConstantInt::get(DstType, EC.getKnownMinValue());
return EC.isScalable() ? CreateVScale(MinEC) : MinEC;
Value *IRBuilderBase::CreateElementCount(Type *Ty, ElementCount EC) {
if (EC.isFixed() || EC.isZero())
return ConstantInt::get(Ty, EC.getKnownMinValue());

return CreateVScaleMultiple(*this, Ty, EC.getKnownMinValue());
}

Value *IRBuilderBase::CreateTypeSize(Type *DstType, TypeSize Size) {
Constant *MinSize = ConstantInt::get(DstType, Size.getKnownMinValue());
return Size.isScalable() ? CreateVScale(MinSize) : MinSize;
Value *IRBuilderBase::CreateTypeSize(Type *Ty, TypeSize Size) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just suggestion, so please feel free to ignore it, but given how similar the two functions are, would it be useful to write a common helper such as

  Value *CreateVScaleMultiple(Type *Ty, uint64_t Scale) {
    Value *VScale = CreateVScale(Ty);
    if (Scale == 1)
      return VScale;

    return CreateMul(VScale, ConstantInt::get(Ty, Scale));
  }

and reusing this in both CreateTypeSize and CreateElementCount?

if (Size.isFixed() || Size.isZero())
return ConstantInt::get(Ty, Size.getKnownMinValue());

return CreateVScaleMultiple(*this, Ty, Size.getKnownMinValue());
}

Value *IRBuilderBase::CreateStepVector(Type *DstType, const Twine &Name) {
Expand Down
8 changes: 4 additions & 4 deletions llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2051,10 +2051,10 @@ instCombineSVECntElts(InstCombiner &IC, IntrinsicInst &II, unsigned NumElts) {
const auto Pattern = cast<ConstantInt>(II.getArgOperand(0))->getZExtValue();

if (Pattern == AArch64SVEPredPattern::all) {
Constant *StepVal = ConstantInt::get(II.getType(), NumElts);
auto *VScale = IC.Builder.CreateVScale(StepVal);
VScale->takeName(&II);
return IC.replaceInstUsesWith(II, VScale);
Value *Cnt = IC.Builder.CreateElementCount(
II.getType(), ElementCount::getScalable(NumElts));
Cnt->takeName(&II);
return IC.replaceInstUsesWith(II, Cnt);
}

unsigned MinNumElts = getNumElementsFromSVEPredPattern(Pattern);
Expand Down
24 changes: 8 additions & 16 deletions llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -935,12 +935,9 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) {
Trunc.getFunction()->hasFnAttribute(Attribute::VScaleRange)) {
Attribute Attr =
Trunc.getFunction()->getFnAttribute(Attribute::VScaleRange);
if (std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax()) {
if (Log2_32(*MaxVScale) < DestWidth) {
Value *VScale = Builder.CreateVScale(ConstantInt::get(DestTy, 1));
return replaceInstUsesWith(Trunc, VScale);
}
}
if (std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax())
if (Log2_32(*MaxVScale) < DestWidth)
return replaceInstUsesWith(Trunc, Builder.CreateVScale(DestTy));
}
}

Expand Down Expand Up @@ -1314,10 +1311,8 @@ Instruction *InstCombinerImpl::visitZExt(ZExtInst &Zext) {
Zext.getFunction()->getFnAttribute(Attribute::VScaleRange);
if (std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax()) {
unsigned TypeWidth = Src->getType()->getScalarSizeInBits();
if (Log2_32(*MaxVScale) < TypeWidth) {
Value *VScale = Builder.CreateVScale(ConstantInt::get(DestTy, 1));
return replaceInstUsesWith(Zext, VScale);
}
if (Log2_32(*MaxVScale) < TypeWidth)
return replaceInstUsesWith(Zext, Builder.CreateVScale(DestTy));
}
}
}
Expand Down Expand Up @@ -1604,12 +1599,9 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &Sext) {
Sext.getFunction()->hasFnAttribute(Attribute::VScaleRange)) {
Attribute Attr =
Sext.getFunction()->getFnAttribute(Attribute::VScaleRange);
if (std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax()) {
if (Log2_32(*MaxVScale) < (SrcBitSize - 1)) {
Value *VScale = Builder.CreateVScale(ConstantInt::get(DestTy, 1));
return replaceInstUsesWith(Sext, VScale);
}
}
if (std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax())
if (Log2_32(*MaxVScale) < (SrcBitSize - 1))
return replaceInstUsesWith(Sext, Builder.CreateVScale(DestTy));
}
}

Expand Down
16 changes: 3 additions & 13 deletions llvm/lib/Transforms/Utils/LowerVectorIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,29 +21,19 @@ bool llvm::lowerUnaryVectorIntrinsicAsLoop(Module &M, CallInst *CI) {
BasicBlock *PostLoopBB = nullptr;
Function *ParentFunc = PreLoopBB->getParent();
LLVMContext &Ctx = PreLoopBB->getContext();
Type *Int64Ty = IntegerType::get(Ctx, 64);

PostLoopBB = PreLoopBB->splitBasicBlock(CI);
BasicBlock *LoopBB = BasicBlock::Create(Ctx, "", ParentFunc, PostLoopBB);
PreLoopBB->getTerminator()->setSuccessor(0, LoopBB);

// Loop preheader
IRBuilder<> PreLoopBuilder(PreLoopBB->getTerminator());
Value *LoopEnd = nullptr;
if (auto *ScalableVecTy = dyn_cast<ScalableVectorType>(VecTy)) {
Value *VScale = PreLoopBuilder.CreateVScale(
ConstantInt::get(PreLoopBuilder.getInt64Ty(), 1));
Value *N = ConstantInt::get(PreLoopBuilder.getInt64Ty(),
ScalableVecTy->getMinNumElements());
LoopEnd = PreLoopBuilder.CreateMul(VScale, N);
} else {
FixedVectorType *FixedVecTy = cast<FixedVectorType>(VecTy);
LoopEnd = ConstantInt::get(PreLoopBuilder.getInt64Ty(),
FixedVecTy->getNumElements());
}
Value *LoopEnd =
PreLoopBuilder.CreateElementCount(Int64Ty, VecTy->getElementCount());

// Loop body
IRBuilder<> LoopBuilder(LoopBB);
Type *Int64Ty = LoopBuilder.getInt64Ty();

PHINode *LoopIndex = LoopBuilder.CreatePHI(Int64Ty, 2);
LoopIndex->addIncoming(ConstantInt::get(Int64Ty, 0U), PreLoopBB);
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1440,7 +1440,7 @@ Value *SCEVExpander::visitSequentialUMinExpr(const SCEVSequentialUMinExpr *S) {
}

Value *SCEVExpander::visitVScale(const SCEVVScale *S) {
return Builder.CreateVScale(ConstantInt::get(S->getType(), 1));
return Builder.CreateVScale(S->getType());
}

Value *SCEVExpander::expandCodeFor(const SCEV *SH, Type *Ty,
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ Value *LoopIdiomVectorize::createMaskedFindMismatch(
Value *InitialPred = Builder.CreateIntrinsic(
Intrinsic::get_active_lane_mask, {PredVTy, I64Type}, {ExtStart, ExtEnd});

Value *VecLen = Builder.CreateIntrinsic(Intrinsic::vscale, {I64Type}, {});
Value *VecLen = Builder.CreateVScale(I64Type);
VecLen =
Builder.CreateMul(VecLen, ConstantInt::get(I64Type, ByteCompareVF), "",
/*HasNUW=*/true, /*HasNSW=*/true);
Expand Down
8 changes: 0 additions & 8 deletions llvm/unittests/IR/IRBuilderTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,14 +212,6 @@ TEST_F(IRBuilderTest, IntrinsicsWithScalableVectors) {
EXPECT_EQ(FTy->getParamType(i), Args[i]->getType());
}

TEST_F(IRBuilderTest, CreateVScale) {
IRBuilder<> Builder(BB);

Constant *Zero = Builder.getInt32(0);
Value *VScale = Builder.CreateVScale(Zero);
EXPECT_TRUE(isa<ConstantInt>(VScale) && cast<ConstantInt>(VScale)->isZero());
}

TEST_F(IRBuilderTest, CreateStepVector) {
IRBuilder<> Builder(BB);

Expand Down