Skip to content

[TLI] Pass replace-with-veclib works with Scalable Vectors. #73642

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion llvm/lib/Analysis/VFABIDemangling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ static ParseRet tryParseLinearTokenWithRuntimeStep(StringRef &ParseString,
return ParseRet::None;
}

/// The function looks for the following stringt at the beginning of
/// The function looks for the following string at the beginning of
/// the input string `ParseString`:
///
/// <token> <number>
Expand Down
210 changes: 105 additions & 105 deletions llvm/lib/CodeGen/ReplaceWithVeclib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,19 @@
#include "llvm/CodeGen/ReplaceWithVeclib.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Analysis/DemandedBits.h"
#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/VectorUtils.h"
#include "llvm/CodeGen/Passes.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/Support/TypeSize.h"
#include "llvm/Transforms/Utils/ModuleUtils.h"
#include <optional>

using namespace llvm;

Expand All @@ -38,138 +42,135 @@ STATISTIC(NumTLIFuncDeclAdded,
STATISTIC(NumFuncUsedAdded,
"Number of functions added to `llvm.compiler.used`");

static bool replaceWithTLIFunction(CallInst &CI, const StringRef TLIName) {
Module *M = CI.getModule();

Function *OldFunc = CI.getCalledFunction();

// Check if the vector library function is already declared in this module,
// otherwise insert it.
/// Returns a vector Function that it adds to the Module \p M. When an \p
/// ScalarFunc is not null, it copies its attributes to the newly created
/// Function.
Function *getTLIFunction(Module *M, FunctionType *VectorFTy,
Function *ScalarFunc, const StringRef TLIName) {
Function *TLIFunc = M->getFunction(TLIName);
if (!TLIFunc) {
TLIFunc = Function::Create(OldFunc->getFunctionType(),
Function::ExternalLinkage, TLIName, *M);
TLIFunc->copyAttributesFrom(OldFunc);
TLIFunc =
Function::Create(VectorFTy, Function::ExternalLinkage, TLIName, *M);
if (ScalarFunc)
TLIFunc->copyAttributesFrom(ScalarFunc);

LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Added vector library function `"
<< TLIName << "` of type `" << *(TLIFunc->getType())
<< "` to module.\n");

++NumTLIFuncDeclAdded;

// Add the freshly created function to llvm.compiler.used,
// similar to as it is done in InjectTLIMappings
// Add the freshly created function to llvm.compiler.used, similar to as it
// is done in InjectTLIMappings.
appendToCompilerUsed(*M, {TLIFunc});

LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Adding `" << TLIName
<< "` to `@llvm.compiler.used`.\n");
++NumFuncUsedAdded;
}
return TLIFunc;
}

// Replace the call to the vector intrinsic with a call
// to the corresponding function from the vector library.
IRBuilder<> IRBuilder(&CI);
SmallVector<Value *> Args(CI.args());
// Preserve the operand bundles.
SmallVector<OperandBundleDef, 1> OpBundles;
CI.getOperandBundlesAsDefs(OpBundles);
CallInst *Replacement = IRBuilder.CreateCall(TLIFunc, Args, OpBundles);
assert(OldFunc->getFunctionType() == TLIFunc->getFunctionType() &&
"Expecting function types to be identical");
CI.replaceAllUsesWith(Replacement);
if (isa<FPMathOperator>(Replacement)) {
// Preserve fast math flags for FP math.
Replacement->copyFastMathFlags(&CI);
/// Replace the call to the vector intrinsic ( \p CalltoReplace ) with a call to
/// the corresponding function from the vector library ( \p TLIVecFunc ).
static void replaceWithTLIFunction(CallInst &CalltoReplace, VFInfo &Info,
Function *TLIVecFunc) {
IRBuilder<> IRBuilder(&CalltoReplace);
SmallVector<Value *> Args(CalltoReplace.args());
if (auto OptMaskpos = Info.getParamIndexForOptionalMask()) {
auto *MaskTy = VectorType::get(Type::getInt1Ty(CalltoReplace.getContext()),
Info.Shape.VF);
Args.insert(Args.begin() + OptMaskpos.value(),
Constant::getAllOnesValue(MaskTy));
}

LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `"
<< OldFunc->getName() << "` with call to `" << TLIName
<< "`.\n");
++NumCallsReplaced;
return true;
// Preserve the operand bundles.
SmallVector<OperandBundleDef, 1> OpBundles;
CalltoReplace.getOperandBundlesAsDefs(OpBundles);
CallInst *Replacement = IRBuilder.CreateCall(TLIVecFunc, Args, OpBundles);
CalltoReplace.replaceAllUsesWith(Replacement);
// Preserve fast math flags for FP math.
if (isa<FPMathOperator>(Replacement))
Replacement->copyFastMathFlags(&CalltoReplace);
}

/// Returns true when successfully replaced \p CallToReplace with a suitable
/// function taking vector arguments, based on available mappings in the \p TLI.
/// Currently only works when \p CallToReplace is a call to vectorized
/// intrinsic.
static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
CallInst &CI) {
if (!CI.getCalledFunction()) {
CallInst &CallToReplace) {
if (!CallToReplace.getCalledFunction())
return false;
}

auto IntrinsicID = CI.getCalledFunction()->getIntrinsicID();
if (IntrinsicID == Intrinsic::not_intrinsic) {
// Replacement is only performed for intrinsic functions
auto IntrinsicID = CallToReplace.getCalledFunction()->getIntrinsicID();
// Replacement is only performed for intrinsic functions.
if (IntrinsicID == Intrinsic::not_intrinsic)
return false;
}

// Convert vector arguments to scalar type and check that
// all vector operands have identical vector width.
// Compute arguments types of the corresponding scalar call. Additionally
// checks if in the vector call, all vector operands have the same EC.
ElementCount VF = ElementCount::getFixed(0);
SmallVector<Type *> ScalarTypes;
for (auto Arg : enumerate(CI.args())) {
auto *ArgType = Arg.value()->getType();
// Vector calls to intrinsics can still have
// scalar operands for specific arguments.
SmallVector<Type *> ScalarArgTypes;
for (auto Arg : enumerate(CallToReplace.args())) {
auto *ArgTy = Arg.value()->getType();
if (isVectorIntrinsicWithScalarOpAtArg(IntrinsicID, Arg.index())) {
ScalarTypes.push_back(ArgType);
} else {
// The argument in this place should be a vector if
// this is a call to a vector intrinsic.
auto *VectorArgTy = dyn_cast<VectorType>(ArgType);
if (!VectorArgTy) {
// The argument is not a vector, do not perform
// the replacement.
return false;
}
ElementCount NumElements = VectorArgTy->getElementCount();
if (NumElements.isScalable()) {
// The current implementation does not support
// scalable vectors.
ScalarArgTypes.push_back(ArgTy);
} else if (auto *VectorArgTy = dyn_cast<VectorType>(ArgTy)) {
ScalarArgTypes.push_back(ArgTy->getScalarType());
// Disallow vector arguments with different VFs. When processing the first
// vector argument, store it's VF, and for the rest ensure that they match
// it.
if (VF.isZero())
VF = VectorArgTy->getElementCount();
else if (VF != VectorArgTy->getElementCount())
return false;
}
if (VF.isNonZero() && VF != NumElements) {
// The different arguments differ in vector size.
return false;
} else {
VF = NumElements;
}
ScalarTypes.push_back(VectorArgTy->getElementType());
}
} else
// Exit when it is supposed to be a vector argument but it isn't.
return false;
}

// Try to reconstruct the name for the scalar version of this
// intrinsic using the intrinsic ID and the argument types
// converted to scalar above.
std::string ScalarName;
if (Intrinsic::isOverloaded(IntrinsicID)) {
ScalarName = Intrinsic::getName(IntrinsicID, ScalarTypes, CI.getModule());
} else {
ScalarName = Intrinsic::getName(IntrinsicID).str();
}
// Try to reconstruct the name for the scalar version of this intrinsic using
// the intrinsic ID and the argument types converted to scalar above.
std::string ScalarName =
(Intrinsic::isOverloaded(IntrinsicID)
? Intrinsic::getName(IntrinsicID, ScalarArgTypes,
CallToReplace.getModule())
: Intrinsic::getName(IntrinsicID).str());

// Try to find the mapping for the scalar version of this intrinsic and the
// exact vector width of the call operands in the TargetLibraryInfo. First,
// check with a non-masked variant, and if that fails try with a masked one.
const VecDesc *VD = TLI.getVectorMappingInfo(ScalarName, VF, false);
if (!VD && !(VD = TLI.getVectorMappingInfo(ScalarName, VF, true)))
return false;

if (!TLI.isFunctionVectorizable(ScalarName)) {
// The TargetLibraryInfo does not contain a vectorized version of
// the scalar function.
LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Found TLI mapping from: `" << ScalarName
<< "` and vector width " << VF << " to: `"
<< VD->getVectorFnName() << "`.\n");

// Replace the call to the intrinsic with a call to the vector library
// function.
Type *ScalarRetTy = CallToReplace.getType()->getScalarType();
FunctionType *ScalarFTy =
FunctionType::get(ScalarRetTy, ScalarArgTypes, /*isVarArg*/ false);
const std::string MangledName = VD->getVectorFunctionABIVariantString();
auto OptInfo = VFABI::tryDemangleForVFABI(MangledName, ScalarFTy);
if (!OptInfo)
return false;
}

// Try to find the mapping for the scalar version of this intrinsic
// and the exact vector width of the call operands in the
// TargetLibraryInfo.
StringRef TLIName = TLI.getVectorizedFunction(ScalarName, VF);

LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Looking up TLI mapping for `"
<< ScalarName << "` and vector width " << VF << ".\n");

if (!TLIName.empty()) {
// Found the correct mapping in the TargetLibraryInfo,
// replace the call to the intrinsic with a call to
// the vector library function.
LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Found TLI function `" << TLIName
<< "`.\n");
return replaceWithTLIFunction(CI, TLIName);
}
FunctionType *VectorFTy = VFABI::createFunctionType(*OptInfo, ScalarFTy);
if (!VectorFTy)
return false;

Function *FuncToReplace = CallToReplace.getCalledFunction();
Function *TLIFunc = getTLIFunction(CallToReplace.getModule(), VectorFTy,
FuncToReplace, VD->getVectorFnName());
replaceWithTLIFunction(CallToReplace, *OptInfo, TLIFunc);

return false;
LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `"
<< FuncToReplace->getName() << "` with call to `"
<< TLIFunc->getName() << "`.\n");
++NumCallsReplaced;
return true;
}

static bool runImpl(const TargetLibraryInfo &TLI, Function &F) {
Expand All @@ -185,9 +186,8 @@ static bool runImpl(const TargetLibraryInfo &TLI, Function &F) {
}
// Erase the calls to the intrinsics that have been replaced
// with calls to the vector library.
for (auto *CI : ReplacedCalls) {
for (auto *CI : ReplacedCalls)
CI->eraseFromParent();
}
return Changed;
}

Expand All @@ -207,10 +207,10 @@ PreservedAnalyses ReplaceWithVeclib::run(Function &F,
PA.preserve<DemandedBitsAnalysis>();
PA.preserve<OptimizationRemarkEmitterAnalysis>();
return PA;
} else {
// The pass did not replace any calls, hence it preserves all analyses.
return PreservedAnalyses::all();
}

// The pass did not replace any calls, hence it preserves all analyses.
return PreservedAnalyses::all();
}

////////////////////////////////////////////////////////////////////////////////
Expand Down
Loading