@@ -8268,6 +8268,105 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(Instruction *I,
8268
8268
return Recipe;
8269
8269
}
8270
8270
8271
+ /// Find all possible partial reductions in the loop and track all of those that
8272
+ /// are valid so recipes can be formed later.
8273
+ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
8274
+ // Find all possible partial reductions.
8275
+ SmallVector<std::pair<PartialReductionChain, unsigned>, 1>
8276
+ PartialReductionChains;
8277
+ for (const auto &[Phi, RdxDesc] : Legal->getReductionVars())
8278
+ if (std::optional<std::pair<PartialReductionChain, unsigned>> Pair =
8279
+ getScaledReduction(Phi, RdxDesc, Range))
8280
+ PartialReductionChains.push_back(*Pair);
8281
+
8282
+ // A partial reduction is invalid if any of its extends are used by
8283
+ // something that isn't another partial reduction. This is because the
8284
+ // extends are intended to be lowered along with the reduction itself.
8285
+
8286
+ // Build up a set of partial reduction bin ops for efficient use checking.
8287
+ SmallSet<User *, 4> PartialReductionBinOps;
8288
+ for (const auto &[PartialRdx, _] : PartialReductionChains)
8289
+ PartialReductionBinOps.insert(PartialRdx.BinOp);
8290
+
8291
+ auto ExtendIsOnlyUsedByPartialReductions =
8292
+ [&PartialReductionBinOps](Instruction *Extend) {
8293
+ return all_of(Extend->users(), [&](const User *U) {
8294
+ return PartialReductionBinOps.contains(U);
8295
+ });
8296
+ };
8297
+
8298
+ // Check if each use of a chain's two extends is a partial reduction
8299
+ // and only add those that don't have non-partial reduction users.
8300
+ for (auto Pair : PartialReductionChains) {
8301
+ PartialReductionChain Chain = Pair.first;
8302
+ if (ExtendIsOnlyUsedByPartialReductions(Chain.ExtendA) &&
8303
+ ExtendIsOnlyUsedByPartialReductions(Chain.ExtendB))
8304
+ ScaledReductionExitInstrs.insert(std::make_pair(Chain.Reduction, Pair));
8305
+ }
8306
+ }
8307
+
8308
+ std::optional<std::pair<PartialReductionChain, unsigned>>
8309
+ VPRecipeBuilder::getScaledReduction(PHINode *PHI,
8310
+ const RecurrenceDescriptor &Rdx,
8311
+ VFRange &Range) {
8312
+ // TODO: Allow scaling reductions when predicating. The select at
8313
+ // the end of the loop chooses between the phi value and most recent
8314
+ // reduction result, both of which have different VFs to the active lane
8315
+ // mask when scaling.
8316
+ if (CM.blockNeedsPredicationForAnyReason(Rdx.getLoopExitInstr()->getParent()))
8317
+ return std::nullopt;
8318
+
8319
+ auto *Update = dyn_cast<BinaryOperator>(Rdx.getLoopExitInstr());
8320
+ if (!Update)
8321
+ return std::nullopt;
8322
+
8323
+ Value *Op = Update->getOperand(0);
8324
+ Value *PhiOp = Update->getOperand(1);
8325
+ if (Op == PHI) {
8326
+ Op = Update->getOperand(1);
8327
+ PhiOp = Update->getOperand(0);
8328
+ }
8329
+ if (PhiOp != PHI)
8330
+ return std::nullopt;
8331
+
8332
+ auto *BinOp = dyn_cast<BinaryOperator>(Op);
8333
+ if (!BinOp || !BinOp->hasOneUse())
8334
+ return std::nullopt;
8335
+
8336
+ using namespace llvm::PatternMatch;
8337
+ Value *A, *B;
8338
+ if (!match(BinOp->getOperand(0), m_ZExtOrSExt(m_Value(A))) ||
8339
+ !match(BinOp->getOperand(1), m_ZExtOrSExt(m_Value(B))))
8340
+ return std::nullopt;
8341
+
8342
+ Instruction *ExtA = cast<Instruction>(BinOp->getOperand(0));
8343
+ Instruction *ExtB = cast<Instruction>(BinOp->getOperand(1));
8344
+
8345
+ TTI::PartialReductionExtendKind OpAExtend =
8346
+ TargetTransformInfo::getPartialReductionExtendKind(ExtA);
8347
+ TTI::PartialReductionExtendKind OpBExtend =
8348
+ TargetTransformInfo::getPartialReductionExtendKind(ExtB);
8349
+
8350
+ PartialReductionChain Chain(Rdx.getLoopExitInstr(), ExtA, ExtB, BinOp);
8351
+
8352
+ unsigned TargetScaleFactor =
8353
+ PHI->getType()->getPrimitiveSizeInBits().getKnownScalarFactor(
8354
+ A->getType()->getPrimitiveSizeInBits());
8355
+
8356
+ if (LoopVectorizationPlanner::getDecisionAndClampRange(
8357
+ [&](ElementCount VF) {
8358
+ InstructionCost Cost = TTI->getPartialReductionCost(
8359
+ Update->getOpcode(), A->getType(), B->getType(), PHI->getType(),
8360
+ VF, OpAExtend, OpBExtend,
8361
+ std::make_optional(BinOp->getOpcode()));
8362
+ return Cost.isValid();
8363
+ },
8364
+ Range))
8365
+ return std::make_pair(Chain, TargetScaleFactor);
8366
+
8367
+ return std::nullopt;
8368
+ }
8369
+
8271
8370
VPRecipeBase *
8272
8371
VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
8273
8372
ArrayRef<VPValue *> Operands,
@@ -8292,9 +8391,14 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
8292
8391
Legal->getReductionVars().find(Phi)->second;
8293
8392
assert(RdxDesc.getRecurrenceStartValue() ==
8294
8393
Phi->getIncomingValueForBlock(OrigLoop->getLoopPreheader()));
8295
- PhiRecipe = new VPReductionPHIRecipe(Phi, RdxDesc, *StartV,
8296
- CM.isInLoopReduction(Phi),
8297
- CM.useOrderedReductions(RdxDesc));
8394
+
8395
+ // If the PHI is used by a partial reduction, set the scale factor.
8396
+ std::optional<std::pair<PartialReductionChain, unsigned>> Pair =
8397
+ getScaledReductionForInstr(RdxDesc.getLoopExitInstr());
8398
+ unsigned ScaleFactor = Pair ? Pair->second : 1;
8399
+ PhiRecipe = new VPReductionPHIRecipe(
8400
+ Phi, RdxDesc, *StartV, CM.isInLoopReduction(Phi),
8401
+ CM.useOrderedReductions(RdxDesc), ScaleFactor);
8298
8402
} else {
8299
8403
// TODO: Currently fixed-order recurrences are modeled as chains of
8300
8404
// first-order recurrences. If there are no users of the intermediate
@@ -8322,6 +8426,9 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
8322
8426
if (isa<LoadInst>(Instr) || isa<StoreInst>(Instr))
8323
8427
return tryToWidenMemory(Instr, Operands, Range);
8324
8428
8429
+ if (getScaledReductionForInstr(Instr))
8430
+ return tryToCreatePartialReduction(Instr, Operands);
8431
+
8325
8432
if (!shouldWiden(Instr, Range))
8326
8433
return nullptr;
8327
8434
@@ -8342,6 +8449,21 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
8342
8449
return tryToWiden(Instr, Operands, VPBB);
8343
8450
}
8344
8451
8452
+ VPRecipeBase *
8453
+ VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction,
8454
+ ArrayRef<VPValue *> Operands) {
8455
+ assert(Operands.size() == 2 &&
8456
+ "Unexpected number of operands for partial reduction");
8457
+
8458
+ VPValue *BinOp = Operands[0];
8459
+ VPValue *Phi = Operands[1];
8460
+ if (isa<VPReductionPHIRecipe>(BinOp->getDefiningRecipe()))
8461
+ std::swap(BinOp, Phi);
8462
+
8463
+ return new VPPartialReductionRecipe(Reduction->getOpcode(), BinOp, Phi,
8464
+ Reduction);
8465
+ }
8466
+
8345
8467
void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF,
8346
8468
ElementCount MaxVF) {
8347
8469
assert(OrigLoop->isInnermost() && "Inner loop expected.");
@@ -8514,7 +8636,8 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
8514
8636
bool HasNUW = Style == TailFoldingStyle::None;
8515
8637
addCanonicalIVRecipes(*Plan, Legal->getWidestInductionType(), HasNUW, DL);
8516
8638
8517
- VPRecipeBuilder RecipeBuilder(*Plan, OrigLoop, TLI, Legal, CM, PSE, Builder);
8639
+ VPRecipeBuilder RecipeBuilder(*Plan, OrigLoop, TLI, &TTI, Legal, CM, PSE,
8640
+ Builder);
8518
8641
8519
8642
// ---------------------------------------------------------------------------
8520
8643
// Pre-construction: record ingredients whose recipes we'll need to further
@@ -8560,6 +8683,9 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
8560
8683
bool NeedsBlends = BB != HeaderBB && !BB->phis().empty();
8561
8684
return Legal->blockNeedsPredication(BB) || NeedsBlends;
8562
8685
});
8686
+
8687
+ RecipeBuilder.collectScaledReductions(Range);
8688
+
8563
8689
for (BasicBlock *BB : make_range(DFS.beginRPO(), DFS.endRPO())) {
8564
8690
// Relevant instructions from basic block BB will be grouped into VPRecipe
8565
8691
// ingredients and fill a new VPBasicBlock.
@@ -8770,7 +8896,7 @@ VPlanPtr LoopVectorizationPlanner::buildVPlan(VFRange &Range) {
8770
8896
bool HasNUW = true;
8771
8897
addCanonicalIVRecipes(*Plan, Legal->getWidestInductionType(), HasNUW,
8772
8898
DebugLoc());
8773
- assert(verifyVPlanIsValid(*Plan) && "VPlan is invalid");
8899
+ assert(verifyVPlanIsValid(*Plan) && "VPlan is invalid");
8774
8900
return Plan;
8775
8901
}
8776
8902
0 commit comments