Skip to content

Commit 4a1c53f

Browse files
committed
[SLP]Improve minbitwidth analysis for abs/smin/smax/umin/umax intrinsics.
https://alive2.llvm.org/ce/z/ivPZ26 for the abs transformations. Reviewers: RKSimon Reviewed By: RKSimon Pull Request: #86135
1 parent ed1b24b commit 4a1c53f

File tree

4 files changed

+117
-27
lines changed

4 files changed

+117
-27
lines changed

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 105 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7056,19 +7056,16 @@ bool BoUpSLP::areAllUsersVectorized(
70567056

70577057
static std::pair<InstructionCost, InstructionCost>
70587058
getVectorCallCosts(CallInst *CI, FixedVectorType *VecTy,
7059-
TargetTransformInfo *TTI, TargetLibraryInfo *TLI) {
7059+
TargetTransformInfo *TTI, TargetLibraryInfo *TLI,
7060+
ArrayRef<Type *> ArgTys) {
70607061
Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI);
70617062

70627063
// Calculate the cost of the scalar and vector calls.
7063-
SmallVector<Type *, 4> VecTys;
7064-
for (Use &Arg : CI->args())
7065-
VecTys.push_back(
7066-
FixedVectorType::get(Arg->getType(), VecTy->getNumElements()));
70677064
FastMathFlags FMF;
70687065
if (auto *FPCI = dyn_cast<FPMathOperator>(CI))
70697066
FMF = FPCI->getFastMathFlags();
70707067
SmallVector<const Value *> Arguments(CI->args());
7071-
IntrinsicCostAttributes CostAttrs(ID, VecTy, Arguments, VecTys, FMF,
7068+
IntrinsicCostAttributes CostAttrs(ID, VecTy, Arguments, ArgTys, FMF,
70727069
dyn_cast<IntrinsicInst>(CI));
70737070
auto IntrinsicCost =
70747071
TTI->getIntrinsicInstrCost(CostAttrs, TTI::TCK_RecipThroughput);
@@ -7081,8 +7078,8 @@ getVectorCallCosts(CallInst *CI, FixedVectorType *VecTy,
70817078
if (!CI->isNoBuiltin() && VecFunc) {
70827079
// Calculate the cost of the vector library call.
70837080
// If the corresponding vector call is cheaper, return its cost.
7084-
LibCost = TTI->getCallInstrCost(nullptr, VecTy, VecTys,
7085-
TTI::TCK_RecipThroughput);
7081+
LibCost =
7082+
TTI->getCallInstrCost(nullptr, VecTy, ArgTys, TTI::TCK_RecipThroughput);
70867083
}
70877084
return {IntrinsicCost, LibCost};
70887085
}
@@ -8508,6 +8505,30 @@ TTI::CastContextHint BoUpSLP::getCastContextHint(const TreeEntry &TE) const {
85088505
return TTI::CastContextHint::None;
85098506
}
85108507

8508+
/// Builds the arguments types vector for the given call instruction with the
8509+
/// given \p ID for the specified vector factor.
8510+
static SmallVector<Type *> buildIntrinsicArgTypes(const CallInst *CI,
8511+
const Intrinsic::ID ID,
8512+
const unsigned VF,
8513+
unsigned MinBW) {
8514+
SmallVector<Type *> ArgTys;
8515+
for (auto [Idx, Arg] : enumerate(CI->args())) {
8516+
if (ID != Intrinsic::not_intrinsic) {
8517+
if (isVectorIntrinsicWithScalarOpAtArg(ID, Idx)) {
8518+
ArgTys.push_back(Arg->getType());
8519+
continue;
8520+
}
8521+
if (MinBW > 0) {
8522+
ArgTys.push_back(FixedVectorType::get(
8523+
IntegerType::get(CI->getContext(), MinBW), VF));
8524+
continue;
8525+
}
8526+
}
8527+
ArgTys.push_back(FixedVectorType::get(Arg->getType(), VF));
8528+
}
8529+
return ArgTys;
8530+
}
8531+
85118532
InstructionCost
85128533
BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
85138534
SmallPtrSetImpl<Value *> &CheckedExtracts) {
@@ -9074,7 +9095,11 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
90749095
};
90759096
auto GetVectorCost = [=](InstructionCost CommonCost) {
90769097
auto *CI = cast<CallInst>(VL0);
9077-
auto VecCallCosts = getVectorCallCosts(CI, VecTy, TTI, TLI);
9098+
Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI);
9099+
SmallVector<Type *> ArgTys =
9100+
buildIntrinsicArgTypes(CI, ID, VecTy->getNumElements(),
9101+
It != MinBWs.end() ? It->second.first : 0);
9102+
auto VecCallCosts = getVectorCallCosts(CI, VecTy, TTI, TLI, ArgTys);
90789103
return std::min(VecCallCosts.first, VecCallCosts.second) + CommonCost;
90799104
};
90809105
return GetCostDiff(GetScalarCost, GetVectorCost);
@@ -12548,7 +12573,10 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
1254812573

