Skip to content

Commit 795e35a

Browse files
Reland "[LoopVectorizer] Add support for partial reductions" with non-phi operand fix. (llvm#121744)
This relands the reverted llvm#120721 with a fix for cases where neither reduction operand are the reduction phi. Only 6311423 and 6311423 are new on top of the reverted PR. --------- Co-authored-by: Nicholas Guy <nicholas.guy@arm.com>
1 parent 171d3ed commit 795e35a

17 files changed

+4588
-31
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

+44
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,12 @@ typedef TargetTransformInfo TTI;
211211
/// for IR-level transformations.
212212
class TargetTransformInfo {
213213
public:
214+
enum PartialReductionExtendKind { PR_None, PR_SignExtend, PR_ZeroExtend };
215+
216+
/// Get the kind of extension that an instruction represents.
217+
static PartialReductionExtendKind
218+
getPartialReductionExtendKind(Instruction *I);
219+
214220
/// Construct a TTI object using a type implementing the \c Concept
215221
/// API below.
216222
///
@@ -1280,6 +1286,20 @@ class TargetTransformInfo {
12801286
/// \return if target want to issue a prefetch in address space \p AS.
12811287
bool shouldPrefetchAddressSpace(unsigned AS) const;
12821288

1289+
/// \return The cost of a partial reduction, which is a reduction from a
1290+
/// vector to another vector with fewer elements of larger size. They are
1291+
/// represented by the llvm.experimental.partial.reduce.add intrinsic, which
1292+
/// takes an accumulator and a binary operation operand that itself is fed by
1293+
/// two extends. An example of an operation that uses a partial reduction is a
1294+
/// dot product, which reduces two vectors to another of 4 times fewer and 4
1295+
/// times larger elements.
1296+
InstructionCost
1297+
getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
1298+
Type *AccumType, ElementCount VF,
1299+
PartialReductionExtendKind OpAExtend,
1300+
PartialReductionExtendKind OpBExtend,
1301+
std::optional<unsigned> BinOp = std::nullopt) const;
1302+
12831303
/// \return The maximum interleave factor that any transform should try to
12841304
/// perform for this target. This number depends on the level of parallelism
12851305
/// and the number of execution units in the CPU.
@@ -2107,6 +2127,20 @@ class TargetTransformInfo::Concept {
21072127
/// \return if target want to issue a prefetch in address space \p AS.
21082128
virtual bool shouldPrefetchAddressSpace(unsigned AS) const = 0;
21092129

2130+
/// \return The cost of a partial reduction, which is a reduction from a
2131+
/// vector to another vector with fewer elements of larger size. They are
2132+
/// represented by the llvm.experimental.partial.reduce.add intrinsic, which
2133+
/// takes an accumulator and a binary operation operand that itself is fed by
2134+
/// two extends. An example of an operation that uses a partial reduction is a
2135+
/// dot product, which reduces two vectors to another of 4 times fewer and 4
2136+
/// times larger elements.
2137+
virtual InstructionCost
2138+
getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
2139+
Type *AccumType, ElementCount VF,
2140+
PartialReductionExtendKind OpAExtend,
2141+
PartialReductionExtendKind OpBExtend,
2142+
std::optional<unsigned> BinOp) const = 0;
2143+
21102144
virtual unsigned getMaxInterleaveFactor(ElementCount VF) = 0;
21112145
virtual InstructionCost getArithmeticInstrCost(
21122146
unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
@@ -2786,6 +2820,16 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
27862820
return Impl.shouldPrefetchAddressSpace(AS);
27872821
}
27882822

2823+
InstructionCost getPartialReductionCost(
2824+
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
2825+
ElementCount VF, PartialReductionExtendKind OpAExtend,
2826+
PartialReductionExtendKind OpBExtend,
2827+
std::optional<unsigned> BinOp = std::nullopt) const override {
2828+
return Impl.getPartialReductionCost(Opcode, InputTypeA, InputTypeB,
2829+
AccumType, VF, OpAExtend, OpBExtend,
2830+
BinOp);
2831+
}
2832+
27892833
unsigned getMaxInterleaveFactor(ElementCount VF) override {
27902834
return Impl.getMaxInterleaveFactor(VF);
27912835
}

llvm/include/llvm/Analysis/TargetTransformInfoImpl.h

+9
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,15 @@ class TargetTransformInfoImplBase {
585585
bool enableWritePrefetching() const { return false; }
586586
bool shouldPrefetchAddressSpace(unsigned AS) const { return !AS; }
587587

588+
InstructionCost
589+
getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
590+
Type *AccumType, ElementCount VF,
591+
TTI::PartialReductionExtendKind OpAExtend,
592+
TTI::PartialReductionExtendKind OpBExtend,
593+
std::optional<unsigned> BinOp = std::nullopt) const {
594+
return InstructionCost::getInvalid();
595+
}
596+
588597
unsigned getMaxInterleaveFactor(ElementCount VF) const { return 1; }
589598

590599
InstructionCost getArithmeticInstrCost(

llvm/lib/Analysis/TargetTransformInfo.cpp

+18
Original file line numberDiff line numberDiff line change
@@ -863,6 +863,15 @@ bool TargetTransformInfo::shouldPrefetchAddressSpace(unsigned AS) const {
863863
return TTIImpl->shouldPrefetchAddressSpace(AS);
864864
}
865865

866+
InstructionCost TargetTransformInfo::getPartialReductionCost(
867+
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
868+
ElementCount VF, PartialReductionExtendKind OpAExtend,
869+
PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp) const {
870+
return TTIImpl->getPartialReductionCost(Opcode, InputTypeA, InputTypeB,
871+
AccumType, VF, OpAExtend, OpBExtend,
872+
BinOp);
873+
}
874+
866875
unsigned TargetTransformInfo::getMaxInterleaveFactor(ElementCount VF) const {
867876
return TTIImpl->getMaxInterleaveFactor(VF);
868877
}
@@ -974,6 +983,15 @@ InstructionCost TargetTransformInfo::getShuffleCost(
974983
return Cost;
975984
}
976985

986+
TargetTransformInfo::PartialReductionExtendKind
987+
TargetTransformInfo::getPartialReductionExtendKind(Instruction *I) {
988+
if (isa<SExtInst>(I))
989+
return PR_SignExtend;
990+
if (isa<ZExtInst>(I))
991+
return PR_ZeroExtend;
992+
return PR_None;
993+
}
994+
977995
TTI::CastContextHint
978996
TargetTransformInfo::getCastContextHint(const Instruction *I) {
979997
if (!I)

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h

+63
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "llvm/CodeGen/BasicTTIImpl.h"
2424
#include "llvm/IR/Function.h"
2525
#include "llvm/IR/Intrinsics.h"
26+
#include "llvm/Support/InstructionCost.h"
2627
#include <cstdint>
2728
#include <optional>
2829

@@ -357,6 +358,68 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
357358
return BaseT::isLegalNTLoad(DataType, Alignment);
358359
}
359360

361+
InstructionCost
362+
getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
363+
Type *AccumType, ElementCount VF,
364+
TTI::PartialReductionExtendKind OpAExtend,
365+
TTI::PartialReductionExtendKind OpBExtend,
366+
std::optional<unsigned> BinOp) const {
367+
368+
InstructionCost Invalid = InstructionCost::getInvalid();
369+
InstructionCost Cost(TTI::TCC_Basic);
370+
371+
if (Opcode != Instruction::Add)
372+
return Invalid;
373+
374+
if (InputTypeA != InputTypeB)
375+
return Invalid;
376+
377+
EVT InputEVT = EVT::getEVT(InputTypeA);
378+
EVT AccumEVT = EVT::getEVT(AccumType);
379+
380+
if (VF.isScalable() && !ST->isSVEorStreamingSVEAvailable())
381+
return Invalid;
382+
if (VF.isFixed() && (!ST->isNeonAvailable() || !ST->hasDotProd()))
383+
return Invalid;
384+
385+
if (InputEVT == MVT::i8) {
386+
switch (VF.getKnownMinValue()) {
387+
default:
388+
return Invalid;
389+
case 8:
390+
if (AccumEVT == MVT::i32)
391+
Cost *= 2;
392+
else if (AccumEVT != MVT::i64)
393+
return Invalid;
394+
break;
395+
case 16:
396+
if (AccumEVT == MVT::i64)
397+
Cost *= 2;
398+
else if (AccumEVT != MVT::i32)
399+
return Invalid;
400+
break;
401+
}
402+
} else if (InputEVT == MVT::i16) {
403+
// FIXME: Allow i32 accumulator but increase cost, as we would extend
404+
// it to i64.
405+
if (VF.getKnownMinValue() != 8 || AccumEVT != MVT::i64)
406+
return Invalid;
407+
} else
408+
return Invalid;
409+
410+
// AArch64 supports lowering mixed extensions to a usdot but only if the
411+
// i8mm or sve/streaming features are available.
412+
if (OpAExtend == TTI::PR_None || OpBExtend == TTI::PR_None ||
413+
(OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
414+
!ST->isSVEorStreamingSVEAvailable()))
415+
return Invalid;
416+
417+
if (!BinOp || *BinOp != Instruction::Mul)
418+
return Invalid;
419+
420+
return Cost;
421+
}
422+
360423
bool enableOrderedReductions() const { return true; }
361424

362425
InstructionCost getInterleavedMemoryOpCost(

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

+136-5
Original file line numberDiff line numberDiff line change
@@ -7531,6 +7531,10 @@ static bool planContainsAdditionalSimplifications(VPlan &Plan,
75317531
}
75327532
continue;
75337533
}
7534+
// The VPlan-based cost model is more accurate for partial reduction and
7535+
// comparing against the legacy cost isn't desirable.
7536+
if (isa<VPPartialReductionRecipe>(&R))
7537+
return true;
75347538
if (Instruction *UI = GetInstructionForCost(&R))
75357539
SeenInstrs.insert(UI);
75367540
}
@@ -8751,6 +8755,105 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(Instruction *I,
87518755
return Recipe;
87528756
}
87538757

8758+
/// Find all possible partial reductions in the loop and track all of those that
8759+
/// are valid so recipes can be formed later.
8760+
void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
8761+
// Find all possible partial reductions.
8762+
SmallVector<std::pair<PartialReductionChain, unsigned>, 1>
8763+
PartialReductionChains;
8764+
for (const auto &[Phi, RdxDesc] : Legal->getReductionVars())
8765+
if (std::optional<std::pair<PartialReductionChain, unsigned>> Pair =
8766+
getScaledReduction(Phi, RdxDesc, Range))
8767+
PartialReductionChains.push_back(*Pair);
8768+
8769+
// A partial reduction is invalid if any of its extends are used by
8770+
// something that isn't another partial reduction. This is because the
8771+
// extends are intended to be lowered along with the reduction itself.
8772+
8773+
// Build up a set of partial reduction bin ops for efficient use checking.
8774+
SmallSet<User *, 4> PartialReductionBinOps;
8775+
for (const auto &[PartialRdx, _] : PartialReductionChains)
8776+
PartialReductionBinOps.insert(PartialRdx.BinOp);
8777+
8778+
auto ExtendIsOnlyUsedByPartialReductions =
8779+
[&PartialReductionBinOps](Instruction *Extend) {
8780+
return all_of(Extend->users(), [&](const User *U) {
8781+
return PartialReductionBinOps.contains(U);
8782+
});
8783+
};
8784+
8785+
// Check if each use of a chain's two extends is a partial reduction
8786+
// and only add those that don't have non-partial reduction users.
8787+
for (auto Pair : PartialReductionChains) {
8788+
PartialReductionChain Chain = Pair.first;
8789+
if (ExtendIsOnlyUsedByPartialReductions(Chain.ExtendA) &&
8790+
ExtendIsOnlyUsedByPartialReductions(Chain.ExtendB))
8791+
ScaledReductionExitInstrs.insert(std::make_pair(Chain.Reduction, Pair));
8792+
}
8793+
}
8794+
8795+
std::optional<std::pair<PartialReductionChain, unsigned>>
8796+
VPRecipeBuilder::getScaledReduction(PHINode *PHI,
8797+
const RecurrenceDescriptor &Rdx,
8798+
VFRange &Range) {
8799+
// TODO: Allow scaling reductions when predicating. The select at
8800+
// the end of the loop chooses between the phi value and most recent
8801+
// reduction result, both of which have different VFs to the active lane
8802+
// mask when scaling.
8803+
if (CM.blockNeedsPredicationForAnyReason(Rdx.getLoopExitInstr()->getParent()))
8804+
return std::nullopt;
8805+
8806+
auto *Update = dyn_cast<BinaryOperator>(Rdx.getLoopExitInstr());
8807+
if (!Update)
8808+
return std::nullopt;
8809+
8810+
Value *Op = Update->getOperand(0);
8811+
Value *PhiOp = Update->getOperand(1);
8812+
if (Op == PHI) {
8813+
Op = Update->getOperand(1);
8814+
PhiOp = Update->getOperand(0);
8815+
}
8816+
if (PhiOp != PHI)
8817+
return std::nullopt;
8818+
8819+
auto *BinOp = dyn_cast<BinaryOperator>(Op);
8820+
if (!BinOp || !BinOp->hasOneUse())
8821+
return std::nullopt;
8822+
8823+
using namespace llvm::PatternMatch;
8824+
Value *A, *B;
8825+
if (!match(BinOp->getOperand(0), m_ZExtOrSExt(m_Value(A))) ||
8826+
!match(BinOp->getOperand(1), m_ZExtOrSExt(m_Value(B))))
8827+
return std::nullopt;
8828+
8829+
Instruction *ExtA = cast<Instruction>(BinOp->getOperand(0));
8830+
Instruction *ExtB = cast<Instruction>(BinOp->getOperand(1));
8831+
8832+
TTI::PartialReductionExtendKind OpAExtend =
8833+
TargetTransformInfo::getPartialReductionExtendKind(ExtA);
8834+
TTI::PartialReductionExtendKind OpBExtend =
8835+
TargetTransformInfo::getPartialReductionExtendKind(ExtB);
8836+
8837+
PartialReductionChain Chain(Rdx.getLoopExitInstr(), ExtA, ExtB, BinOp);
8838+
8839+
unsigned TargetScaleFactor =
8840+
PHI->getType()->getPrimitiveSizeInBits().getKnownScalarFactor(
8841+
A->getType()->getPrimitiveSizeInBits());
8842+
8843+
if (LoopVectorizationPlanner::getDecisionAndClampRange(
8844+
[&](ElementCount VF) {
8845+
InstructionCost Cost = TTI->getPartialReductionCost(
8846+
Update->getOpcode(), A->getType(), B->getType(), PHI->getType(),
8847+
VF, OpAExtend, OpBExtend,
8848+
std::make_optional(BinOp->getOpcode()));
8849+
return Cost.isValid();
8850+
},
8851+
Range))
8852+
return std::make_pair(Chain, TargetScaleFactor);
8853+
8854+
return std::nullopt;
8855+
}
8856+
87548857
VPRecipeBase *
87558858
VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
87568859
ArrayRef<VPValue *> Operands,
@@ -8775,9 +8878,14 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
87758878
Legal->getReductionVars().find(Phi)->second;
87768879
assert(RdxDesc.getRecurrenceStartValue() ==
87778880
Phi->getIncomingValueForBlock(OrigLoop->getLoopPreheader()));
8778-
PhiRecipe = new VPReductionPHIRecipe(Phi, RdxDesc, *StartV,
8779-
CM.isInLoopReduction(Phi),
8780-
CM.useOrderedReductions(RdxDesc));
8881+
8882+
// If the PHI is used by a partial reduction, set the scale factor.
8883+
std::optional<std::pair<PartialReductionChain, unsigned>> Pair =
8884+
getScaledReductionForInstr(RdxDesc.getLoopExitInstr());
8885+
unsigned ScaleFactor = Pair ? Pair->second : 1;
8886+
PhiRecipe = new VPReductionPHIRecipe(
8887+
Phi, RdxDesc, *StartV, CM.isInLoopReduction(Phi),
8888+
CM.useOrderedReductions(RdxDesc), ScaleFactor);
87818889
} else {
87828890
// TODO: Currently fixed-order recurrences are modeled as chains of
87838891
// first-order recurrences. If there are no users of the intermediate
@@ -8809,6 +8917,9 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
88098917
if (isa<LoadInst>(Instr) || isa<StoreInst>(Instr))
88108918
return tryToWidenMemory(Instr, Operands, Range);
88118919

8920+
if (getScaledReductionForInstr(Instr))
8921+
return tryToCreatePartialReduction(Instr, Operands);
8922+
88128923
if (!shouldWiden(Instr, Range))
88138924
return nullptr;
88148925

@@ -8829,6 +8940,21 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
88298940
return tryToWiden(Instr, Operands, VPBB);
88308941
}
88318942

