Skip to content

Commit 060d62b

Browse files
[LoopVectorizer] Add support for partial reductions (llvm#92418)
Following on from llvm#94499, this patch adds support to the Loop Vectorizer to emit the partial reduction intrinsics where they may be beneficial for the target. --------- Co-authored-by: Samuel Tebbs <samuel.tebbs@arm.com>
1 parent b41240b commit 060d62b

16 files changed

+3812
-31
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

+39
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,12 @@ typedef TargetTransformInfo TTI;
211211
/// for IR-level transformations.
212212
class TargetTransformInfo {
213213
public:
214+
enum PartialReductionExtendKind { PR_None, PR_SignExtend, PR_ZeroExtend };
215+
216+
/// Get the kind of extension that an instruction represents.
217+
static PartialReductionExtendKind
218+
getPartialReductionExtendKind(Instruction *I);
219+
214220
/// Construct a TTI object using a type implementing the \c Concept
215221
/// API below.
216222
///
@@ -1274,6 +1280,18 @@ class TargetTransformInfo {
12741280
/// \return if target want to issue a prefetch in address space \p AS.
12751281
bool shouldPrefetchAddressSpace(unsigned AS) const;
12761282

1283+
/// \return The cost of a partial reduction, which is a reduction from a
1284+
/// vector to another vector with fewer elements of larger size. They are
1285+
/// represented by the llvm.experimental.partial.reduce.add intrinsic, which
1286+
/// takes an accumulator and a binary operation operand that itself is fed by
1287+
/// two extends. An example of an operation that uses a partial reduction is a
1288+
/// dot product, which reduces a vector to another of 4 times fewer elements.
1289+
InstructionCost
1290+
getPartialReductionCost(unsigned Opcode, Type *InputType, Type *AccumType,
1291+
ElementCount VF, PartialReductionExtendKind OpAExtend,
1292+
PartialReductionExtendKind OpBExtend,
1293+
std::optional<unsigned> BinOp = std::nullopt) const;
1294+
12771295
/// \return The maximum interleave factor that any transform should try to
12781296
/// perform for this target. This number depends on the level of parallelism
12791297
/// and the number of execution units in the CPU.
@@ -2098,6 +2116,18 @@ class TargetTransformInfo::Concept {
20982116
/// \return if target want to issue a prefetch in address space \p AS.
20992117
virtual bool shouldPrefetchAddressSpace(unsigned AS) const = 0;
21002118

2119+
/// \return The cost of a partial reduction, which is a reduction from a
2120+
/// vector to another vector with fewer elements of larger size. They are
2121+
/// represented by the llvm.experimental.partial.reduce.add intrinsic, which
2122+
/// takes an accumulator and a binary operation operand that itself is fed by
2123+
/// two extends. An example of an operation that uses a partial reduction is a
2124+
/// dot product, which reduces a vector to another of 4 times fewer elements.
2125+
virtual InstructionCost
2126+
getPartialReductionCost(unsigned Opcode, Type *InputType, Type *AccumType,
2127+
ElementCount VF, PartialReductionExtendKind OpAExtend,
2128+
PartialReductionExtendKind OpBExtend,
2129+
std::optional<unsigned> BinOp) const = 0;
2130+
21012131
virtual unsigned getMaxInterleaveFactor(ElementCount VF) = 0;
21022132
virtual InstructionCost getArithmeticInstrCost(
21032133
unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
@@ -2772,6 +2802,15 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
27722802
return Impl.shouldPrefetchAddressSpace(AS);
27732803
}
27742804

2805+
InstructionCost getPartialReductionCost(
2806+
unsigned Opcode, Type *InputType, Type *AccumType, ElementCount VF,
2807+
PartialReductionExtendKind OpAExtend,
2808+
PartialReductionExtendKind OpBExtend,
2809+
std::optional<unsigned> BinOp = std::nullopt) const override {
2810+
return Impl.getPartialReductionCost(Opcode, InputType, AccumType, VF,
2811+
OpAExtend, OpBExtend, BinOp);
2812+
}
2813+
27752814
unsigned getMaxInterleaveFactor(ElementCount VF) override {
27762815
return Impl.getMaxInterleaveFactor(VF);
27772816
}

llvm/include/llvm/Analysis/TargetTransformInfoImpl.h

+9
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,15 @@ class TargetTransformInfoImplBase {
580580
bool enableWritePrefetching() const { return false; }
581581
bool shouldPrefetchAddressSpace(unsigned AS) const { return !AS; }
582582

583+
InstructionCost
584+
getPartialReductionCost(unsigned Opcode, Type *InputType, Type *AccumType,
585+
ElementCount VF,
586+
TTI::PartialReductionExtendKind OpAExtend,
587+
TTI::PartialReductionExtendKind OpBExtend,
588+
std::optional<unsigned> BinOp = std::nullopt) const {
589+
return InstructionCost::getInvalid();
590+
}
591+
583592
unsigned getMaxInterleaveFactor(ElementCount VF) const { return 1; }
584593

585594
InstructionCost getArithmeticInstrCost(

llvm/lib/Analysis/TargetTransformInfo.cpp

+17
Original file line numberDiff line numberDiff line change
@@ -858,6 +858,14 @@ bool TargetTransformInfo::shouldPrefetchAddressSpace(unsigned AS) const {
858858
return TTIImpl->shouldPrefetchAddressSpace(AS);
859859
}
860860

861+
InstructionCost TargetTransformInfo::getPartialReductionCost(
862+
unsigned Opcode, Type *InputType, Type *AccumType, ElementCount VF,
863+
PartialReductionExtendKind OpAExtend, PartialReductionExtendKind OpBExtend,
864+
std::optional<unsigned> BinOp) const {
865+
return TTIImpl->getPartialReductionCost(Opcode, InputType, AccumType, VF,
866+
OpAExtend, OpBExtend, BinOp);
867+
}
868+
861869
unsigned TargetTransformInfo::getMaxInterleaveFactor(ElementCount VF) const {
862870
return TTIImpl->getMaxInterleaveFactor(VF);
863871
}
@@ -969,6 +977,15 @@ InstructionCost TargetTransformInfo::getShuffleCost(
969977
return Cost;
970978
}
971979

980+
TargetTransformInfo::PartialReductionExtendKind
981+
TargetTransformInfo::getPartialReductionExtendKind(Instruction *I) {
982+
if (isa<SExtInst>(I))
983+
return PR_SignExtend;
984+
if (isa<ZExtInst>(I))
985+
return PR_ZeroExtend;
986+
return PR_None;
987+
}
988+
972989
TTI::CastContextHint
973990
TargetTransformInfo::getCastContextHint(const Instruction *I) {
974991
if (!I)

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h

+56
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "llvm/CodeGen/BasicTTIImpl.h"
2424
#include "llvm/IR/Function.h"
2525
#include "llvm/IR/Intrinsics.h"
26+
#include "llvm/Support/InstructionCost.h"
2627
#include <cstdint>
2728
#include <optional>
2829

@@ -357,6 +358,61 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
357358
return BaseT::isLegalNTLoad(DataType, Alignment);
358359
}
359360

361+
InstructionCost
362+
getPartialReductionCost(unsigned Opcode, Type *InputType, Type *AccumType,
363+
ElementCount VF,
364+
TTI::PartialReductionExtendKind OpAExtend,
365+
TTI::PartialReductionExtendKind OpBExtend,
366+
std::optional<unsigned> BinOp) const {
367+
368+
InstructionCost Invalid = InstructionCost::getInvalid();
369+
InstructionCost Cost(TTI::TCC_Basic);
370+
371+
if (Opcode != Instruction::Add)
372+
return Invalid;
373+
374+
EVT InputEVT = EVT::getEVT(InputType);
375+
EVT AccumEVT = EVT::getEVT(AccumType);
376+
377+
if (VF.isScalable() && !ST->isSVEorStreamingSVEAvailable())
378+
return Invalid;
379+
if (VF.isFixed() && (!ST->isNeonAvailable() || !ST->hasDotProd()))
380+
return Invalid;
381+
382+
if (InputEVT == MVT::i8) {
383+
switch (VF.getKnownMinValue()) {
384+
default:
385+
return Invalid;
386+
case 8:
387+
if (AccumEVT == MVT::i32)
388+
Cost *= 2;
389+
else if (AccumEVT != MVT::i64)
390+
return Invalid;
391+
break;
392+
case 16:
393+
if (AccumEVT == MVT::i64)
394+
Cost *= 2;
395+
else if (AccumEVT != MVT::i32)
396+
return Invalid;
397+
break;
398+
}
399+
} else if (InputEVT == MVT::i16) {
400+
// FIXME: Allow i32 accumulator but increase cost, as we would extend
401+
// it to i64.
402+
if (VF.getKnownMinValue() != 8 || AccumEVT != MVT::i64)
403+
return Invalid;
404+
} else
405+
return Invalid;
406+
407+
if (OpAExtend == TTI::PR_None || OpBExtend == TTI::PR_None)
408+
return Invalid;
409+
410+
if (!BinOp || (*BinOp) != Instruction::Mul)
411+
return Invalid;
412+
413+
return Cost;
414+
}
415+
360416
bool enableOrderedReductions() const { return true; }
361417

362418
InstructionCost getInterleavedMemoryOpCost(

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

+132-4
Original file line numberDiff line numberDiff line change
@@ -7605,6 +7605,10 @@ static bool planContainsAdditionalSimplifications(VPlan &Plan,
76057605
}
76067606
continue;
76077607
}
7608+
// The VPlan-based cost model is more accurate for partial reduction and
7609+
// comparing against the legacy cost isn't desirable.
7610+
if (isa<VPPartialReductionRecipe>(&R))
7611+
return true;
76087612
if (Instruction *UI = GetInstructionForCost(&R))
76097613
SeenInstrs.insert(UI);
76107614
}
@@ -8827,6 +8831,103 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(Instruction *I,
88278831
return Recipe;
88288832
}
88298833

8834+
/// Find all possible partial reductions in the loop and track all of those that
8835+
/// are valid so recipes can be formed later.
8836+
void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
8837+
// Find all possible partial reductions.
8838+
SmallVector<std::pair<PartialReductionChain, unsigned>, 1>
8839+
PartialReductionChains;
8840+
for (const auto &[Phi, RdxDesc] : Legal->getReductionVars())
8841+
if (std::optional<std::pair<PartialReductionChain, unsigned>> Pair =
8842+
getScaledReduction(Phi, RdxDesc, Range))
8843+
PartialReductionChains.push_back(*Pair);
8844+
8845+
// A partial reduction is invalid if any of its extends are used by
8846+
// something that isn't another partial reduction. This is because the
8847+
// extends are intended to be lowered along with the reduction itself.
8848+
8849+
// Build up a set of partial reduction bin ops for efficient use checking.
8850+
SmallSet<User *, 4> PartialReductionBinOps;
8851+
for (const auto &[PartialRdx, _] : PartialReductionChains)
8852+
PartialReductionBinOps.insert(PartialRdx.BinOp);
8853+
8854+
auto ExtendIsOnlyUsedByPartialReductions =
8855+
[&PartialReductionBinOps](Instruction *Extend) {
8856+
return all_of(Extend->users(), [&](const User *U) {
8857+
return PartialReductionBinOps.contains(U);
8858+
});
8859+
};
8860+
8861+
// Check if each use of a chain's two extends is a partial reduction
8862+
// and only add those that don't have non-partial reduction users.
8863+
for (auto Pair : PartialReductionChains) {
8864+
PartialReductionChain Chain = Pair.first;
8865+
if (ExtendIsOnlyUsedByPartialReductions(Chain.ExtendA) &&
8866+
ExtendIsOnlyUsedByPartialReductions(Chain.ExtendB))
8867+
ScaledReductionExitInstrs.insert(std::make_pair(Chain.Reduction, Pair));
8868+
}
8869+
}
8870+
8871+
std::optional<std::pair<PartialReductionChain, unsigned>>
8872+
VPRecipeBuilder::getScaledReduction(PHINode *PHI,
8873+
const RecurrenceDescriptor &Rdx,
8874+
VFRange &Range) {
8875+
// TODO: Allow scaling reductions when predicating. The select at
8876+
// the end of the loop chooses between the phi value and most recent
8877+
// reduction result, both of which have different VFs to the active lane
8878+
// mask when scaling.
8879+
if (CM.blockNeedsPredicationForAnyReason(Rdx.getLoopExitInstr()->getParent()))
8880+
return std::nullopt;
8881+
8882+
auto *Update = dyn_cast<BinaryOperator>(Rdx.getLoopExitInstr());
8883+
if (!Update)
8884+
return std::nullopt;
8885+
8886+
Value *Op = Update->getOperand(0);
8887+
if (Op == PHI)
8888+
Op = Update->getOperand(1);
8889+
8890+
auto *BinOp = dyn_cast<BinaryOperator>(Op);
8891+
if (!BinOp || !BinOp->hasOneUse())
8892+
return std::nullopt;
8893+
8894+
using namespace llvm::PatternMatch;
8895+
Value *A, *B;
8896+
if (!match(BinOp->getOperand(0), m_ZExtOrSExt(m_Value(A))) ||
8897+
!match(BinOp->getOperand(1), m_ZExtOrSExt(m_Value(B))))
8898+
return std::nullopt;
8899+
8900+
Instruction *ExtA = cast<Instruction>(BinOp->getOperand(0));
8901+
Instruction *ExtB = cast<Instruction>(BinOp->getOperand(1));
8902+
8903+
// Check that the extends extend from the same type.
8904+
if (A->getType() != B->getType())
8905+
return std::nullopt;
8906+
8907+
TTI::PartialReductionExtendKind OpAExtend =
8908+
TargetTransformInfo::getPartialReductionExtendKind(ExtA);
8909+
TTI::PartialReductionExtendKind OpBExtend =
8910+
TargetTransformInfo::getPartialReductionExtendKind(ExtB);
8911+
8912+
PartialReductionChain Chain(Rdx.getLoopExitInstr(), ExtA, ExtB, BinOp);
8913+
8914+
unsigned TargetScaleFactor =
8915+
PHI->getType()->getPrimitiveSizeInBits().getKnownScalarFactor(
8916+
A->getType()->getPrimitiveSizeInBits());
8917+
8918+
if (LoopVectorizationPlanner::getDecisionAndClampRange(
8919+
[&](ElementCount VF) {
8920+
InstructionCost Cost = TTI->getPartialReductionCost(
8921+
Update->getOpcode(), A->getType(), PHI->getType(), VF,
8922+
OpAExtend, OpBExtend, std::make_optional(BinOp->getOpcode()));
8923+
return Cost.isValid();
8924+
},
8925+
Range))
8926+
return std::make_pair(Chain, TargetScaleFactor);
8927+
8928+
return std::nullopt;
8929+
}
8930+
88308931
VPRecipeBase *
88318932
VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
88328933
ArrayRef<VPValue *> Operands,
@@ -8851,9 +8952,14 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
88518952
Legal->getReductionVars().find(Phi)->second;
88528953
assert(RdxDesc.getRecurrenceStartValue() ==
88538954
Phi->getIncomingValueForBlock(OrigLoop->getLoopPreheader()));
8854-
PhiRecipe = new VPReductionPHIRecipe(Phi, RdxDesc, *StartV,
8855-
CM.isInLoopReduction(Phi),
8856-
CM.useOrderedReductions(RdxDesc));
8955+
8956+
// If the PHI is used by a partial reduction, set the scale factor.
8957+
std::optional<std::pair<PartialReductionChain, unsigned>> Pair =
8958+
getScaledReductionForInstr(RdxDesc.getLoopExitInstr());
8959+
unsigned ScaleFactor = Pair ? Pair->second : 1;
8960+
PhiRecipe = new VPReductionPHIRecipe(
8961+
Phi, RdxDesc, *StartV, CM.isInLoopReduction(Phi),
8962+
CM.useOrderedReductions(RdxDesc), ScaleFactor);
88578963
} else {
88588964
// TODO: Currently fixed-order recurrences are modeled as chains of
88598965
// first-order recurrences. If there are no users of the intermediate
@@ -8885,6 +8991,9 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
88858991
if (isa<LoadInst>(Instr) || isa<StoreInst>(Instr))
88868992
return tryToWidenMemory(Instr, Operands, Range);
88878993

8994+
if (getScaledReductionForInstr(Instr))
8995+
return tryToCreatePartialReduction(Instr, Operands);
8996+
88888997
if (!shouldWiden(Instr, Range))
88898998
return nullptr;
88908999

@@ -8905,6 +9014,21 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
89059014
return tryToWiden(Instr, Operands, VPBB);
89069015
}
89079016

9017+
VPRecipeBase *
9018+
VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction,
9019+
ArrayRef<VPValue *> Operands) {
9020+
assert(Operands.size() == 2 &&
9021+
"Unexpected number of operands for partial reduction");
9022+
9023+
VPValue *BinOp = Operands[0];
9024+
VPValue *Phi = Operands[1];
9025+
if (isa<VPReductionPHIRecipe>(BinOp->getDefiningRecipe()))
9026+
std::swap(BinOp, Phi);
9027+
9028+
return new VPPartialReductionRecipe(Reduction->getOpcode(), BinOp, Phi,
9029+
Reduction);
9030+
}
9031+
89089032
void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF,
89099033
ElementCount MaxVF) {
89109034
assert(OrigLoop->isInnermost() && "Inner loop expected.");
@@ -9222,7 +9346,8 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
92229346
bool HasNUW = !IVUpdateMayOverflow || Style == TailFoldingStyle::None;
92239347
addCanonicalIVRecipes(*Plan, Legal->getWidestInductionType(), HasNUW, DL);
92249348

9225-
VPRecipeBuilder RecipeBuilder(*Plan, OrigLoop, TLI, Legal, CM, PSE, Builder);
9349+
VPRecipeBuilder RecipeBuilder(*Plan, OrigLoop, TLI, &TTI, Legal, CM, PSE,
9350+
Builder);
92269351

92279352
// ---------------------------------------------------------------------------
92289353
// Pre-construction: record ingredients whose recipes we'll need to further
@@ -9268,6 +9393,9 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
92689393
bool NeedsBlends = BB != HeaderBB && !BB->phis().empty();
92699394
return Legal->blockNeedsPredication(BB) || NeedsBlends;
92709395
});
9396+
9397+
RecipeBuilder.collectScaledReductions(Range);
9398+
92719399
auto *MiddleVPBB = Plan->getMiddleBlock();
92729400
VPBasicBlock::iterator MBIP = MiddleVPBB->getFirstNonPhi();
92739401
for (BasicBlock *BB : make_range(DFS.beginRPO(), DFS.endRPO())) {

0 commit comments

Comments
 (0)