@@ -7531,6 +7531,10 @@ static bool planContainsAdditionalSimplifications(VPlan &Plan,
7531
7531
}
7532
7532
continue;
7533
7533
}
7534
+ // The VPlan-based cost model is more accurate for partial reduction and
7535
+ // comparing against the legacy cost isn't desirable.
7536
+ if (isa<VPPartialReductionRecipe>(&R))
7537
+ return true;
7534
7538
if (Instruction *UI = GetInstructionForCost(&R))
7535
7539
SeenInstrs.insert(UI);
7536
7540
}
@@ -8751,6 +8755,105 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(Instruction *I,
8751
8755
return Recipe;
8752
8756
}
8753
8757
8758
+ /// Find all possible partial reductions in the loop and track all of those that
8759
+ /// are valid so recipes can be formed later.
8760
+ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
8761
+ // Find all possible partial reductions.
8762
+ SmallVector<std::pair<PartialReductionChain, unsigned>, 1>
8763
+ PartialReductionChains;
8764
+ for (const auto &[Phi, RdxDesc] : Legal->getReductionVars())
8765
+ if (std::optional<std::pair<PartialReductionChain, unsigned>> Pair =
8766
+ getScaledReduction(Phi, RdxDesc, Range))
8767
+ PartialReductionChains.push_back(*Pair);
8768
+
8769
+ // A partial reduction is invalid if any of its extends are used by
8770
+ // something that isn't another partial reduction. This is because the
8771
+ // extends are intended to be lowered along with the reduction itself.
8772
+
8773
+ // Build up a set of partial reduction bin ops for efficient use checking.
8774
+ SmallSet<User *, 4> PartialReductionBinOps;
8775
+ for (const auto &[PartialRdx, _] : PartialReductionChains)
8776
+ PartialReductionBinOps.insert(PartialRdx.BinOp);
8777
+
8778
+ auto ExtendIsOnlyUsedByPartialReductions =
8779
+ [&PartialReductionBinOps](Instruction *Extend) {
8780
+ return all_of(Extend->users(), [&](const User *U) {
8781
+ return PartialReductionBinOps.contains(U);
8782
+ });
8783
+ };
8784
+
8785
+ // Check if each use of a chain's two extends is a partial reduction
8786
+ // and only add those that don't have non-partial reduction users.
8787
+ for (auto Pair : PartialReductionChains) {
8788
+ PartialReductionChain Chain = Pair.first;
8789
+ if (ExtendIsOnlyUsedByPartialReductions(Chain.ExtendA) &&
8790
+ ExtendIsOnlyUsedByPartialReductions(Chain.ExtendB))
8791
+ ScaledReductionExitInstrs.insert(std::make_pair(Chain.Reduction, Pair));
8792
+ }
8793
+ }
8794
+
8795
+ std::optional<std::pair<PartialReductionChain, unsigned>>
8796
+ VPRecipeBuilder::getScaledReduction(PHINode *PHI,
8797
+ const RecurrenceDescriptor &Rdx,
8798
+ VFRange &Range) {
8799
+ // TODO: Allow scaling reductions when predicating. The select at
8800
+ // the end of the loop chooses between the phi value and most recent
8801
+ // reduction result, both of which have different VFs to the active lane
8802
+ // mask when scaling.
8803
+ if (CM.blockNeedsPredicationForAnyReason(Rdx.getLoopExitInstr()->getParent()))
8804
+ return std::nullopt;
8805
+
8806
+ auto *Update = dyn_cast<BinaryOperator>(Rdx.getLoopExitInstr());
8807
+ if (!Update)
8808
+ return std::nullopt;
8809
+
8810
+ Value *Op = Update->getOperand(0);
8811
+ Value *PhiOp = Update->getOperand(1);
8812
+ if (Op == PHI) {
8813
+ Op = Update->getOperand(1);
8814
+ PhiOp = Update->getOperand(0);
8815
+ }
8816
+ if (PhiOp != PHI)
8817
+ return std::nullopt;
8818
+
8819
+ auto *BinOp = dyn_cast<BinaryOperator>(Op);
8820
+ if (!BinOp || !BinOp->hasOneUse())
8821
+ return std::nullopt;
8822
+
8823
+ using namespace llvm::PatternMatch;
8824
+ Value *A, *B;
8825
+ if (!match(BinOp->getOperand(0), m_ZExtOrSExt(m_Value(A))) ||
8826
+ !match(BinOp->getOperand(1), m_ZExtOrSExt(m_Value(B))))
8827
+ return std::nullopt;
8828
+
8829
+ Instruction *ExtA = cast<Instruction>(BinOp->getOperand(0));
8830
+ Instruction *ExtB = cast<Instruction>(BinOp->getOperand(1));
8831
+
8832
+ TTI::PartialReductionExtendKind OpAExtend =
8833
+ TargetTransformInfo::getPartialReductionExtendKind(ExtA);
8834
+ TTI::PartialReductionExtendKind OpBExtend =
8835
+ TargetTransformInfo::getPartialReductionExtendKind(ExtB);
8836
+
8837
+ PartialReductionChain Chain(Rdx.getLoopExitInstr(), ExtA, ExtB, BinOp);
8838
+
8839
+ unsigned TargetScaleFactor =
8840
+ PHI->getType()->getPrimitiveSizeInBits().getKnownScalarFactor(
8841
+ A->getType()->getPrimitiveSizeInBits());
8842
+
8843
+ if (LoopVectorizationPlanner::getDecisionAndClampRange(
8844
+ [&](ElementCount VF) {
8845
+ InstructionCost Cost = TTI->getPartialReductionCost(
8846
+ Update->getOpcode(), A->getType(), B->getType(), PHI->getType(),
8847
+ VF, OpAExtend, OpBExtend,
8848
+ std::make_optional(BinOp->getOpcode()));
8849
+ return Cost.isValid();
8850
+ },
8851
+ Range))
8852
+ return std::make_pair(Chain, TargetScaleFactor);
8853
+
8854
+ return std::nullopt;
8855
+ }
8856
+
8754
8857
VPRecipeBase *
8755
8858
VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
8756
8859
ArrayRef<VPValue *> Operands,
@@ -8775,9 +8878,14 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
8775
8878
Legal->getReductionVars().find(Phi)->second;
8776
8879
assert(RdxDesc.getRecurrenceStartValue() ==
8777
8880
Phi->getIncomingValueForBlock(OrigLoop->getLoopPreheader()));
8778
- PhiRecipe = new VPReductionPHIRecipe(Phi, RdxDesc, *StartV,
8779
- CM.isInLoopReduction(Phi),
8780
- CM.useOrderedReductions(RdxDesc));
8881
+
8882
+ // If the PHI is used by a partial reduction, set the scale factor.
8883
+ std::optional<std::pair<PartialReductionChain, unsigned>> Pair =
8884
+ getScaledReductionForInstr(RdxDesc.getLoopExitInstr());
8885
+ unsigned ScaleFactor = Pair ? Pair->second : 1;
8886
+ PhiRecipe = new VPReductionPHIRecipe(
8887
+ Phi, RdxDesc, *StartV, CM.isInLoopReduction(Phi),
8888
+ CM.useOrderedReductions(RdxDesc), ScaleFactor);
8781
8889
} else {
8782
8890
// TODO: Currently fixed-order recurrences are modeled as chains of
8783
8891
// first-order recurrences. If there are no users of the intermediate
@@ -8809,6 +8917,9 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
8809
8917
if (isa<LoadInst>(Instr) || isa<StoreInst>(Instr))
8810
8918
return tryToWidenMemory(Instr, Operands, Range);
8811
8919
8920
+ if (getScaledReductionForInstr(Instr))
8921
+ return tryToCreatePartialReduction(Instr, Operands);
8922
+
8812
8923
if (!shouldWiden(Instr, Range))
8813
8924
return nullptr;
8814
8925
@@ -8829,6 +8940,21 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
8829
8940
return tryToWiden(Instr, Operands, VPBB);
8830
8941
}
8831
8942
8943
+ VPRecipeBase *
8944
+ VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction,
8945
+ ArrayRef<VPValue *> Operands) {
8946
+ assert(Operands.size() == 2 &&
8947
+ "Unexpected number of operands for partial reduction");
8948
+
8949
+ VPValue *BinOp = Operands[0];
8950
+ VPValue *Phi = Operands[1];
8951
+ if (isa<VPReductionPHIRecipe>(BinOp->getDefiningRecipe()))
8952
+ std::swap(BinOp, Phi);
8953
+
8954
+ return new VPPartialReductionRecipe(Reduction->getOpcode(), BinOp, Phi,
8955
+ Reduction);
8956
+ }
8957
+
8832
8958
void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF,
8833
8959
ElementCount MaxVF) {
8834
8960
assert(OrigLoop->isInnermost() && "Inner loop expected.");
@@ -9252,7 +9378,8 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
9252
9378
bool HasNUW = !IVUpdateMayOverflow || Style == TailFoldingStyle::None;
9253
9379
addCanonicalIVRecipes(*Plan, Legal->getWidestInductionType(), HasNUW, DL);
9254
9380
9255
- VPRecipeBuilder RecipeBuilder(*Plan, OrigLoop, TLI, Legal, CM, PSE, Builder);
9381
+ VPRecipeBuilder RecipeBuilder(*Plan, OrigLoop, TLI, &TTI, Legal, CM, PSE,
9382
+ Builder);
9256
9383
9257
9384
// ---------------------------------------------------------------------------
9258
9385
// Pre-construction: record ingredients whose recipes we'll need to further
@@ -9298,6 +9425,9 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
9298
9425
bool NeedsBlends = BB != HeaderBB && !BB->phis().empty();
9299
9426
return Legal->blockNeedsPredication(BB) || NeedsBlends;
9300
9427
});
9428
+
9429
+ RecipeBuilder.collectScaledReductions(Range);
9430
+
9301
9431
auto *MiddleVPBB = Plan->getMiddleBlock();
9302
9432
VPBasicBlock::iterator MBIP = MiddleVPBB->getFirstNonPhi();
9303
9433
for (BasicBlock *BB : make_range(DFS.beginRPO(), DFS.endRPO())) {
@@ -9521,7 +9651,8 @@ VPlanPtr LoopVectorizationPlanner::buildVPlan(VFRange &Range) {
9521
9651
9522
9652
// Collect mapping of IR header phis to header phi recipes, to be used in
9523
9653
// addScalarResumePhis.
9524
- VPRecipeBuilder RecipeBuilder(*Plan, OrigLoop, TLI, Legal, CM, PSE, Builder);
9654
+ VPRecipeBuilder RecipeBuilder(*Plan, OrigLoop, TLI, &TTI, Legal, CM, PSE,
9655
+ Builder);
9525
9656
for (auto &R : Plan->getVectorLoopRegion()->getEntryBasicBlock()->phis()) {
9526
9657
if (isa<VPCanonicalIVPHIRecipe>(&R))
9527
9658
continue;
0 commit comments