@@ -4635,6 +4635,54 @@ AArch64TTIImpl::getArithmeticReductionCost(unsigned Opcode, VectorType *ValTy,
4635
4635
return BaseT::getArithmeticReductionCost (Opcode, ValTy, FMF, CostKind);
4636
4636
}
4637
4637
4638
+ InstructionCost AArch64TTIImpl::getExtendedReductionCost (
4639
+ unsigned Opcode, bool IsUnsigned, Type *ResTy, VectorType *VecTy,
4640
+ FastMathFlags FMF, TTI::TargetCostKind CostKind) {
4641
+ EVT VecVT = TLI->getValueType (DL, VecTy);
4642
+ EVT ResVT = TLI->getValueType (DL, ResTy);
4643
+
4644
+ if (Opcode == Instruction::Add && VecVT.isSimple () && ResVT.isSimple () &&
4645
+ VecVT.getSizeInBits () >= 64 ) {
4646
+ std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost (VecTy);
4647
+
4648
+ // The legal cases are:
4649
+ // UADDLV 8/16/32->32
4650
+ // UADDLP 32->64
4651
+ unsigned RevVTSize = ResVT.getSizeInBits ();
4652
+ if (((LT.second == MVT::v8i8 || LT.second == MVT::v16i8) &&
4653
+ RevVTSize <= 32 ) ||
4654
+ ((LT.second == MVT::v4i16 || LT.second == MVT::v8i16) &&
4655
+ RevVTSize <= 32 ) ||
4656
+ ((LT.second == MVT::v2i32 || LT.second == MVT::v4i32) &&
4657
+ RevVTSize <= 64 ))
4658
+ return (LT.first - 1 ) * 2 + 2 ;
4659
+ }
4660
+
4661
+ return BaseT::getExtendedReductionCost (Opcode, IsUnsigned, ResTy, VecTy, FMF,
4662
+ CostKind);
4663
+ }
4664
+
4665
+ InstructionCost
4666
+ AArch64TTIImpl::getMulAccReductionCost (bool IsUnsigned, Type *ResTy,
4667
+ VectorType *VecTy,
4668
+ TTI::TargetCostKind CostKind) {
4669
+ EVT VecVT = TLI->getValueType (DL, VecTy);
4670
+ EVT ResVT = TLI->getValueType (DL, ResTy);
4671
+
4672
+ if (ST->hasDotProd () && VecVT.isSimple () && ResVT.isSimple ()) {
4673
+ std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost (VecTy);
4674
+
4675
+ // The legal cases with dotprod are
4676
+ // UDOT 8->32
4677
+ // Which requires an additional uaddv to sum the i32 values.
4678
+ if ((LT.second == MVT::v8i8 || LT.second == MVT::v16i8) &&
4679
+ ResVT == MVT::i32)
4680
+ return LT.first + 2 ;
4681
+ }
4682
+
4683
+ return BaseT::getMulAccReductionCost (IsUnsigned, ResTy, VecTy, CostKind);
4684
+ }
4685
+
4638
4686
InstructionCost AArch64TTIImpl::getSpliceCost (VectorType *Tp, int Index) {
4639
4687
static const CostTblEntry ShuffleTbl[] = {
4640
4688
{ TTI::SK_Splice, MVT::nxv16i8, 1 },
0 commit comments