@@ -114,6 +114,7 @@ class VectorCombine {
114114 bool scalarizeBinopOrCmp (Instruction &I);
115115 bool scalarizeVPIntrinsic (Instruction &I);
116116 bool foldExtractedCmps (Instruction &I);
117+ bool foldBinopOfReductions (Instruction &I);
117118 bool foldSingleElementStore (Instruction &I);
118119 bool scalarizeLoadExtract (Instruction &I);
119120 bool foldConcatOfBoolMasks (Instruction &I);
@@ -1242,6 +1243,121 @@ bool VectorCombine::foldExtractedCmps(Instruction &I) {
12421243 return true ;
12431244}
12441245
1246+ static void analyzeCostOfVecReduction (const IntrinsicInst &II,
1247+ TTI::TargetCostKind CostKind,
1248+ const TargetTransformInfo &TTI,
1249+ InstructionCost &CostBeforeReduction,
1250+ InstructionCost &CostAfterReduction) {
1251+ Instruction *Op0, *Op1;
1252+ auto *RedOp = dyn_cast<Instruction>(II.getOperand (0 ));
1253+ auto *VecRedTy = cast<VectorType>(II.getOperand (0 )->getType ());
1254+ unsigned ReductionOpc =
1255+ getArithmeticReductionInstruction (II.getIntrinsicID ());
1256+ if (RedOp && match (RedOp, m_ZExtOrSExt (m_Value ()))) {
1257+ bool IsUnsigned = isa<ZExtInst>(RedOp);
1258+ auto *ExtType = cast<VectorType>(RedOp->getOperand (0 )->getType ());
1259+
1260+ CostBeforeReduction =
1261+ TTI.getCastInstrCost (RedOp->getOpcode (), VecRedTy, ExtType,
1262+ TTI::CastContextHint::None, CostKind, RedOp);
1263+ CostAfterReduction =
1264+ TTI.getExtendedReductionCost (ReductionOpc, IsUnsigned, II.getType (),
1265+ ExtType, FastMathFlags (), CostKind);
1266+ return ;
1267+ }
1268+ if (RedOp && II.getIntrinsicID () == Intrinsic::vector_reduce_add &&
1269+ match (RedOp,
1270+ m_ZExtOrSExt (m_Mul (m_Instruction (Op0), m_Instruction (Op1)))) &&
1271+ match (Op0, m_ZExtOrSExt (m_Value ())) &&
1272+ Op0->getOpcode () == Op1->getOpcode () &&
1273+ Op0->getOperand (0 )->getType () == Op1->getOperand (0 )->getType () &&
1274+ (Op0->getOpcode () == RedOp->getOpcode () || Op0 == Op1)) {
1275+ // Matched reduce.add(ext(mul(ext(A), ext(B)))
1276+ bool IsUnsigned = isa<ZExtInst>(Op0);
1277+ auto *ExtType = cast<VectorType>(Op0->getOperand (0 )->getType ());
1278+ VectorType *MulType = VectorType::get (Op0->getType (), VecRedTy);
1279+
1280+ InstructionCost ExtCost =
1281+ TTI.getCastInstrCost (Op0->getOpcode (), MulType, ExtType,
1282+ TTI::CastContextHint::None, CostKind, Op0);
1283+ InstructionCost MulCost =
1284+ TTI.getArithmeticInstrCost (Instruction::Mul, MulType, CostKind);
1285+ InstructionCost Ext2Cost =
1286+ TTI.getCastInstrCost (RedOp->getOpcode (), VecRedTy, MulType,
1287+ TTI::CastContextHint::None, CostKind, RedOp);
1288+
1289+ CostBeforeReduction = ExtCost * 2 + MulCost + Ext2Cost;
1290+ CostAfterReduction =
1291+ TTI.getMulAccReductionCost (IsUnsigned, II.getType (), ExtType, CostKind);
1292+ return ;
1293+ }
1294+ CostAfterReduction = TTI.getArithmeticReductionCost (ReductionOpc, VecRedTy,
1295+ std::nullopt , CostKind);
1296+ return ;
1297+ }
1298+
1299+ bool VectorCombine::foldBinopOfReductions (Instruction &I) {
1300+ Instruction::BinaryOps BinOpOpc = cast<BinaryOperator>(&I)->getOpcode ();
1301+ Intrinsic::ID ReductionIID = getReductionForBinop (BinOpOpc);
1302+ if (BinOpOpc == Instruction::Sub)
1303+ ReductionIID = Intrinsic::vector_reduce_add;
1304+ if (ReductionIID == Intrinsic::not_intrinsic)
1305+ return false ;
1306+
1307+ auto checkIntrinsicAndGetItsArgument = [](Value *V,
1308+ Intrinsic::ID IID) -> Value * {
1309+ auto *II = dyn_cast<IntrinsicInst>(V);
1310+ if (!II)
1311+ return nullptr ;
1312+ if (II->getIntrinsicID () == IID && II->hasOneUse ())
1313+ return II->getArgOperand (0 );
1314+ return nullptr ;
1315+ };
1316+
1317+ Value *V0 = checkIntrinsicAndGetItsArgument (I.getOperand (0 ), ReductionIID);
1318+ if (!V0)
1319+ return false ;
1320+ Value *V1 = checkIntrinsicAndGetItsArgument (I.getOperand (1 ), ReductionIID);
1321+ if (!V1)
1322+ return false ;
1323+
1324+ auto *VTy = cast<VectorType>(V0->getType ());
1325+ if (V1->getType () != VTy)
1326+ return false ;
1327+ const auto &II0 = *cast<IntrinsicInst>(I.getOperand (0 ));
1328+ const auto &II1 = *cast<IntrinsicInst>(I.getOperand (1 ));
1329+ unsigned ReductionOpc =
1330+ getArithmeticReductionInstruction (II0.getIntrinsicID ());
1331+
1332+ InstructionCost OldCost = 0 ;
1333+ InstructionCost NewCost = 0 ;
1334+ InstructionCost CostOfRedOperand0 = 0 ;
1335+ InstructionCost CostOfRed0 = 0 ;
1336+ InstructionCost CostOfRedOperand1 = 0 ;
1337+ InstructionCost CostOfRed1 = 0 ;
1338+ analyzeCostOfVecReduction (II0, CostKind, TTI, CostOfRedOperand0, CostOfRed0);
1339+ analyzeCostOfVecReduction (II1, CostKind, TTI, CostOfRedOperand1, CostOfRed1);
1340+ OldCost = CostOfRed0 + CostOfRed1 + TTI.getInstructionCost (&I, CostKind);
1341+ NewCost =
1342+ CostOfRedOperand0 + CostOfRedOperand1 +
1343+ TTI.getArithmeticInstrCost (BinOpOpc, VTy, CostKind) +
1344+ TTI.getArithmeticReductionCost (ReductionOpc, VTy, std::nullopt , CostKind);
1345+ if (NewCost >= OldCost || !NewCost.isValid ())
1346+ return false ;
1347+
1348+ LLVM_DEBUG (dbgs () << " Found two mergeable reductions: " << I
1349+ << " \n OldCost: " << OldCost << " vs NewCost: " << NewCost
1350+ << " \n " );
1351+ Value *VectorBO = Builder.CreateBinOp (BinOpOpc, V0, V1);
1352+ if (auto *PDInst = dyn_cast<PossiblyDisjointInst>(&I))
1353+ if (auto *PDVectorBO = dyn_cast<PossiblyDisjointInst>(VectorBO))
1354+ PDVectorBO->setIsDisjoint (PDInst->isDisjoint ());
1355+
1356+ Instruction *Rdx = Builder.CreateIntrinsic (ReductionIID, {VTy}, {VectorBO});
1357+ replaceValue (I, *Rdx);
1358+ return true ;
1359+ }
1360+
12451361// Check if memory loc modified between two instrs in the same BB
12461362static bool isMemModifiedBetween (BasicBlock::iterator Begin,
12471363 BasicBlock::iterator End,
@@ -3380,6 +3496,7 @@ bool VectorCombine::run() {
33803496 if (Instruction::isBinaryOp (Opcode)) {
33813497 MadeChange |= foldExtractExtract (I);
33823498 MadeChange |= foldExtractedCmps (I);
3499+ MadeChange |= foldBinopOfReductions (I);
33833500 }
33843501 break ;
33853502 }
0 commit comments