Skip to content

[NFC][LLVM] Refactor IRBuilder::Create{VScale,ElementCount,TypeSize}. #142803

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

paulwalker-arm
Copy link
Collaborator

CreateVScale took a scaling parameter that had a single use outside of IRBuilder with all other callers having to create a redundant ConstantInt. To work round this some code perferred to use CreateIntrinsic directly.

This patch simplifies CreateVScale to return a call to the llvm.vscale() intrinsic and nothing more. As well as simplifying the existing call sites I've also migrated the uses of CreateIntrinsic.

Whilst IRBuilder used CreateVScale's scaling parameter as part of the implementations of CreateElementCount and CreateTypeSize, I have follow-on work to switch them to the NUW varaiety and thus they would stop using CreateVScale's scaling as well. To prepare for this I have moved the multiplication and constant folding into the implementations of CreateElementCount and CreateTypeSize.

As a final step I have replaced some callers of CreateVScale with CreateElementCount where it's clear from the code they wanted the latter.

CreateVScale took a scaling parameter that had a single use
outside of IRBuilder with all other callers having to create a
redundant ConstantInt. To work round this some code perferred to use
CreateIntrinsic directly.

This patch simplifies CreateVScale to only return a call to the
llvm.vscale() intrinsic and nothing more. As well as simplifying the
existing call sites I've also ported the uses of CreateIntrinsic.

Whilst IRBuilder used CreateVScale's scaling parameter as part of
the implementations of CreateElementCount and CreateTypeSize, I have
follow-on work to switch them to the NUW varaiety and thus they would
stop using CreateVScale's scaling as well. To prepare for this I have
moved the multiplication and constant folding into the implementations
of CreateElementCount and CreateTypeSize.

As a final step I have replaced some callers of CreateVScale with
CreateElementCount where it's clear from the code they wanted the
latter.
@paulwalker-arm paulwalker-arm requested a review from nikic as a code owner June 4, 2025 16:10
@llvmbot llvmbot added clang Clang issues not falling into any other category backend:AArch64 clang:codegen IR generation bugs: mangling, exceptions, etc. llvm:codegen vectorizers llvm:instcombine llvm:ir llvm:transforms labels Jun 4, 2025
@llvmbot
Copy link
Member

llvmbot commented Jun 4, 2025

@llvm/pr-subscribers-vectorizers
@llvm/pr-subscribers-llvm-transforms
@llvm/pr-subscribers-clang-codegen

@llvm/pr-subscribers-backend-aarch64

Author: Paul Walker (paulwalker-arm)

Changes

CreateVScale took a scaling parameter that had a single use outside of IRBuilder with all other callers having to create a redundant ConstantInt. To work round this some code perferred to use CreateIntrinsic directly.

This patch simplifies CreateVScale to return a call to the llvm.vscale() intrinsic and nothing more. As well as simplifying the existing call sites I've also migrated the uses of CreateIntrinsic.

Whilst IRBuilder used CreateVScale's scaling parameter as part of the implementations of CreateElementCount and CreateTypeSize, I have follow-on work to switch them to the NUW varaiety and thus they would stop using CreateVScale's scaling as well. To prepare for this I have moved the multiplication and constant folding into the implementations of CreateElementCount and CreateTypeSize.

As a final step I have replaced some callers of CreateVScale with CreateElementCount where it's clear from the code they wanted the latter.


Full diff: https://github.com/llvm/llvm-project/pull/142803.diff

10 Files Affected:

  • (modified) clang/lib/CodeGen/TargetBuiltins/ARM.cpp (+1-5)
  • (modified) llvm/include/llvm/IR/IRBuilder.h (+4-5)
  • (modified) llvm/lib/CodeGen/ExpandVectorPredication.cpp (+1-2)
  • (modified) llvm/lib/IR/IRBuilder.cpp (+20-13)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp (+4-4)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp (+8-16)
  • (modified) llvm/lib/Transforms/Utils/LowerVectorIntrinsics.cpp (+3-13)
  • (modified) llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp (+1-1)
  • (modified) llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp (+1-1)
  • (modified) llvm/unittests/IR/IRBuilderTest.cpp (-8)