8943+
VPRecipeBase *
8944+
VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction,
8945+
ArrayRef<VPValue *> Operands) {
8946+
assert(Operands.size() == 2 &&
8947+
"Unexpected number of operands for partial reduction");
8948+
8949+
VPValue *BinOp = Operands[0];
8950+
VPValue *Phi = Operands[1];
8951+
if (isa<VPReductionPHIRecipe>(BinOp->getDefiningRecipe()))
8952+
std::swap(BinOp, Phi);
8953+
8954+
return new VPPartialReductionRecipe(Reduction->getOpcode(), BinOp, Phi,
8955+
Reduction);
8956+
}
8957+
88328958
void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF,
88338959
ElementCount MaxVF) {
88348960
assert(OrigLoop->isInnermost() && "Inner loop expected.");
@@ -9252,7 +9378,8 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
92529378
bool HasNUW = !IVUpdateMayOverflow || Style == TailFoldingStyle::None;
92539379
addCanonicalIVRecipes(*Plan, Legal->getWidestInductionType(), HasNUW, DL);
92549380

9255-
VPRecipeBuilder RecipeBuilder(*Plan, OrigLoop, TLI, Legal, CM, PSE, Builder);
9381+
VPRecipeBuilder RecipeBuilder(*Plan, OrigLoop, TLI, &TTI, Legal, CM, PSE,
9382+
Builder);
92569383

92579384
// ---------------------------------------------------------------------------
92589385
// Pre-construction: record ingredients whose recipes we'll need to further
@@ -9298,6 +9425,9 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
92989425
bool NeedsBlends = BB != HeaderBB && !BB->phis().empty();
92999426
return Legal->blockNeedsPredication(BB) || NeedsBlends;
93009427
});
9428+
9429+
RecipeBuilder.collectScaledReductions(Range);
9430+
93019431
auto *MiddleVPBB = Plan->getMiddleBlock();
93029432
VPBasicBlock::iterator MBIP = MiddleVPBB->getFirstNonPhi();
93039433
for (BasicBlock *BB : make_range(DFS.beginRPO(), DFS.endRPO())) {
@@ -9521,7 +9651,8 @@ VPlanPtr LoopVectorizationPlanner::buildVPlan(VFRange &Range) {
95219651

95229652
// Collect mapping of IR header phis to header phi recipes, to be used in
95239653
// addScalarResumePhis.
9524-
VPRecipeBuilder RecipeBuilder(*Plan, OrigLoop, TLI, Legal, CM, PSE, Builder);
9654+
VPRecipeBuilder RecipeBuilder(*Plan, OrigLoop, TLI, &TTI, Legal, CM, PSE,
9655+
Builder);
95259656
for (auto &R : Plan->getVectorLoopRegion()->getEntryBasicBlock()->phis()) {
95269657
if (isa<VPCanonicalIVPHIRecipe>(&R))
95279658
continue;

0 commit comments

Comments
 (0)