Skip to content

Commit e5f4019

Browse files
committed
[AArch64] Add extending reduction costs for addlv and dot
This adds some basic getExtendedReductionCost and getMulAccReductionCost to account for add reduction (uaddlv/saddlv) and mla reductions with dotprod.
1 parent 50581ef commit e5f4019

File tree

4 files changed

+1209
-1
lines changed

4 files changed

+1209
-1
lines changed

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

+48
Original file line numberDiff line numberDiff line change
@@ -4635,6 +4635,54 @@ AArch64TTIImpl::getArithmeticReductionCost(unsigned Opcode, VectorType *ValTy,
46354635
return BaseT::getArithmeticReductionCost(Opcode, ValTy, FMF, CostKind);
46364636
}
46374637

4638+
InstructionCost AArch64TTIImpl::getExtendedReductionCost(
4639+
unsigned Opcode, bool IsUnsigned, Type *ResTy, VectorType *VecTy,
4640+
FastMathFlags FMF, TTI::TargetCostKind CostKind) {
4641+
EVT VecVT = TLI->getValueType(DL, VecTy);
4642+
EVT ResVT = TLI->getValueType(DL, ResTy);
4643+
4644+
if (Opcode == Instruction::Add && VecVT.isSimple() && ResVT.isSimple() &&
4645+
VecVT.getSizeInBits() >= 64) {
4646+
std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(VecTy);
4647+
4648+
// The legal cases are:
4649+
// UADDLV 8/16/32->32
4650+
// UADDLP 32->64
4651+
unsigned RevVTSize = ResVT.getSizeInBits();
4652+
if (((LT.second == MVT::v8i8 || LT.second == MVT::v16i8) &&
4653+
RevVTSize <= 32) ||
4654+
((LT.second == MVT::v4i16 || LT.second == MVT::v8i16) &&
4655+
RevVTSize <= 32) ||
4656+
((LT.second == MVT::v2i32 || LT.second == MVT::v4i32) &&
4657+
RevVTSize <= 64))
4658+
return (LT.first - 1) * 2 + 2;
4659+
}
4660+
4661+
return BaseT::getExtendedReductionCost(Opcode, IsUnsigned, ResTy, VecTy, FMF,
4662+
CostKind);
4663+
}
4664+
4665+
InstructionCost
4666+
AArch64TTIImpl::getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
4667+
VectorType *VecTy,
4668+
TTI::TargetCostKind CostKind) {
4669+
EVT VecVT = TLI->getValueType(DL, VecTy);
4670+
EVT ResVT = TLI->getValueType(DL, ResTy);
4671+
4672+
if (ST->hasDotProd() && VecVT.isSimple() && ResVT.isSimple()) {
4673+
std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(VecTy);
4674+
4675+
// The legal cases with dotprod are
4676+
// UDOT 8->32
4677+
// Which requires an additional uaddv to sum the i32 values.
4678+
if ((LT.second == MVT::v8i8 || LT.second == MVT::v16i8) &&
4679+
ResVT == MVT::i32)
4680+
return LT.first + 2;
4681+
}
4682+
4683+
return BaseT::getMulAccReductionCost(IsUnsigned, ResTy, VecTy, CostKind);
4684+
}
4685+
46384686
InstructionCost AArch64TTIImpl::getSpliceCost(VectorType *Tp, int Index) {
46394687
static const CostTblEntry ShuffleTbl[] = {
46404688
{ TTI::SK_Splice, MVT::nxv16i8, 1 },

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h

+9
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,15 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
425425
std::optional<FastMathFlags> FMF,
426426
TTI::TargetCostKind CostKind);
427427

428+
InstructionCost getExtendedReductionCost(unsigned Opcode, bool IsUnsigned,
429+
Type *ResTy, VectorType *ValTy,
430+
FastMathFlags FMF,
431+
TTI::TargetCostKind CostKind);
432+
433+
InstructionCost getMulAccReductionCost(
434+
bool IsUnsigned, Type *ResTy, VectorType *Ty,
435+
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput);
436+
428437
InstructionCost getShuffleCost(TTI::ShuffleKind Kind, VectorType *Tp,
429438
ArrayRef<int> Mask,
430439
TTI::TargetCostKind CostKind, int Index,

llvm/test/Transforms/SLPVectorizer/AArch64/horizontal.ll

+1-1
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ for.end: ; preds = %for.end.loopexit, %
228228
; YAML-NEXT: Function: test_unrolled_select
229229
; YAML-NEXT: Args:
230230
; YAML-NEXT: - String: 'Vectorized horizontal reduction with cost '
231-
; YAML-NEXT: - Cost: '-41'
231+
; YAML-NEXT: - Cost: '-44'
232232
; YAML-NEXT: - String: ' and with tree size '
233233
; YAML-NEXT: - TreeSize: '10'
234234

0 commit comments

Comments
 (0)