1254912574
Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI);
1255012575

12551-
auto VecCallCosts = getVectorCallCosts(CI, VecTy, TTI, TLI);
12576+
SmallVector<Type *> ArgTys =
12577+
buildIntrinsicArgTypes(CI, ID, VecTy->getNumElements(),
12578+
It != MinBWs.end() ? It->second.first : 0);
12579+
auto VecCallCosts = getVectorCallCosts(CI, VecTy, TTI, TLI, ArgTys);
1255212580
bool UseIntrinsic = ID != Intrinsic::not_intrinsic &&
1255312581
VecCallCosts.first <= VecCallCosts.second;
1255412582

@@ -12557,16 +12585,20 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
1255712585
SmallVector<Type *, 2> TysForDecl;
1255812586
// Add return type if intrinsic is overloaded on it.
1255912587
if (UseIntrinsic && isVectorIntrinsicWithOverloadTypeAtArg(ID, -1))
12560-
TysForDecl.push_back(
12561-
FixedVectorType::get(CI->getType(), E->Scalars.size()));
12588+
TysForDecl.push_back(VecTy);
1256212589
auto *CEI = cast<CallInst>(VL0);
1256312590
for (unsigned I : seq<unsigned>(0, CI->arg_size())) {
1256412591
ValueList OpVL;
1256512592
// Some intrinsics have scalar arguments. This argument should not be
1256612593
// vectorized.
1256712594
if (UseIntrinsic && isVectorIntrinsicWithScalarOpAtArg(ID, I)) {
1256812595
ScalarArg = CEI->getArgOperand(I);
12569-
OpVecs.push_back(CEI->getArgOperand(I));
12596+
// if decided to reduce bitwidth of abs intrinsic, it second argument
12597+
// must be set false (do not return poison, if value issigned min).
12598+
if (ID == Intrinsic::abs && It != MinBWs.end() &&
12599+
It->second.first < DL->getTypeSizeInBits(CEI->getType()))
12600+
ScalarArg = Builder.getFalse();
12601+
OpVecs.push_back(ScalarArg);
1257012602
if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I))
1257112603
TysForDecl.push_back(ScalarArg->getType());
1257212604
continue;
@@ -12579,10 +12611,13 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
1257912611
}
1258012612
ScalarArg = CEI->getArgOperand(I);
1258112613
if (cast<VectorType>(OpVec->getType())->getElementType() !=
12582-
ScalarArg->getType()) {
12614+
ScalarArg->getType() &&
12615+
It == MinBWs.end()) {
1258312616
auto *CastTy = FixedVectorType::get(ScalarArg->getType(),
1258412617
VecTy->getNumElements());
1258512618
OpVec = Builder.CreateIntCast(OpVec, CastTy, GetOperandSignedness(I));
12619+
} else if (It != MinBWs.end()) {
12620+
OpVec = Builder.CreateIntCast(OpVec, VecTy, GetOperandSignedness(I));
1258612621
}
1258712622
LLVM_DEBUG(dbgs() << "SLP: OpVec[" << I << "]: " << *OpVec << "\n");
1258812623
OpVecs.push_back(OpVec);
@@ -14326,6 +14361,62 @@ bool BoUpSLP::collectValuesToDemote(
1432614361
return TryProcessInstruction(I, *ITE, BitWidth, Ops);
1432714362
}
1432814363