diff --git a/clang/lib/CodeGen/TargetBuiltins/ARM.cpp b/clang/lib/CodeGen/TargetBuiltins/ARM.cpp
index 1cf8f6819b75a..9c77346389e04 100644
--- a/clang/lib/CodeGen/TargetBuiltins/ARM.cpp
+++ b/clang/lib/CodeGen/TargetBuiltins/ARM.cpp
@@ -4793,11 +4793,7 @@ Value *CodeGenFunction::EmitAArch64SVEBuiltinExpr(unsigned BuiltinID,
   case SVE::BI__builtin_sve_svlen_u64: {
     SVETypeFlags TF(Builtin->TypeModifier);
     auto VTy = cast<llvm::VectorType>(getSVEType(TF));
-    auto *NumEls =
-        llvm::ConstantInt::get(Ty, VTy->getElementCount().getKnownMinValue());
-
-    Function *F = CGM.getIntrinsic(Intrinsic::vscale, Ty);
-    return Builder.CreateMul(NumEls, Builder.CreateCall(F));
+    return Builder.CreateElementCount(Ty, VTy->getElementCount());
   }
 
   case SVE::BI__builtin_sve_svtbl2_u8:
diff --git a/llvm/include/llvm/IR/IRBuilder.h b/llvm/include/llvm/IR/IRBuilder.h
index 0db5179c7a3e4..8ed10cb803a9c 100644
--- a/llvm/include/llvm/IR/IRBuilder.h
+++ b/llvm/include/llvm/IR/IRBuilder.h
@@ -945,17 +945,16 @@ class IRBuilderBase {
   LLVM_ABI CallInst *CreateGCGetPointerOffset(Value *DerivedPtr,
                                               const Twine &Name = "");
 
-  /// Create a call to llvm.vscale, multiplied by \p Scaling. The type of VScale
-  /// will be the same type as that of \p Scaling.
-  LLVM_ABI Value *CreateVScale(Constant *Scaling, const Twine &Name = "");
+  /// Create a call to llvm.vscale.<Ty>().
+  LLVM_ABI Value *CreateVScale(Type *Ty, const Twine &Name = "");
 
   /// Create an expression which evaluates to the number of elements in \p EC
   /// at runtime.
-  LLVM_ABI Value *CreateElementCount(Type *DstType, ElementCount EC);
+  LLVM_ABI Value *CreateElementCount(Type *Ty, ElementCount EC);
 
   /// Create an expression which evaluates to the number of units in \p Size
   /// at runtime.  This works for both units of bits and bytes.
-  LLVM_ABI Value *CreateTypeSize(Type *DstType, TypeSize Size);
+  LLVM_ABI Value *CreateTypeSize(Type *Ty, TypeSize Size);
 
   /// Creates a vector of type \p DstType with the linear sequence <0, 1, ...>
   LLVM_ABI Value *CreateStepVector(Type *DstType, const Twine &Name = "");
diff --git a/llvm/lib/CodeGen/ExpandVectorPredication.cpp b/llvm/lib/CodeGen/ExpandVectorPredication.cpp
index 1bb0763fcf57b..d8e3f5fbb31de 100644
--- a/llvm/lib/CodeGen/ExpandVectorPredication.cpp
+++ b/llvm/lib/CodeGen/ExpandVectorPredication.cpp
@@ -507,8 +507,7 @@ bool CachingVPExpander::discardEVLParameter(VPIntrinsic &VPI) {
     // TODO add caching
     IRBuilder<> Builder(VPI.getParent(), VPI.getIterator());
     Value *FactorConst = Builder.getInt32(StaticElemCount.getKnownMinValue());
-    Value *VScale = Builder.CreateIntrinsic(Intrinsic::vscale, Int32Ty, {},
-                                            /*FMFSource=*/nullptr, "vscale");
+    Value *VScale = Builder.CreateVScale(Int32Ty, "vscale");
     MaxEVL = Builder.CreateMul(VScale, FactorConst, "scalable_size",
                                /*NUW*/ true, /*NSW*/ false);
   } else {
diff --git a/llvm/lib/IR/IRBuilder.cpp b/llvm/lib/IR/IRBuilder.cpp
index 580b0af709337..868aa7a2cb799 100644
--- a/llvm/lib/IR/IRBuilder.cpp
+++ b/llvm/lib/IR/IRBuilder.cpp
@@ -120,23 +120,30 @@ IRBuilderBase::createCallHelper(Function *Callee, ArrayRef<Value *> Ops,
   return CI;
 }
 
-Value *IRBuilderBase::CreateVScale(Constant *Scaling, const Twine &Name) {
-  assert(isa<ConstantInt>(Scaling) && "Expected constant integer");
-  if (cast<ConstantInt>(Scaling)->isZero())
-    return Scaling;
-  CallInst *CI =
-      CreateIntrinsic(Intrinsic::vscale, {Scaling->getType()}, {}, {}, Name);
-  return cast<ConstantInt>(Scaling)->isOne() ? CI : CreateMul(CI, Scaling);
+Value *IRBuilderBase::CreateVScale(Type *Ty, const Twine &Name) {
+  return CreateIntrinsic(Intrinsic::vscale, {Ty}, {}, {}, Name);
 }
 
-Value *IRBuilderBase::CreateElementCount(Type *DstType, ElementCount EC) {
-  Constant *MinEC = ConstantInt::get(DstType, EC.getKnownMinValue());
-  return EC.isScalable() ? CreateVScale(MinEC) : MinEC;
+Value *IRBuilderBase::CreateElementCount(Type *Ty, ElementCount EC) {
+  if (EC.isFixed() || EC.isZero())
+    return ConstantInt::get(Ty, EC.getKnownMinValue());
+
+  Value *VScale = CreateVScale(Ty);
+  if (EC.getKnownMinValue() == 1)
+    return VScale;
+
+  return CreateMul(VScale, ConstantInt::get(Ty, EC.getKnownMinValue()));
 }
 
-Value *IRBuilderBase::CreateTypeSize(Type *DstType, TypeSize Size) {
-  Constant *MinSize = ConstantInt::get(DstType, Size.getKnownMinValue());
-  return Size.isScalable() ? CreateVScale(MinSize) : MinSize;
+Value *IRBuilderBase::CreateTypeSize(Type *Ty, TypeSize Size) {
+  if (Size.isFixed() || Size.isZero())
+    return ConstantInt::get(Ty, Size.getKnownMinValue());
+
+  Value *VScale = CreateVScale(Ty);
+  if (Size.getKnownMinValue() == 1)
+    return VScale;
+
+  return CreateMul(VScale, ConstantInt::get(Ty, Size.getKnownMinValue()));
 }
 
 Value *IRBuilderBase::CreateStepVector(Type *DstType, const Twine &Name) {
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 68aec80f07e1d..f6167e4333327 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -2051,10 +2051,10 @@ instCombineSVECntElts(InstCombiner &IC, IntrinsicInst &II, unsigned NumElts) {
   const auto Pattern = cast<ConstantInt>(II.getArgOperand(0))->getZExtValue();
 
   if (Pattern == AArch64SVEPredPattern::all) {
-    Constant *StepVal = ConstantInt::get(II.getType(), NumElts);
-    auto *VScale = IC.Builder.CreateVScale(StepVal);
-    VScale->takeName(&II);
-    return IC.replaceInstUsesWith(II, VScale);
+    Value *Cnt = IC.Builder.CreateElementCount(
+        II.getType(), ElementCount::getScalable(NumElts));
+    Cnt->takeName(&II);
+    return IC.replaceInstUsesWith(II, Cnt);
   }
 
   unsigned MinNumElts = getNumElementsFromSVEPredPattern(Pattern);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index d234a0566e191..2db79228bf0e6 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -935,12 +935,9 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) {
         Trunc.getFunction()->hasFnAttribute(Attribute::VScaleRange)) {
       Attribute Attr =
           Trunc.getFunction()->getFnAttribute(Attribute::VScaleRange);
-      if (std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax()) {
-        if (Log2_32(*MaxVScale) < DestWidth) {
-          Value *VScale = Builder.CreateVScale(ConstantInt::get(DestTy, 1));
-          return replaceInstUsesWith(Trunc, VScale);
-        }
-      }
+      if (std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax())
+        if (Log2_32(*MaxVScale) < DestWidth)
+          return replaceInstUsesWith(Trunc, Builder.CreateVScale(DestTy));
     }
   }
 
@@ -1314,10 +1311,8 @@ Instruction *InstCombinerImpl::visitZExt(ZExtInst &Zext) {
           Zext.getFunction()->getFnAttribute(Attribute::VScaleRange);
       if (std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax()) {
         unsigned TypeWidth = Src->getType()->getScalarSizeInBits();
-        if (Log2_32(*MaxVScale) < TypeWidth) {
-          Value *VScale = Builder.CreateVScale(ConstantInt::get(DestTy, 1));
-          return replaceInstUsesWith(Zext, VScale);
-        }
+        if (Log2_32(*MaxVScale) < TypeWidth)
+          return replaceInstUsesWith(Zext, Builder.CreateVScale(DestTy));
       }
     }
   }
@@ -1604,12 +1599,9 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &Sext) {
         Sext.getFunction()->hasFnAttribute(Attribute::VScaleRange)) {
       Attribute Attr =
           Sext.getFunction()->getFnAttribute(Attribute::VScaleRange);
-      if (std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax()) {
-        if (Log2_32(*MaxVScale) < (SrcBitSize - 1)) {
-          Value *VScale = Builder.CreateVScale(ConstantInt::get(DestTy, 1));
-          return replaceInstUsesWith(Sext, VScale);
-        }
-      }
+      if (std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax())
+        if (Log2_32(*MaxVScale) < (SrcBitSize - 1))
+          return replaceInstUsesWith(Sext, Builder.CreateVScale(DestTy));
     }
   }
 
diff --git a/llvm/lib/Transforms/Utils/LowerVectorIntrinsics.cpp b/llvm/lib/Transforms/Utils/LowerVectorIntrinsics.cpp
index a34ef260dc244..71c10f5b157c7 100644
--- a/llvm/lib/Transforms/Utils/LowerVectorIntrinsics.cpp
+++ b/llvm/lib/Transforms/Utils/LowerVectorIntrinsics.cpp
@@ -21,6 +21,7 @@ bool llvm::lowerUnaryVectorIntrinsicAsLoop(Module &M, CallInst *CI) {
   BasicBlock *PostLoopBB = nullptr;
   Function *ParentFunc = PreLoopBB->getParent();
   LLVMContext &Ctx = PreLoopBB->getContext();
+  Type *Int64Ty = IntegerType::get(Ctx, 64);
 
   PostLoopBB = PreLoopBB->splitBasicBlock(CI);
   BasicBlock *LoopBB = BasicBlock::Create(Ctx, "", ParentFunc, PostLoopBB);
@@ -28,22 +29,11 @@ bool llvm::lowerUnaryVectorIntrinsicAsLoop(Module &M, CallInst *CI) {
 
   // Loop preheader
   IRBuilder<> PreLoopBuilder(PreLoopBB->getTerminator());
-  Value *LoopEnd = nullptr;
-  if (auto *ScalableVecTy = dyn_cast<ScalableVectorType>(VecTy)) {
-    Value *VScale = PreLoopBuilder.CreateVScale(
-        ConstantInt::get(PreLoopBuilder.getInt64Ty(), 1));
-    Value *N = ConstantInt::get(PreLoopBuilder.getInt64Ty(),
-                                ScalableVecTy->getMinNumElements());
-    LoopEnd = PreLoopBuilder.CreateMul(VScale, N);
-  } else {
-    FixedVectorType *FixedVecTy = cast<FixedVectorType>(VecTy);
-    LoopEnd = ConstantInt::get(PreLoopBuilder.getInt64Ty(),
-                               FixedVecTy->getNumElements());
-  }
+  Value *LoopEnd =
+      PreLoopBuilder.CreateElementCount(Int64Ty, VecTy->getElementCount());
 
   // Loop body
   IRBuilder<> LoopBuilder(LoopBB);
-  Type *Int64Ty = LoopBuilder.getInt64Ty();
 
   PHINode *LoopIndex = LoopBuilder.CreatePHI(Int64Ty, 2);
   LoopIndex->addIncoming(ConstantInt::get(Int64Ty, 0U), PreLoopBB);
diff --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
index 88d2eca36ca51..70afd4133df7c 100644
--- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
+++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
@@ -1440,7 +1440,7 @@ Value *SCEVExpander::visitSequentialUMinExpr(const SCEVSequentialUMinExpr *S) {
 }
 
 Value *SCEVExpander::visitVScale(const SCEVVScale *S) {
-  return Builder.CreateVScale(ConstantInt::get(S->getType(), 1));
+  return Builder.CreateVScale(S->getType());
 }
 
 Value *SCEVExpander::expandCodeFor(const SCEV *SH, Type *Ty,
diff --git a/llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp
index 02ffc67c774dd..491f0b76f4ae0 100644
--- a/llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp
@@ -437,7 +437,7 @@ Value *LoopIdiomVectorize::createMaskedFindMismatch(
   Value *InitialPred = Builder.CreateIntrinsic(
       Intrinsic::get_active_lane_mask, {PredVTy, I64Type}, {ExtStart, ExtEnd});
 
-  Value *VecLen = Builder.CreateIntrinsic(Intrinsic::vscale, {I64Type}, {});
+  Value *VecLen = Builder.CreateVScale(I64Type);
   VecLen =
       Builder.CreateMul(VecLen, ConstantInt::get(I64Type, ByteCompareVF), "",
                         /*HasNUW=*/true, /*HasNSW=*/true);
diff --git a/llvm/unittests/IR/IRBuilderTest.cpp b/llvm/unittests/IR/IRBuilderTest.cpp
index b7eb0af728331..3a7ba924792ef 100644
--- a/llvm/unittests/IR/IRBuilderTest.cpp
+++ b/llvm/unittests/IR/IRBuilderTest.cpp
@@ -212,14 +212,6 @@ TEST_F(IRBuilderTest, IntrinsicsWithScalableVectors) {
     EXPECT_EQ(FTy->getParamType(i), Args[i]->getType());
 }
 
-TEST_F(IRBuilderTest, CreateVScale) {
-  IRBuilder<> Builder(BB);
-
-  Constant *Zero = Builder.getInt32(0);
-  Value *VScale = Builder.CreateVScale(Zero);
-  EXPECT_TRUE(isa<ConstantInt>(VScale) && cast<ConstantInt>(VScale)->isZero());
-}
-
 TEST_F(IRBuilderTest, CreateStepVector) {
   IRBuilder<> Builder(BB);
 

@llvmbot
Copy link
Member

llvmbot commented Jun 4, 2025

@llvm/pr-subscribers-clang

Author: Paul Walker (paulwalker-arm)

Changes

CreateVScale took a scaling parameter that had a single use outside of IRBuilder with all other callers having to create a redundant ConstantInt. To work round this some code perferred to use CreateIntrinsic directly.

This patch simplifies CreateVScale to return a call to the llvm.vscale() intrinsic and nothing more. As well as simplifying the existing call sites I've also migrated the uses of CreateIntrinsic.

Whilst IRBuilder used CreateVScale's scaling parameter as part of the implementations of CreateElementCount and CreateTypeSize, I have follow-on work to switch them to the NUW varaiety and thus they would stop using CreateVScale's scaling as well. To prepare for this I have moved the multiplication and constant folding into the implementations of CreateElementCount and CreateTypeSize.

As a final step I have replaced some callers of CreateVScale with CreateElementCount where it's clear from the code they wanted the latter.


Full diff: https://github.com/llvm/llvm-project/pull/142803.diff

10 Files Affected:

  • (modified) clang/lib/CodeGen/TargetBuiltins/ARM.cpp (+1-5)
  • (modified) llvm/include/llvm/IR/IRBuilder.h (+4-5)
  • (modified) llvm/lib/CodeGen/ExpandVectorPredication.cpp (+1-2)
  • (modified) llvm/lib/IR/IRBuilder.cpp (+20-13)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp (+4-4)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp (+8-16)
  • (modified) llvm/lib/Transforms/Utils/LowerVectorIntrinsics.cpp (+3-13)
  • (modified) llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp (+1-1)
  • (modified) llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp (+1-1)
  • (modified) llvm/unittests/IR/IRBuilderTest.cpp (-8)
diff --git a/clang/lib/CodeGen/TargetBuiltins/ARM.cpp b/clang/lib/CodeGen/TargetBuiltins/ARM.cpp
index 1cf8f6819b75a..9c77346389e04 100644
--- a/clang/lib/CodeGen/TargetBuiltins/ARM.cpp
+++ b/clang/lib/CodeGen/TargetBuiltins/ARM.cpp
@@ -4793,11 +4793,7 @@ Value *CodeGenFunction::EmitAArch64SVEBuiltinExpr(unsigned BuiltinID,
   case SVE::BI__builtin_sve_svlen_u64: {
     SVETypeFlags TF(Builtin->TypeModifier);
     auto VTy = cast<llvm::VectorType>(getSVEType(TF));
-    auto *NumEls =
-        llvm::ConstantInt::get(Ty, VTy->getElementCount().getKnownMinValue());
-
-    Function *F = CGM.getIntrinsic(Intrinsic::vscale, Ty);
-    return Builder.CreateMul(NumEls, Builder.CreateCall(F));
+    return Builder.CreateElementCount(Ty, VTy->getElementCount());
   }
 
   case SVE::BI__builtin_sve_svtbl2_u8:
diff --git a/llvm/include/llvm/IR/IRBuilder.h b/llvm/include/llvm/IR/IRBuilder.h
index 0db5179c7a3e4..8ed10cb803a9c 100644
--- a/llvm/include/llvm/IR/IRBuilder.h
+++ b/llvm/include/llvm/IR/IRBuilder.h
@@ -945,17 +945,16 @@ class IRBuilderBase {
   LLVM_ABI CallInst *CreateGCGetPointerOffset(Value *DerivedPtr,
                                               const Twine &Name = "");
 
-  /// Create a call to llvm.vscale, multiplied by \p Scaling. The type of VScale
-  /// will be the same type as that of \p Scaling.
-  LLVM_ABI Value *CreateVScale(Constant *Scaling, const Twine &Name = "");
+  /// Create a call to llvm.vscale.<Ty>().
+  LLVM_ABI Value *CreateVScale(Type *Ty, const Twine &Name = "");
 
   /// Create an expression which evaluates to the number of elements in \p EC
   /// at runtime.
-  LLVM_ABI Value *CreateElementCount(Type *DstType, ElementCount EC);
+  LLVM_ABI Value *CreateElementCount(Type *Ty, ElementCount EC);
 
   /// Create an expression which evaluates to the number of units in \p Size
   /// at runtime.  This works for both units of bits and bytes.
-  LLVM_ABI Value *CreateTypeSize(Type *DstType, TypeSize Size);
+  LLVM_ABI Value *CreateTypeSize(Type *Ty, TypeSize Size);
 
   /// Creates a vector of type \p DstType with the linear sequence <0, 1, ...>
   LLVM_ABI Value *CreateStepVector(Type *DstType, const Twine &Name = "");
diff --git a/llvm/lib/CodeGen/ExpandVectorPredication.cpp b/llvm/lib/CodeGen/ExpandVectorPredication.cpp
index 1bb0763fcf57b..d8e3f5fbb31de 100644
--- a/llvm/lib/CodeGen/ExpandVectorPredication.cpp
+++ b/llvm/lib/CodeGen/ExpandVectorPredication.cpp
@@ -507,8 +507,7 @@ bool CachingVPExpander::discardEVLParameter(VPIntrinsic &VPI) {
     // TODO add caching
     IRBuilder<> Builder(VPI.getParent(), VPI.getIterator());
     Value *FactorConst = Builder.getInt32(StaticElemCount.getKnownMinValue());
-    Value *VScale = Builder.CreateIntrinsic(Intrinsic::vscale, Int32Ty, {},
-                                            /*FMFSource=*/nullptr, "vscale");
+    Value *VScale = Builder.CreateVScale(Int32Ty, "vscale");
     MaxEVL = Builder.CreateMul(VScale, FactorConst, "scalable_size",
                                /*NUW*/ true, /*NSW*/ false);
   } else {
diff --git a/llvm/lib/IR/IRBuilder.cpp b/llvm/lib/IR/IRBuilder.cpp
index 580b0af709337..868aa7a2cb799 100644
--- a/llvm/lib/IR/IRBuilder.cpp
+++ b/llvm/lib/IR/IRBuilder.cpp
@@ -120,23 +120,30 @@ IRBuilderBase::createCallHelper(Function *Callee, ArrayRef<Value *> Ops,
   return CI;
 }
 
-Value *IRBuilderBase::CreateVScale(Constant *Scaling, const Twine &Name) {
-  assert(isa<ConstantInt>(Scaling) && "Expected constant integer");
-  if (cast<ConstantInt>(Scaling)->isZero())
-    return Scaling;
-  CallInst *CI =
-      CreateIntrinsic(Intrinsic::vscale, {Scaling->getType()}, {}, {}, Name);
-  return cast<ConstantInt>(Scaling)->isOne() ? CI : CreateMul(CI, Scaling);
+Value *IRBuilderBase::CreateVScale(Type *Ty, const Twine &Name) {
+  return CreateIntrinsic(Intrinsic::vscale, {Ty}, {}, {}, Name);
 }
 
-Value *IRBuilderBase::CreateElementCount(Type *DstType, ElementCount EC) {
-  Constant *MinEC = ConstantInt::get(DstType, EC.getKnownMinValue());
-  return EC.isScalable() ? CreateVScale(MinEC) : MinEC;
+Value *IRBuilderBase::CreateElementCount(Type *Ty, ElementCount EC) {
+  if (EC.isFixed() || EC.isZero())
+    return ConstantInt::get(Ty, EC.getKnownMinValue());
+
+  Value *VScale = CreateVScale(Ty);
+  if (EC.getKnownMinValue() == 1)
+    return VScale;
+
+  return CreateMul(VScale, ConstantInt::get(Ty, EC.getKnownMinValue()));
 }
 
-Value *IRBuilderBase::CreateTypeSize(Type *DstType, TypeSize Size) {
-  Constant *MinSize = ConstantInt::get(DstType, Size.getKnownMinValue());
-  return Size.isScalable() ? CreateVScale(MinSize) : MinSize;
+Value *IRBuilderBase::CreateTypeSize(Type *Ty, TypeSize Size) {
+  if (Size.isFixed() || Size.isZero())
+    return ConstantInt::get(Ty, Size.getKnownMinValue());
+
+  Value *VScale = CreateVScale(Ty);
+  if (Size.getKnownMinValue() == 1)
+    return VScale;
+
+  return CreateMul(VScale, ConstantInt::get(Ty, Size.getKnownMinValue()));
 }
 
 Value *IRBuilderBase::CreateStepVector(Type *DstType, const Twine &Name) {
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 68aec80f07e1d..f6167e4333327 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -2051,10 +2051,10 @@ instCombineSVECntElts(InstCombiner &IC, IntrinsicInst &II, unsigned NumElts) {
   const auto Pattern = cast<ConstantInt>(II.getArgOperand(0))->getZExtValue();
 
   if (Pattern == AArch64SVEPredPattern::all) {
-    Constant *StepVal = ConstantInt::get(II.getType(), NumElts);
-    auto *VScale = IC.Builder.CreateVScale(StepVal);
-    VScale->takeName(&II);
-    return IC.replaceInstUsesWith(II, VScale);
+    Value *Cnt = IC.Builder.CreateElementCount(
+        II.getType(), ElementCount::getScalable(NumElts));
+    Cnt->takeName(&II);
+    return IC.replaceInstUsesWith(II, Cnt);
   }
 
   unsigned MinNumElts = getNumElementsFromSVEPredPattern(Pattern);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index d234a0566e191..2db79228bf0e6 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -935,12 +935,9 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) {
         Trunc.getFunction()->hasFnAttribute(Attribute::VScaleRange)) {
       Attribute Attr =
           Trunc.getFunction()->getFnAttribute(Attribute::VScaleRange);
-      if (std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax()) {
-        if (Log2_32(*MaxVScale) < DestWidth) {
-          Value *VScale = Builder.CreateVScale(ConstantInt::get(DestTy, 1));
-          return replaceInstUsesWith(Trunc, VScale);
-        }
-      }
+      if (std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax())
+        if (Log2_32(*MaxVScale) < DestWidth)
+          return replaceInstUsesWith(Trunc, Builder.CreateVScale(DestTy));
     }
   }
 
@@ -1314,10 +1311,8 @@ Instruction *InstCombinerImpl::visitZExt(ZExtInst &Zext) {
           Zext.getFunction()->getFnAttribute(Attribute::VScaleRange);
       if (std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax()) {
         unsigned TypeWidth = Src->getType()->getScalarSizeInBits();
-        if (Log2_32(*MaxVScale) < TypeWidth) {
-          Value *VScale = Builder.CreateVScale(ConstantInt::get(DestTy, 1));
-          return replaceInstUsesWith(Zext, VScale);
-        }
+        if (Log2_32(*MaxVScale) < TypeWidth)
+          return replaceInstUsesWith(Zext, Builder.CreateVScale(DestTy));
       }
     }
   }
@@ -1604,12 +1599,9 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &Sext) {
         Sext.getFunction()->hasFnAttribute(Attribute::VScaleRange)) {
       Attribute Attr =
           Sext.getFunction()->getFnAttribute(Attribute::VScaleRange);
-      if (std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax()) {
-        if (Log2_32(*MaxVScale) < (SrcBitSize - 1)) {
-          Value *VScale = Builder.CreateVScale(ConstantInt::get(DestTy, 1));
-          return replaceInstUsesWith(Sext, VScale);
-        }
-      }
+      if (std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax())
+        if (Log2_32(*MaxVScale) < (SrcBitSize - 1))
+          return replaceInstUsesWith(Sext, Builder.CreateVScale(DestTy));
     }
   }
 
diff --git a/llvm/lib/Transforms/Utils/LowerVectorIntrinsics.cpp b/llvm/lib/Transforms/Utils/LowerVectorIntrinsics.cpp
index a34ef260dc244..71c10f5b157c7 100644
--- a/llvm/lib/Transforms/Utils/LowerVectorIntrinsics.cpp
+++ b/llvm/lib/Transforms/Utils/LowerVectorIntrinsics.cpp
@@ -21,6 +21,7 @@ bool llvm::lowerUnaryVectorIntrinsicAsLoop(Module &M, CallInst *CI) {
   BasicBlock *PostLoopBB = nullptr;
   Function *ParentFunc = PreLoopBB->getParent();
   LLVMContext &Ctx = PreLoopBB->getContext();
+  Type *Int64Ty = IntegerType::get(Ctx, 64);
 
   PostLoopBB = PreLoopBB->splitBasicBlock(CI);
   BasicBlock *LoopBB = BasicBlock::Create(Ctx, "", ParentFunc, PostLoopBB);
@@ -28,22 +29,11 @@ bool llvm::lowerUnaryVectorIntrinsicAsLoop(Module &M, CallInst *CI) {
 
   // Loop preheader
   IRBuilder<> PreLoopBuilder(PreLoopBB->getTerminator());
-  Value *LoopEnd = nullptr;
-  if (auto *ScalableVecTy = dyn_cast<ScalableVectorType>(VecTy)) {
-    Value *VScale = PreLoopBuilder.CreateVScale(
-        ConstantInt::get(PreLoopBuilder.getInt64Ty(), 1));
-    Value *N = ConstantInt::get(PreLoopBuilder.getInt64Ty(),
-                                ScalableVecTy->getMinNumElements());
-    LoopEnd = PreLoopBuilder.CreateMul(VScale, N);
-  } else {
-    FixedVectorType *FixedVecTy = cast<FixedVectorType>(VecTy);
-    LoopEnd = ConstantInt::get(PreLoopBuilder.getInt64Ty(),
-                               FixedVecTy->getNumElements());
-  }
+  Value *LoopEnd =
+      PreLoopBuilder.CreateElementCount(Int64Ty, VecTy->getElementCount());
 
   // Loop body
   IRBuilder<> LoopBuilder(LoopBB);
-  Type *Int64Ty = LoopBuilder.getInt64Ty();
 
   PHINode *LoopIndex = LoopBuilder.CreatePHI(Int64Ty, 2);
   LoopIndex->addIncoming(ConstantInt::get(Int64Ty, 0U), PreLoopBB);
diff --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
index 88d2eca36ca51..70afd4133df7c 100644
--- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
+++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
@@ -1440,7 +1440,7 @@ Value *SCEVExpander::visitSequentialUMinExpr(const SCEVSequentialUMinExpr *S) {
 }
 
 Value *SCEVExpander::visitVScale(const SCEVVScale *S) {
-  return Builder.CreateVScale(ConstantInt::get(S->getType(), 1));
+  return Builder.CreateVScale(S->getType());
 }
 
 Value *SCEVExpander::expandCodeFor(const SCEV *SH, Type *Ty,
diff --git a/llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp
index 02ffc67c774dd..491f0b76f4ae0 100644
--- a/llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp
@@ -437,7 +437,7 @@ Value *LoopIdiomVectorize::createMaskedFindMismatch(
   Value *InitialPred = Builder.CreateIntrinsic(
       Intrinsic::get_active_lane_mask, {PredVTy, I64Type}, {ExtStart, ExtEnd});
 
-  Value *VecLen = Builder.CreateIntrinsic(Intrinsic::vscale, {I64Type}, {});
+  Value *VecLen = Builder.CreateVScale(I64Type);
   VecLen =
       Builder.CreateMul(VecLen, ConstantInt::get(I64Type, ByteCompareVF), "",
                         /*HasNUW=*/true, /*HasNSW=*/true);
diff --git a/llvm/unittests/IR/IRBuilderTest.cpp b/llvm/unittests/IR/IRBuilderTest.cpp
index b7eb0af728331..3a7ba924792ef 100644
--- a/llvm/unittests/IR/IRBuilderTest.cpp
+++ b/llvm/unittests/IR/IRBuilderTest.cpp
@@ -212,14 +212,6 @@ TEST_F(IRBuilderTest, IntrinsicsWithScalableVectors) {
     EXPECT_EQ(FTy->getParamType(i), Args[i]->getType());
 }
 
-TEST_F(IRBuilderTest, CreateVScale) {
-  IRBuilder<> Builder(BB);
-
-  Constant *Zero = Builder.getInt32(0);
-  Value *VScale = Builder.CreateVScale(Zero);
-  EXPECT_TRUE(isa<ConstantInt>(VScale) && cast<ConstantInt>(VScale)->isZero());
-}
-
 TEST_F(IRBuilderTest, CreateStepVector) {
   IRBuilder<> Builder(BB);
 

@llvmbot
Copy link
Member

llvmbot commented Jun 4, 2025

@llvm/pr-subscribers-llvm-ir

Author: Paul Walker (paulwalker-arm)

Changes

CreateVScale took a scaling parameter that had a single use outside of IRBuilder with all other callers having to create a redundant ConstantInt. To work round this some code perferred to use CreateIntrinsic directly.

This patch simplifies CreateVScale to return a call to the llvm.vscale() intrinsic and nothing more. As well as simplifying the existing call sites I've also migrated the uses of CreateIntrinsic.

Whilst IRBuilder used CreateVScale's scaling parameter as part of the implementations of CreateElementCount and CreateTypeSize, I have follow-on work to switch them to the NUW varaiety and thus they would stop using CreateVScale's scaling as well. To prepare for this I have moved the multiplication and constant folding into the implementations of CreateElementCount and CreateTypeSize.

As a final step I have replaced some callers of CreateVScale with CreateElementCount where it's clear from the code they wanted the latter.


Full diff: https://github.com/llvm/llvm-project/pull/142803.diff

10 Files Affected:

  • (modified) clang/lib/CodeGen/TargetBuiltins/ARM.cpp (+1-5)
  • (modified) llvm/include/llvm/IR/IRBuilder.h (+4-5)
  • (modified) llvm/lib/CodeGen/ExpandVectorPredication.cpp (+1-2)
  • (modified) llvm/lib/IR/IRBuilder.cpp (+20-13)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp (+4-4)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp (+8-16)
  • (modified) llvm/lib/Transforms/Utils/LowerVectorIntrinsics.cpp (+3-13)
  • (modified) llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp (+1-1)
  • (modified) llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp (+1-1)
  • (modified) llvm/unittests/IR/IRBuilderTest.cpp (-8)
diff --git a/clang/lib/CodeGen/TargetBuiltins/ARM.cpp b/clang/lib/CodeGen/TargetBuiltins/ARM.cpp
index 1cf8f6819b75a..9c77346389e04 100644
--- a/clang/lib/CodeGen/TargetBuiltins/ARM.cpp
+++ b/clang/lib/CodeGen/TargetBuiltins/ARM.cpp
@@ -4793,11 +4793,7 @@ Value *CodeGenFunction::EmitAArch64SVEBuiltinExpr(unsigned BuiltinID,
   case SVE::BI__builtin_sve_svlen_u64: {
     SVETypeFlags TF(Builtin->TypeModifier);
     auto VTy = cast<llvm::VectorType>(getSVEType(TF));
-    auto *NumEls =
-        llvm::ConstantInt::get(Ty, VTy->getElementCount().getKnownMinValue());
-
-    Function *F = CGM.getIntrinsic(Intrinsic::vscale, Ty);
-    return Builder.CreateMul(NumEls, Builder.CreateCall(F));
+    return Builder.CreateElementCount(Ty, VTy->getElementCount());
   }
 
   case SVE::BI__builtin_sve_svtbl2_u8:
diff --git a/llvm/include/llvm/IR/IRBuilder.h b/llvm/include/llvm/IR/IRBuilder.h
index 0db5179c7a3e4..8ed10cb803a9c 100644
--- a/llvm/include/llvm/IR/IRBuilder.h
+++ b/llvm/include/llvm/IR/IRBuilder.h
@@ -945,17 +945,16 @@ class IRBuilderBase {
   LLVM_ABI CallInst *CreateGCGetPointerOffset(Value *DerivedPtr,
                                               const Twine &Name = "");
 
-  /// Create a call to llvm.vscale, multiplied by \p Scaling. The type of VScale
-  /// will be the same type as that of \p Scaling.
-  LLVM_ABI Value *CreateVScale(Constant *Scaling, const Twine &Name = "");
+  /// Create a call to llvm.vscale.<Ty>().
+  LLVM_ABI Value *CreateVScale(Type *Ty, const Twine &Name = "");
 
   /// Create an expression which evaluates to the number of elements in \p EC
   /// at runtime.
-  LLVM_ABI Value *CreateElementCount(Type *DstType, ElementCount EC);
+  LLVM_ABI Value *CreateElementCount(Type *Ty, ElementCount EC);
 
   /// Create an expression which evaluates to the number of units in \p Size
   /// at runtime.  This works for both units of bits and bytes.
-  LLVM_ABI Value *CreateTypeSize(Type *DstType, TypeSize Size);
+  LLVM_ABI Value *CreateTypeSize(Type *Ty, TypeSize Size);
 
   /// Creates a vector of type \p DstType with the linear sequence <0, 1, ...>
   LLVM_ABI Value *CreateStepVector(Type *DstType, const Twine &Name = "");
diff --git a/llvm/lib/CodeGen/ExpandVectorPredication.cpp b/llvm/lib/CodeGen/ExpandVectorPredication.cpp
index 1bb0763fcf57b..d8e3f5fbb31de 100644
--- a/llvm/lib/CodeGen/ExpandVectorPredication.cpp
+++ b/llvm/lib/CodeGen/ExpandVectorPredication.cpp
@@ -507,8 +507,7 @@ bool CachingVPExpander::discardEVLParameter(VPIntrinsic &VPI) {
     // TODO add caching
     IRBuilder<> Builder(VPI.getParent(), VPI.getIterator());
     Value *FactorConst = Builder.getInt32(StaticElemCount.getKnownMinValue());
-    Value *VScale = Builder.CreateIntrinsic(Intrinsic::vscale, Int32Ty, {},
-                                            /*FMFSource=*/nullptr, "vscale");
+    Value *VScale = Builder.CreateVScale(Int32Ty, "vscale");
     MaxEVL = Builder.CreateMul(VScale, FactorConst, "scalable_size",
                                /*NUW*/ true, /*NSW*/ false);
   } else {
diff --git a/llvm/lib/IR/IRBuilder.cpp b/llvm/lib/IR/IRBuilder.cpp
index 580b0af709337..868aa7a2cb799 100644
--- a/llvm/lib/IR/IRBuilder.cpp
+++ b/llvm/lib/IR/IRBuilder.cpp
@@ -120,23 +120,30 @@ IRBuilderBase::createCallHelper(Function *Callee, ArrayRef<Value *> Ops,
   return CI;
 }
 
-Value *IRBuilderBase::CreateVScale(Constant *Scaling, const Twine &Name) {
-  assert(isa<ConstantInt>(Scaling) && "Expected constant integer");
-  if (cast<ConstantInt>(Scaling)->isZero())
-    return Scaling;
-  CallInst *CI =
-      CreateIntrinsic(Intrinsic::vscale, {Scaling->getType()}, {}, {}, Name);
-  return cast<ConstantInt>(Scaling)->isOne() ? CI : CreateMul(CI, Scaling);
+Value *IRBuilderBase::CreateVScale(Type *Ty, const Twine &Name) {
+  return CreateIntrinsic(Intrinsic::vscale, {Ty}, {}, {}, Name);
 }
 
-Value *IRBuilderBase::CreateElementCount(Type *DstType, ElementCount EC) {
-  Constant *MinEC = ConstantInt::get(DstType, EC.getKnownMinValue());
-  return EC.isScalable() ? CreateVScale(MinEC) : MinEC;
+Value *IRBuilderBase::CreateElementCount(Type *Ty, ElementCount EC) {
+  if (EC.isFixed() || EC.isZero())
+    return ConstantInt::get(Ty, EC.getKnownMinValue());
+
+  Value *VScale = CreateVScale(Ty);
+  if (EC.getKnownMinValue() == 1)
+    return VScale;
+
+  return CreateMul(VScale, ConstantInt::get(Ty, EC.getKnownMinValue()));
 }
 
-Value *IRBuilderBase::CreateTypeSize(Type *DstType, TypeSize Size) {
-  Constant *MinSize = ConstantInt::get(DstType, Size.getKnownMinValue());
-  return Size.isScalable() ? CreateVScale(MinSize) : MinSize;
+Value *IRBuilderBase::CreateTypeSize(Type *Ty, TypeSize Size) {
+  if (Size.isFixed() || Size.isZero())
+    return ConstantInt::get(Ty, Size.getKnownMinValue());
+
+  Value *VScale = CreateVScale(Ty);
+  if (Size.getKnownMinValue() == 1)
+    return VScale;
+
+  return CreateMul(VScale, ConstantInt::get(Ty, Size.getKnownMinValue()));
 }
 
 Value *IRBuilderBase::CreateStepVector(Type *DstType, const Twine &Name) {
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 68aec80f07e1d..f6167e4333327 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -2051,10 +2051,10 @@ instCombineSVECntElts(InstCombiner &IC, IntrinsicInst &II, unsigned NumElts) {
   const auto Pattern = cast<ConstantInt>(II.getArgOperand(0))->getZExtValue();
 
   if (Pattern == AArch64SVEPredPattern::all) {
-    Constant *StepVal = ConstantInt::get(II.getType(), NumElts);
-    auto *VScale = IC.Builder.CreateVScale(StepVal);
-    VScale->takeName(&II);
-    return IC.replaceInstUsesWith(II, VScale);
+    Value *Cnt = IC.Builder.CreateElementCount(
+        II.getType(), ElementCount::getScalable(NumElts));
+    Cnt->takeName(&II);
+    return IC.replaceInstUsesWith(II, Cnt);
   }
 
   unsigned MinNumElts = getNumElementsFromSVEPredPattern(Pattern);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index d234a0566e191..2db79228bf0e6 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -935,12 +935,9 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) {
         Trunc.getFunction()->hasFnAttribute(Attribute::VScaleRange)) {
       Attribute Attr =
           Trunc.getFunction()->getFnAttribute(Attribute::VScaleRange);
-      if (std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax()) {
-        if (Log2_32(*MaxVScale) < DestWidth) {
-          Value *VScale = Builder.CreateVScale(ConstantInt::get(DestTy, 1));
-          return replaceInstUsesWith(Trunc, VScale);
-        }
-      }
+      if (std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax())
+        if (Log2_32(*MaxVScale) < DestWidth)
+          return replaceInstUsesWith(Trunc, Builder.CreateVScale(DestTy));
     }
   }
 
@@ -1314,10 +1311,8 @@ Instruction *InstCombinerImpl::visitZExt(ZExtInst &Zext) {
           Zext.getFunction()->getFnAttribute(Attribute::VScaleRange);
       if (std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax()) {
         unsigned TypeWidth = Src->getType()->getScalarSizeInBits();
-        if (Log2_32(*MaxVScale) < TypeWidth) {
-          Value *VScale = Builder.CreateVScale(ConstantInt::get(DestTy, 1));
-          return replaceInstUsesWith(Zext, VScale);
-        }
+        if (Log2_32(*MaxVScale) < TypeWidth)
+          return replaceInstUsesWith(Zext, Builder.CreateVScale(DestTy));
       }
     }
   }
@@ -1604,12 +1599,9 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &Sext) {
         Sext.getFunction()->hasFnAttribute(Attribute::VScaleRange)) {
       Attribute Attr =
           Sext.getFunction()->getFnAttribute(Attribute::VScaleRange);
-      if (std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax()) {
-        if (Log2_32(*MaxVScale) < (SrcBitSize - 1)) {
-          Value *VScale = Builder.CreateVScale(ConstantInt::get(DestTy, 1));
-          return replaceInstUsesWith(Sext, VScale);
-        }
-      }
+      if (std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax())
+        if (Log2_32(*MaxVScale) < (SrcBitSize - 1))
+          return replaceInstUsesWith(Sext, Builder.CreateVScale(DestTy));
     }
   }
 
diff --git a/llvm/lib/Transforms/Utils/LowerVectorIntrinsics.cpp b/llvm/lib/Transforms/Utils/LowerVectorIntrinsics.cpp
index a34ef260dc244..71c10f5b157c7 100644
--- a/llvm/lib/Transforms/Utils/LowerVectorIntrinsics.cpp
+++ b/llvm/lib/Transforms/Utils/LowerVectorIntrinsics.cpp
@@ -21,6 +21,7 @@ bool llvm::lowerUnaryVectorIntrinsicAsLoop(Module &M, CallInst *CI) {
   BasicBlock *PostLoopBB = nullptr;
   Function *ParentFunc = PreLoopBB->getParent();
   LLVMContext &Ctx = PreLoopBB->getContext();
+  Type *Int64Ty = IntegerType::get(Ctx, 64);
 
   PostLoopBB = PreLoopBB->splitBasicBlock(CI);
   BasicBlock *LoopBB = BasicBlock::Create(Ctx, "", ParentFunc, PostLoopBB);
@@ -28,22 +29,11 @@ bool llvm::lowerUnaryVectorIntrinsicAsLoop(Module &M, CallInst *CI) {
 
   // Loop preheader
   IRBuilder<> PreLoopBuilder(PreLoopBB->getTerminator());
-  Value *LoopEnd = nullptr;
-  if (auto *ScalableVecTy = dyn_cast<ScalableVectorType>(VecTy)) {
-    Value *VScale = PreLoopBuilder.CreateVScale(
-        ConstantInt::get(PreLoopBuilder.getInt64Ty(), 1));
-    Value *N = ConstantInt::get(PreLoopBuilder.getInt64Ty(),
-                                ScalableVecTy->getMinNumElements());
-    LoopEnd = PreLoopBuilder.CreateMul(VScale, N);
-  } else {
-    FixedVectorType *FixedVecTy = cast<FixedVectorType>(VecTy);
-    LoopEnd = ConstantInt::get(PreLoopBuilder.getInt64Ty(),
-                               FixedVecTy->getNumElements());
-  }
+  Value *LoopEnd =
+      PreLoopBuilder.CreateElementCount(Int64Ty, VecTy->getElementCount());
 
   // Loop body
   IRBuilder<> LoopBuilder(LoopBB);
-  Type *Int64Ty = LoopBuilder.getInt64Ty();
 
   PHINode *LoopIndex = LoopBuilder.CreatePHI(Int64Ty, 2);
   LoopIndex->addIncoming(ConstantInt::get(Int64Ty, 0U), PreLoopBB);
diff --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
index 88d2eca36ca51..70afd4133df7c 100644
--- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
+++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
@@ -1440,7 +1440,7 @@ Value *SCEVExpander::visitSequentialUMinExpr(const SCEVSequentialUMinExpr *S) {
 }
 
 Value *SCEVExpander::visitVScale(const SCEVVScale *S) {
-  return Builder.CreateVScale(ConstantInt::get(S->getType(), 1));
+  return Builder.CreateVScale(S->getType());
 }
 
 Value *SCEVExpander::expandCodeFor(const SCEV *SH, Type *Ty,
diff --git a/llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp
index 02ffc67c774dd..491f0b76f4ae0 100644
--- a/llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp
@@ -437,7 +437,7 @@ Value *LoopIdiomVectorize::createMaskedFindMismatch(
   Value *InitialPred = Builder.CreateIntrinsic(
       Intrinsic::get_active_lane_mask, {PredVTy, I64Type}, {ExtStart, ExtEnd});
 
-  Value *VecLen = Builder.CreateIntrinsic(Intrinsic::vscale, {I64Type}, {});
+  Value *VecLen = Builder.CreateVScale(I64Type);
   VecLen =
       Builder.CreateMul(VecLen, ConstantInt::get(I64Type, ByteCompareVF), "",
                         /*HasNUW=*/true, /*HasNSW=*/true);
diff --git a/llvm/unittests/IR/IRBuilderTest.cpp b/llvm/unittests/IR/IRBuilderTest.cpp
index b7eb0af728331..3a7ba924792ef 100644
--- a/llvm/unittests/IR/IRBuilderTest.cpp
+++ b/llvm/unittests/IR/IRBuilderTest.cpp
@@ -212,14 +212,6 @@ TEST_F(IRBuilderTest, IntrinsicsWithScalableVectors) {
     EXPECT_EQ(FTy->getParamType(i), Args[i]->getType());
 }
 
-TEST_F(IRBuilderTest, CreateVScale) {
-  IRBuilder<> Builder(BB);
-
-  Constant *Zero = Builder.getInt32(0);
-  Value *VScale = Builder.CreateVScale(Zero);
-  EXPECT_TRUE(isa<ConstantInt>(VScale) && cast<ConstantInt>(VScale)->isZero());
-}
-
 TEST_F(IRBuilderTest, CreateStepVector) {
   IRBuilder<> Builder(BB);
 

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

CallInst *CI =
CreateIntrinsic(Intrinsic::vscale, {Scaling->getType()}, {}, {}, Name);
return cast<ConstantInt>(Scaling)->isOne() ? CI : CreateMul(CI, Scaling);
Value *IRBuilderBase::CreateVScale(Type *Ty, const Twine &Name) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Is it worth inlining this into the header definition?

Value *IRBuilderBase::CreateTypeSize(Type *DstType, TypeSize Size) {
Constant *MinSize = ConstantInt::get(DstType, Size.getKnownMinValue());
return Size.isScalable() ? CreateVScale(MinSize) : MinSize;
Value *IRBuilderBase::CreateTypeSize(Type *Ty, TypeSize Size) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just suggestion, so please feel free to ignore it, but given how similar the two functions are, would it be useful to write a common helper such as

  Value *CreateVScaleMultiple(Type *Ty, uint64_t Scale) {
    Value *VScale = CreateVScale(Ty);
    if (Scale == 1)
      return VScale;

    return CreateMul(VScale, ConstantInt::get(Ty, Scale));
  }

and reusing this in both CreateTypeSize and CreateElementCount?

@@ -4793,11 +4793,7 @@ Value *CodeGenFunction::EmitAArch64SVEBuiltinExpr(unsigned BuiltinID,
case SVE::BI__builtin_sve_svlen_u64: {
SVETypeFlags TF(Builtin->TypeModifier);
auto VTy = cast<llvm::VectorType>(getSVEType(TF));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Is it worth strengthening the cast here, given it should always be a ScalableVectorType?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It turns out getSVEType() already returns the correct type to I've remove the redundant cast here, plus the one a few lines lower.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:AArch64 clang:codegen IR generation bugs: mangling, exceptions, etc. clang Clang issues not falling into any other category llvm:codegen llvm:instcombine llvm:ir llvm:transforms vectorizers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants