@@ -7531,6 +7531,10 @@ static bool planContainsAdditionalSimplifications(VPlan &Plan,
7531
7531
}
7532
7532
continue ;
7533
7533
}
7534
+ // The VPlan-based cost model is more accurate for partial reduction and
7535
+ // comparing against the legacy cost isn't desirable.
7536
+ if (isa<VPPartialReductionRecipe>(&R))
7537
+ return true ;
7534
7538
if (Instruction *UI = GetInstructionForCost (&R))
7535
7539
SeenInstrs.insert (UI);
7536
7540
}
@@ -8751,6 +8755,105 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(Instruction *I,
8751
8755
return Recipe;
8752
8756
}
8753
8757
8758
+ // / Find all possible partial reductions in the loop and track all of those that
8759
+ // / are valid so recipes can be formed later.
8760
+ void VPRecipeBuilder::collectScaledReductions (VFRange &Range) {
8761
+ // Find all possible partial reductions.
8762
+ SmallVector<std::pair<PartialReductionChain, unsigned >, 1 >
8763
+ PartialReductionChains;
8764
+ for (const auto &[Phi, RdxDesc] : Legal->getReductionVars ())
8765
+ if (std::optional<std::pair<PartialReductionChain, unsigned >> Pair =
8766
+ getScaledReduction (Phi, RdxDesc, Range))
8767
+ PartialReductionChains.push_back (*Pair);
8768
+
8769
+ // A partial reduction is invalid if any of its extends are used by
8770
+ // something that isn't another partial reduction. This is because the
8771
+ // extends are intended to be lowered along with the reduction itself.
8772
+
8773
+ // Build up a set of partial reduction bin ops for efficient use checking.
8774
+ SmallSet<User *, 4 > PartialReductionBinOps;
8775
+ for (const auto &[PartialRdx, _] : PartialReductionChains)
8776
+ PartialReductionBinOps.insert (PartialRdx.BinOp );
8777
+
8778
+ auto ExtendIsOnlyUsedByPartialReductions =
8779
+ [&PartialReductionBinOps](Instruction *Extend) {
8780
+ return all_of (Extend->users (), [&](const User *U) {
8781
+ return PartialReductionBinOps.contains (U);
8782
+ });
8783
+ };
8784
+
8785
+ // Check if each use of a chain's two extends is a partial reduction
8786
+ // and only add those that don't have non-partial reduction users.
8787
+ for (auto Pair : PartialReductionChains) {
8788
+ PartialReductionChain Chain = Pair.first ;
8789
+ if (ExtendIsOnlyUsedByPartialReductions (Chain.ExtendA ) &&
8790
+ ExtendIsOnlyUsedByPartialReductions (Chain.ExtendB ))
8791
+ ScaledReductionExitInstrs.insert (std::make_pair (Chain.Reduction , Pair));
8792
+ }
8793
+ }
8794
+
8795
+ std::optional<std::pair<PartialReductionChain, unsigned >>
8796
+ VPRecipeBuilder::getScaledReduction (PHINode *PHI,
8797
+ const RecurrenceDescriptor &Rdx,
8798
+ VFRange &Range) {
8799
+ // TODO: Allow scaling reductions when predicating. The select at
8800
+ // the end of the loop chooses between the phi value and most recent
8801
+ // reduction result, both of which have different VFs to the active lane
8802
+ // mask when scaling.
8803
+ if (CM.blockNeedsPredicationForAnyReason (Rdx.getLoopExitInstr ()->getParent ()))
8804
+ return std::nullopt;
8805
+
8806
+ auto *Update = dyn_cast<BinaryOperator>(Rdx.getLoopExitInstr ());
8807
+ if (!Update)
8808
+ return std::nullopt;
8809
+
8810
+ Value *Op = Update->getOperand (0 );
8811
+ Value *PhiOp = Update->getOperand (1 );
8812
+ if (Op == PHI) {
8813
+ Op = Update->getOperand (1 );
8814
+ PhiOp = Update->getOperand (0 );
8815
+ }
8816
+ if (PhiOp != PHI)
8817
+ return std::nullopt;
8818
+
8819
+ auto *BinOp = dyn_cast<BinaryOperator>(Op);
8820
+ if (!BinOp || !BinOp->hasOneUse ())
8821
+ return std::nullopt;
8822
+
8823
+ using namespace llvm ::PatternMatch;
8824
+ Value *A, *B;
8825
+ if (!match (BinOp->getOperand (0 ), m_ZExtOrSExt (m_Value (A))) ||
8826
+ !match (BinOp->getOperand (1 ), m_ZExtOrSExt (m_Value (B))))
8827
+ return std::nullopt;
8828
+
8829
+ Instruction *ExtA = cast<Instruction>(BinOp->getOperand (0 ));
8830
+ Instruction *ExtB = cast<Instruction>(BinOp->getOperand (1 ));
8831
+
8832
+ TTI::PartialReductionExtendKind OpAExtend =
8833
+ TargetTransformInfo::getPartialReductionExtendKind (ExtA);
8834
+ TTI::PartialReductionExtendKind OpBExtend =
8835
+ TargetTransformInfo::getPartialReductionExtendKind (ExtB);
8836
+
8837
+ PartialReductionChain Chain (Rdx.getLoopExitInstr (), ExtA, ExtB, BinOp);
8838
+
8839
+ unsigned TargetScaleFactor =
8840
+ PHI->getType ()->getPrimitiveSizeInBits ().getKnownScalarFactor (
8841
+ A->getType ()->getPrimitiveSizeInBits ());
8842
+
8843
+ if (LoopVectorizationPlanner::getDecisionAndClampRange (
8844
+ [&](ElementCount VF) {
8845
+ InstructionCost Cost = TTI->getPartialReductionCost (
8846
+ Update->getOpcode (), A->getType (), B->getType (), PHI->getType (),
8847
+ VF, OpAExtend, OpBExtend,
8848
+ std::make_optional (BinOp->getOpcode ()));
8849
+ return Cost.isValid ();
8850
+ },
8851
+ Range))
8852
+ return std::make_pair (Chain, TargetScaleFactor);
8853
+
8854
+ return std::nullopt;
8855
+ }
8856
+
8754
8857
VPRecipeBase *
8755
8858
VPRecipeBuilder::tryToCreateWidenRecipe (Instruction *Instr,
8756
8859
ArrayRef<VPValue *> Operands,
@@ -8775,9 +8878,14 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
8775
8878
Legal->getReductionVars ().find (Phi)->second ;
8776
8879
assert (RdxDesc.getRecurrenceStartValue () ==
8777
8880
Phi->getIncomingValueForBlock (OrigLoop->getLoopPreheader ()));
8778
- PhiRecipe = new VPReductionPHIRecipe (Phi, RdxDesc, *StartV,
8779
- CM.isInLoopReduction (Phi),
8780
- CM.useOrderedReductions (RdxDesc));
8881
+
8882
+ // If the PHI is used by a partial reduction, set the scale factor.
8883
+ std::optional<std::pair<PartialReductionChain, unsigned >> Pair =
8884
+ getScaledReductionForInstr (RdxDesc.getLoopExitInstr ());
8885
+ unsigned ScaleFactor = Pair ? Pair->second : 1 ;
8886
+ PhiRecipe = new VPReductionPHIRecipe (
8887
+ Phi, RdxDesc, *StartV, CM.isInLoopReduction (Phi),
8888
+ CM.useOrderedReductions (RdxDesc), ScaleFactor);
8781
8889
} else {
8782
8890
// TODO: Currently fixed-order recurrences are modeled as chains of
8783
8891
// first-order recurrences. If there are no users of the intermediate
@@ -8809,6 +8917,9 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
8809
8917
if (isa<LoadInst>(Instr) || isa<StoreInst>(Instr))
8810
8918
return tryToWidenMemory (Instr, Operands, Range);
8811
8919
8920
+ if (getScaledReductionForInstr (Instr))
8921
+ return tryToCreatePartialReduction (Instr, Operands);
8922
+
8812
8923
if (!shouldWiden (Instr, Range))
8813
8924
return nullptr ;
8814
8925
@@ -8829,6 +8940,21 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
8829
8940
return tryToWiden (Instr, Operands, VPBB);
8830
8941
}
8831
8942
8943
+ VPRecipeBase *
8944
+ VPRecipeBuilder::tryToCreatePartialReduction (Instruction *Reduction,
8945
+ ArrayRef<VPValue *> Operands) {
8946
+ assert (Operands.size () == 2 &&
8947
+ " Unexpected number of operands for partial reduction" );
8948
+
8949
+ VPValue *BinOp = Operands[0 ];
8950
+ VPValue *Phi = Operands[1 ];
8951
+ if (isa<VPReductionPHIRecipe>(BinOp->getDefiningRecipe ()))
8952
+ std::swap (BinOp, Phi);
8953
+
8954
+ return new VPPartialReductionRecipe (Reduction->getOpcode (), BinOp, Phi,
8955
+ Reduction);
8956
+ }
8957
+
8832
8958
void LoopVectorizationPlanner::buildVPlansWithVPRecipes (ElementCount MinVF,
8833
8959
ElementCount MaxVF) {
8834
8960
assert (OrigLoop->isInnermost () && " Inner loop expected." );
@@ -9252,7 +9378,8 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
9252
9378
bool HasNUW = !IVUpdateMayOverflow || Style == TailFoldingStyle::None;
9253
9379
addCanonicalIVRecipes (*Plan, Legal->getWidestInductionType (), HasNUW, DL);
9254
9380
9255
- VPRecipeBuilder RecipeBuilder (*Plan, OrigLoop, TLI, Legal, CM, PSE, Builder);
9381
+ VPRecipeBuilder RecipeBuilder (*Plan, OrigLoop, TLI, &TTI, Legal, CM, PSE,
9382
+ Builder);
9256
9383
9257
9384
// ---------------------------------------------------------------------------
9258
9385
// Pre-construction: record ingredients whose recipes we'll need to further
@@ -9298,6 +9425,9 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
9298
9425
bool NeedsBlends = BB != HeaderBB && !BB->phis ().empty ();
9299
9426
return Legal->blockNeedsPredication (BB) || NeedsBlends;
9300
9427
});
9428
+
9429
+ RecipeBuilder.collectScaledReductions (Range);
9430
+
9301
9431
auto *MiddleVPBB = Plan->getMiddleBlock ();
9302
9432
VPBasicBlock::iterator MBIP = MiddleVPBB->getFirstNonPhi ();
9303
9433
for (BasicBlock *BB : make_range (DFS.beginRPO (), DFS.endRPO ())) {
@@ -9521,7 +9651,8 @@ VPlanPtr LoopVectorizationPlanner::buildVPlan(VFRange &Range) {
9521
9651
9522
9652
// Collect mapping of IR header phis to header phi recipes, to be used in
9523
9653
// addScalarResumePhis.
9524
- VPRecipeBuilder RecipeBuilder (*Plan, OrigLoop, TLI, Legal, CM, PSE, Builder);
9654
+ VPRecipeBuilder RecipeBuilder (*Plan, OrigLoop, TLI, &TTI, Legal, CM, PSE,
9655
+ Builder);
9525
9656
for (auto &R : Plan->getVectorLoopRegion ()->getEntryBasicBlock ()->phis ()) {
9526
9657
if (isa<VPCanonicalIVPHIRecipe>(&R))
9527
9658
continue ;
0 commit comments