14364+
case Instruction::Call: {
14365+
auto *IC = dyn_cast<IntrinsicInst>(I);
14366+
if (!IC)
14367+
break;
14368+
Intrinsic::ID ID = getVectorIntrinsicIDForCall(IC, TLI);
14369+
if (ID != Intrinsic::abs && ID != Intrinsic::smin &&
14370+
ID != Intrinsic::smax && ID != Intrinsic::umin && ID != Intrinsic::umax)
14371+
break;
14372+
SmallVector<Value *> Operands(1, I->getOperand(0));
14373+
function_ref<bool(unsigned, unsigned)> CallChecker;
14374+
auto CompChecker = [&](unsigned BitWidth, unsigned OrigBitWidth) {
14375+
assert(BitWidth <= OrigBitWidth && "Unexpected bitwidths!");
14376+
if (ID == Intrinsic::umin || ID == Intrinsic::umax) {
14377+
APInt Mask = APInt::getBitsSetFrom(OrigBitWidth, BitWidth);
14378+
return MaskedValueIsZero(I->getOperand(0), Mask, SimplifyQuery(*DL)) &&
14379+
MaskedValueIsZero(I->getOperand(1), Mask, SimplifyQuery(*DL));
14380+
}
14381+
assert((ID == Intrinsic::smin || ID == Intrinsic::smax) &&
14382+
"Expected min/max intrinsics only.");
14383+
unsigned SignBits = OrigBitWidth - BitWidth;
14384+
return SignBits <= ComputeNumSignBits(I->getOperand(0), *DL, 0, AC,
14385+
nullptr, DT) &&
14386+
SignBits <=
14387+
ComputeNumSignBits(I->getOperand(1), *DL, 0, AC, nullptr, DT);
14388+
};
14389+
End = 1;
14390+
if (ID != Intrinsic::abs) {
14391+
Operands.push_back(I->getOperand(1));
14392+
End = 2;
14393+
CallChecker = CompChecker;
14394+
}
14395+
InstructionCost BestCost =
14396+
std::numeric_limits<InstructionCost::CostType>::max();
14397+
unsigned BestBitWidth = BitWidth;
14398+
unsigned VF = ITE->Scalars.size();
14399+
// Choose the best bitwidth based on cost estimations.
14400+
auto Checker = [&](unsigned BitWidth, unsigned) {
14401+
unsigned MinBW = PowerOf2Ceil(BitWidth);
14402+
SmallVector<Type *> ArgTys = buildIntrinsicArgTypes(IC, ID, VF, MinBW);
14403+
auto VecCallCosts = getVectorCallCosts(
14404+
IC,
14405+
FixedVectorType::get(IntegerType::get(IC->getContext(), MinBW), VF),
14406+
TTI, TLI, ArgTys);
14407+
InstructionCost Cost = std::min(VecCallCosts.first, VecCallCosts.second);
14408+
if (Cost < BestCost) {
14409+
BestCost = Cost;
14410+
BestBitWidth = BitWidth;
14411+
}
14412+
return false;
14413+
};
14414+
[[maybe_unused]] bool NeedToExit;
14415+
(void)AttemptCheckBitwidth(Checker, NeedToExit);
14416+
BitWidth = BestBitWidth;
14417+
return TryProcessInstruction(I, *ITE, BitWidth, Operands, CallChecker);
14418+
}
14419+
1432914420
// Otherwise, conservatively give up.
1433014421
default:
1433114422
break;

llvm/test/Transforms/SLPVectorizer/X86/call-arg-reduced-by-minbitwidth.ll

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,7 @@ define void @test(ptr %0, i8 %1, i1 %cmp12.i) {
1111
; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <8 x i8> [[TMP4]], <8 x i8> poison, <8 x i32> zeroinitializer
1212
; CHECK-NEXT: br label [[PRE:%.*]]
1313
; CHECK: pre:
14-
; CHECK-NEXT: [[TMP6:%.*]] = zext <8 x i8> [[TMP5]] to <8 x i32>
15-
; CHECK-NEXT: [[TMP7:%.*]] = call <8 x i32> @llvm.umax.v8i32(<8 x i32> [[TMP6]], <8 x i32> <i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1>)
16-
; CHECK-NEXT: [[TMP8:%.*]] = trunc <8 x i32> [[TMP7]] to <8 x i8>
14+
; CHECK-NEXT: [[TMP8:%.*]] = call <8 x i8> @llvm.umax.v8i8(<8 x i8> [[TMP5]], <8 x i8> <i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1>)
1715
; CHECK-NEXT: [[TMP9:%.*]] = add <8 x i8> [[TMP8]], <i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1>
1816
; CHECK-NEXT: [[TMP10:%.*]] = select <8 x i1> [[TMP3]], <8 x i8> [[TMP9]], <8 x i8> [[TMP5]]
1917
; CHECK-NEXT: store <8 x i8> [[TMP10]], ptr [[TMP0]], align 1

llvm/test/Transforms/SLPVectorizer/X86/cmp-after-intrinsic-call-minbitwidth.ll

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@ define void @test() {
55
; CHECK-LABEL: define void @test(
66
; CHECK-SAME: ) #[[ATTR0:[0-9]+]] {
77
; CHECK-NEXT: entry:
8-
; CHECK-NEXT: [[TMP0:%.*]] = call <2 x i32> @llvm.smin.v2i32(<2 x i32> zeroinitializer, <2 x i32> zeroinitializer)
9-
; CHECK-NEXT: [[TMP1:%.*]] = select <2 x i1> zeroinitializer, <2 x i32> zeroinitializer, <2 x i32> [[TMP0]]
10-
; CHECK-NEXT: [[TMP2:%.*]] = or <2 x i32> [[TMP1]], zeroinitializer
11-
; CHECK-NEXT: [[ADD:%.*]] = extractelement <2 x i32> [[TMP2]], i32 1
8+
; CHECK-NEXT: [[TMP0:%.*]] = call <2 x i2> @llvm.smin.v2i2(<2 x i2> zeroinitializer, <2 x i2> zeroinitializer)
9+
; CHECK-NEXT: [[TMP1:%.*]] = select <2 x i1> zeroinitializer, <2 x i2> zeroinitializer, <2 x i2> [[TMP0]]
10+
; CHECK-NEXT: [[TMP2:%.*]] = or <2 x i2> [[TMP1]], zeroinitializer
11+
; CHECK-NEXT: [[TMP3:%.*]] = extractelement <2 x i2> [[TMP2]], i32 1
12+
; CHECK-NEXT: [[ADD:%.*]] = zext i2 [[TMP3]] to i32
1213
; CHECK-NEXT: [[SHR:%.*]] = ashr i32 [[ADD]], 0
13-
; CHECK-NEXT: [[ADD45:%.*]] = extractelement <2 x i32> [[TMP2]], i32 0
14+
; CHECK-NEXT: [[TMP5:%.*]] = extractelement <2 x i2> [[TMP2]], i32 0
15+
; CHECK-NEXT: [[ADD45:%.*]] = zext i2 [[TMP5]] to i32
1416
; CHECK-NEXT: [[ADD152:%.*]] = or i32 [[ADD45]], [[ADD]]
1517
; CHECK-NEXT: [[IDXPROM153:%.*]] = sext i32 [[ADD152]] to i64
1618
; CHECK-NEXT: [[ARRAYIDX154:%.*]] = getelementptr i8, ptr null, i64 [[IDXPROM153]]

llvm/test/Transforms/SLPVectorizer/X86/store-abs-minbitwidth.ll

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,13 @@ define i32 @test(ptr noalias %in, ptr noalias %inn, ptr %out) {
1313
; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <2 x i8> [[TMP3]], <2 x i8> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
1414
; CHECK-NEXT: [[TMP6:%.*]] = shufflevector <2 x i8> [[TMP2]], <2 x i8> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
1515
; CHECK-NEXT: [[TMP7:%.*]] = shufflevector <4 x i8> [[TMP5]], <4 x i8> [[TMP6]], <4 x i32> <i32 0, i32 1, i32 4, i32 5>
16-
; CHECK-NEXT: [[TMP8:%.*]] = sext <4 x i8> [[TMP7]] to <4 x i32>
16+
; CHECK-NEXT: [[TMP8:%.*]] = sext <4 x i8> [[TMP7]] to <4 x i16>
1717
; CHECK-NEXT: [[TMP9:%.*]] = shufflevector <2 x i8> [[TMP1]], <2 x i8> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
1818
; CHECK-NEXT: [[TMP10:%.*]] = shufflevector <2 x i8> [[TMP4]], <2 x i8> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
1919
; CHECK-NEXT: [[TMP11:%.*]] = shufflevector <4 x i8> [[TMP9]], <4 x i8> [[TMP10]], <4 x i32> <i32 0, i32 1, i32 4, i32 5>
20-
; CHECK-NEXT: [[TMP12:%.*]] = sext <4 x i8> [[TMP11]] to <4 x i32>
21-
; CHECK-NEXT: [[TMP13:%.*]] = sub <4 x i32> [[TMP12]], [[TMP8]]
22-
; CHECK-NEXT: [[TMP14:%.*]] = call <4 x i32> @llvm.abs.v4i32(<4 x i32> [[TMP13]], i1 true)
23-
; CHECK-NEXT: [[TMP15:%.*]] = trunc <4 x i32> [[TMP14]] to <4 x i16>
20+
; CHECK-NEXT: [[TMP12:%.*]] = sext <4 x i8> [[TMP11]] to <4 x i16>
21+
; CHECK-NEXT: [[TMP13:%.*]] = sub <4 x i16> [[TMP12]], [[TMP8]]
22+
; CHECK-NEXT: [[TMP15:%.*]] = call <4 x i16> @llvm.abs.v4i16(<4 x i16> [[TMP13]], i1 false)
2423
; CHECK-NEXT: store <4 x i16> [[TMP15]], ptr [[OUT:%.*]], align 2
2524
; CHECK-NEXT: ret i32 undef
2625
;

0 commit comments

Comments
 (0)