Skip to content

Commit c858bf6

Browse files
Reland "[LoopVectorizer] Add support for partial reductions" (llvm#120721)
This re-lands the reverted llvm#92418 When the VF is small enough so that dividing the VF by the scaling factor results in 1, the reduction phi execution thinks the VF is scalar and sets the reduction's output as a scalar value, tripping assertions expecting a vector value. The latest commit in this PR fixes that by using `State.VF` in the scalar check, rather than the divided VF. --------- Co-authored-by: Nicholas Guy <nicholas.guy@arm.com>
1 parent b2073fb commit c858bf6

16 files changed

+3927
-30
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

+39
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,12 @@ typedef TargetTransformInfo TTI;
211211
/// for IR-level transformations.
212212
class TargetTransformInfo {
213213
public:
214+
enum PartialReductionExtendKind { PR_None, PR_SignExtend, PR_ZeroExtend };
215+
216+
/// Get the kind of extension that an instruction represents.
217+
static PartialReductionExtendKind
218+
getPartialReductionExtendKind(Instruction *I);
219+
214220
/// Construct a TTI object using a type implementing the \c Concept
215221
/// API below.
216222
///
@@ -1280,6 +1286,18 @@ class TargetTransformInfo {
12801286
/// \return if target want to issue a prefetch in address space \p AS.
12811287
bool shouldPrefetchAddressSpace(unsigned AS) const;
12821288

1289+
/// \return The cost of a partial reduction, which is a reduction from a
1290+
/// vector to another vector with fewer elements of larger size. They are
1291+
/// represented by the llvm.experimental.partial.reduce.add intrinsic, which
1292+
/// takes an accumulator and a binary operation operand that itself is fed by
1293+
/// two extends. An example of an operation that uses a partial reduction is a
1294+
/// dot product, which reduces a vector to another of 4 times fewer elements.
1295+
InstructionCost
1296+
getPartialReductionCost(unsigned Opcode, Type *InputType, Type *AccumType,
1297+
ElementCount VF, PartialReductionExtendKind OpAExtend,
1298+
PartialReductionExtendKind OpBExtend,
1299+
std::optional<unsigned> BinOp = std::nullopt) const;
1300+
12831301
/// \return The maximum interleave factor that any transform should try to
12841302
/// perform for this target. This number depends on the level of parallelism
12851303
/// and the number of execution units in the CPU.
@@ -2107,6 +2125,18 @@ class TargetTransformInfo::Concept {
21072125
/// \return if target want to issue a prefetch in address space \p AS.
21082126
virtual bool shouldPrefetchAddressSpace(unsigned AS) const = 0;
21092127

2128+
/// \return The cost of a partial reduction, which is a reduction from a
2129+
/// vector to another vector with fewer elements of larger size. They are
2130+
/// represented by the llvm.experimental.partial.reduce.add intrinsic, which
2131+
/// takes an accumulator and a binary operation operand that itself is fed by
2132+
/// two extends. An example of an operation that uses a partial reduction is a
2133+
/// dot product, which reduces a vector to another of 4 times fewer elements.
2134+
virtual InstructionCost
2135+
getPartialReductionCost(unsigned Opcode, Type *InputType, Type *AccumType,
2136+
ElementCount VF, PartialReductionExtendKind OpAExtend,
2137+
PartialReductionExtendKind OpBExtend,
2138+
std::optional<unsigned> BinOp) const = 0;
2139+
21102140
virtual unsigned getMaxInterleaveFactor(ElementCount VF) = 0;
21112141
virtual InstructionCost getArithmeticInstrCost(
21122142
unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
@@ -2786,6 +2816,15 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
27862816
return Impl.shouldPrefetchAddressSpace(AS);
27872817
}
27882818

2819+
InstructionCost getPartialReductionCost(
2820+
unsigned Opcode, Type *InputType, Type *AccumType, ElementCount VF,
2821+
PartialReductionExtendKind OpAExtend,
2822+
PartialReductionExtendKind OpBExtend,
2823+
std::optional<unsigned> BinOp = std::nullopt) const override {
2824+
return Impl.getPartialReductionCost(Opcode, InputType, AccumType, VF,
2825+
OpAExtend, OpBExtend, BinOp);
2826+
}
2827+
27892828
unsigned getMaxInterleaveFactor(ElementCount VF) override {
27902829
return Impl.getMaxInterleaveFactor(VF);
27912830
}

llvm/include/llvm/Analysis/TargetTransformInfoImpl.h

+9
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,15 @@ class TargetTransformInfoImplBase {
585585
bool enableWritePrefetching() const { return false; }
586586
bool shouldPrefetchAddressSpace(unsigned AS) const { return !AS; }
587587

588+
InstructionCost
589+
getPartialReductionCost(unsigned Opcode, Type *InputType, Type *AccumType,
590+
ElementCount VF,
591+
TTI::PartialReductionExtendKind OpAExtend,
592+
TTI::PartialReductionExtendKind OpBExtend,
593+
std::optional<unsigned> BinOp = std::nullopt) const {
594+
return InstructionCost::getInvalid();
595+
}
596+
588597
unsigned getMaxInterleaveFactor(ElementCount VF) const { return 1; }
589598

590599
InstructionCost getArithmeticInstrCost(

llvm/lib/Analysis/TargetTransformInfo.cpp

+17
Original file line numberDiff line numberDiff line change
@@ -863,6 +863,14 @@ bool TargetTransformInfo::shouldPrefetchAddressSpace(unsigned AS) const {
863863
return TTIImpl->shouldPrefetchAddressSpace(AS);
864864
}
865865

866+
InstructionCost TargetTransformInfo::getPartialReductionCost(
867+
unsigned Opcode, Type *InputType, Type *AccumType, ElementCount VF,
868+
PartialReductionExtendKind OpAExtend, PartialReductionExtendKind OpBExtend,
869+
std::optional<unsigned> BinOp) const {
870+
return TTIImpl->getPartialReductionCost(Opcode, InputType, AccumType, VF,
871+
OpAExtend, OpBExtend, BinOp);
872+
}
873+
866874
unsigned TargetTransformInfo::getMaxInterleaveFactor(ElementCount VF) const {
867875
return TTIImpl->getMaxInterleaveFactor(VF);
868876
}
@@ -974,6 +982,15 @@ InstructionCost TargetTransformInfo::getShuffleCost(
974982
return Cost;
975983
}
976984

985+
TargetTransformInfo::PartialReductionExtendKind
986+
TargetTransformInfo::getPartialReductionExtendKind(Instruction *I) {
987+
if (isa<SExtInst>(I))
988+
return PR_SignExtend;
989+
if (isa<ZExtInst>(I))
990+
return PR_ZeroExtend;
991+
return PR_None;
992+
}
993+
977994
TTI::CastContextHint
978995
TargetTransformInfo::getCastContextHint(const Instruction *I) {
979996
if (!I)

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h

+56
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "llvm/CodeGen/BasicTTIImpl.h"
2424
#include "llvm/IR/Function.h"
2525
#include "llvm/IR/Intrinsics.h"
26+
#include "llvm/Support/InstructionCost.h"
2627
#include <cstdint>
2728
#include <optional>
2829

@@ -357,6 +358,61 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
357358
return BaseT::isLegalNTLoad(DataType, Alignment);
358359
}
359360

361+
InstructionCost
362+
getPartialReductionCost(unsigned Opcode, Type *InputType, Type *AccumType,
363+
ElementCount VF,
364+
TTI::PartialReductionExtendKind OpAExtend,
365+
TTI::PartialReductionExtendKind OpBExtend,
366+
std::optional<unsigned> BinOp) const {
367+
368+
InstructionCost Invalid = InstructionCost::getInvalid();
369+
InstructionCost Cost(TTI::TCC_Basic);
370+
371+
if (Opcode != Instruction::Add)
372+
return Invalid;
373+
374+
EVT InputEVT = EVT::getEVT(InputType);
375+
EVT AccumEVT = EVT::getEVT(AccumType);
376+
377+
if (VF.isScalable() && !ST->isSVEorStreamingSVEAvailable())
378+
return Invalid;
379+
if (VF.isFixed() && (!ST->isNeonAvailable() || !ST->hasDotProd()))
380+
return Invalid;
381+
382+
if (InputEVT == MVT::i8) {
383+
switch (VF.getKnownMinValue()) {
384+
default:
385+
return Invalid;
386+
case 8:
387+
if (AccumEVT == MVT::i32)
388+
Cost *= 2;
389+
else if (AccumEVT != MVT::i64)
390+
return Invalid;
391+
break;
392+
case 16:
393+
if (AccumEVT == MVT::i64)
394+
Cost *= 2;
395+
else if (AccumEVT != MVT::i32)
396+
return Invalid;
397+
break;
398+
}
399+
} else if (InputEVT == MVT::i16) {
400+
// FIXME: Allow i32 accumulator but increase cost, as we would extend
401+
// it to i64.
402+
if (VF.getKnownMinValue() != 8 || AccumEVT != MVT::i64)
403+
return Invalid;
404+
} else
405+
return Invalid;
406+
407+
if (OpAExtend == TTI::PR_None || OpBExtend == TTI::PR_None)
408+
return Invalid;
409+
410+
if (!BinOp || (*BinOp) != Instruction::Mul)
411+
return Invalid;
412+
413+
return Cost;
414+
}
415+
360416
bool enableOrderedReductions() const { return true; }
361417

362418
InstructionCost getInterleavedMemoryOpCost(

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

+132-4
Original file line numberDiff line numberDiff line change
@@ -7605,6 +7605,10 @@ static bool planContainsAdditionalSimplifications(VPlan &Plan,
76057605
}
76067606
continue;
76077607
}
7608+
// The VPlan-based cost model is more accurate for partial reduction and
7609+
// comparing against the legacy cost isn't desirable.
7610+
if (isa<VPPartialReductionRecipe>(&R))
7611+
return true;
76087612
if (Instruction *UI = GetInstructionForCost(&R))
76097613
SeenInstrs.insert(UI);
76107614
}
@@ -8827,6 +8831,103 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(Instruction *I,
88278831
return Recipe;
88288832
}
88298833

8834+
/// Find all possible partial reductions in the loop and track all of those that
8835+
/// are valid so recipes can be formed later.
8836+
void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
8837+
// Find all possible partial reductions.
8838+
SmallVector<std::pair<PartialReductionChain, unsigned>, 1>
8839+
PartialReductionChains;
8840+
for (const auto &[Phi, RdxDesc] : Legal->getReductionVars())
8841+
if (std::optional<std::pair<PartialReductionChain, unsigned>> Pair =
8842+
getScaledReduction(Phi, RdxDesc, Range))
8843+
PartialReductionChains.push_back(*Pair);
8844+
8845+
// A partial reduction is invalid if any of its extends are used by
8846+
// something that isn't another partial reduction. This is because the
8847+
// extends are intended to be lowered along with the reduction itself.
8848+
8849+
// Build up a set of partial reduction bin ops for efficient use checking.
8850+
SmallSet<User *, 4> PartialReductionBinOps;
8851+
for (const auto &[PartialRdx, _] : PartialReductionChains)
8852+
PartialReductionBinOps.insert(PartialRdx.BinOp);
8853+
8854+
auto ExtendIsOnlyUsedByPartialReductions =
8855+
[&PartialReductionBinOps](Instruction *Extend) {
8856+
return all_of(Extend->users(), [&](const User *U) {
8857+
return PartialReductionBinOps.contains(U);
8858+
});
8859+
};
8860+
8861+
// Check if each use of a chain's two extends is a partial reduction
8862+
// and only add those that don't have non-partial reduction users.
8863+
for (auto Pair : PartialReductionChains) {
8864+
PartialReductionChain Chain = Pair.first;
8865+
if (ExtendIsOnlyUsedByPartialReductions(Chain.ExtendA) &&
8866+
ExtendIsOnlyUsedByPartialReductions(Chain.ExtendB))
8867+
ScaledReductionExitInstrs.insert(std::make_pair(Chain.Reduction, Pair));
8868+
}
8869+
}
8870+
8871+
std::optional<std::pair<PartialReductionChain, unsigned>>
8872+
VPRecipeBuilder::getScaledReduction(PHINode *PHI,
8873+
const RecurrenceDescriptor &Rdx,
8874+
VFRange &Range) {
8875+
// TODO: Allow scaling reductions when predicating. The select at
8876+
// the end of the loop chooses between the phi value and most recent
8877+
// reduction result, both of which have different VFs to the active lane
8878+
// mask when scaling.
8879+
if (CM.blockNeedsPredicationForAnyReason(Rdx.getLoopExitInstr()->getParent()))
8880+
return std::nullopt;
8881+
8882+
auto *Update = dyn_cast<BinaryOperator>(Rdx.getLoopExitInstr());
8883+
if (!Update)
8884+
return std::nullopt;
8885+
8886+
Value *Op = Update->getOperand(0);
8887+
if (Op == PHI)
8888+
Op = Update->getOperand(1);
8889+
8890+
auto *BinOp = dyn_cast<BinaryOperator>(Op);
8891+
if (!BinOp || !BinOp->hasOneUse())
8892+
return std::nullopt;
8893+
8894+
using namespace llvm::PatternMatch;
8895+
Value *A, *B;
8896+
if (!match(BinOp->getOperand(0), m_ZExtOrSExt(m_Value(A))) ||
8897+
!match(BinOp->getOperand(1), m_ZExtOrSExt(m_Value(B))))
8898+
return std::nullopt;
8899+
8900+
Instruction *ExtA = cast<Instruction>(BinOp->getOperand(0));
8901+
Instruction *ExtB = cast<Instruction>(BinOp->getOperand(1));
8902+
8903+
// Check that the extends extend from the same type.
8904+
if (A->getType() != B->getType())
8905+
return std::nullopt;
8906+
8907+
TTI::PartialReductionExtendKind OpAExtend =
8908+
TargetTransformInfo::getPartialReductionExtendKind(ExtA);
8909+
TTI::PartialReductionExtendKind OpBExtend =
8910+
TargetTransformInfo::getPartialReductionExtendKind(ExtB);
8911+
8912+
PartialReductionChain Chain(Rdx.getLoopExitInstr(), ExtA, ExtB, BinOp);
8913+
8914+
unsigned TargetScaleFactor =
8915+
PHI->getType()->getPrimitiveSizeInBits().getKnownScalarFactor(
8916+
A->getType()->getPrimitiveSizeInBits());
8917+
8918+
if (LoopVectorizationPlanner::getDecisionAndClampRange(
8919+
[&](ElementCount VF) {
8920+
InstructionCost Cost = TTI->getPartialReductionCost(
8921+
Update->getOpcode(), A->getType(), PHI->getType(), VF,
8922+
OpAExtend, OpBExtend, std::make_optional(BinOp->getOpcode()));
8923+
return Cost.isValid();
8924+
},
8925+
Range))
8926+
return std::make_pair(Chain, TargetScaleFactor);
8927+
8928+
return std::nullopt;
8929+
}
8930+
88308931
VPRecipeBase *
88318932
VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
88328933
ArrayRef<VPValue *> Operands,
@@ -8851,9 +8952,14 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
88518952
Legal->getReductionVars().find(Phi)->second;
88528953
assert(RdxDesc.getRecurrenceStartValue() ==
88538954
Phi->getIncomingValueForBlock(OrigLoop->getLoopPreheader()));
8854-
PhiRecipe = new VPReductionPHIRecipe(Phi, RdxDesc, *StartV,
8855-
CM.isInLoopReduction(Phi),
8856-
CM.useOrderedReductions(RdxDesc));
8955+
8956+
// If the PHI is used by a partial reduction, set the scale factor.
8957+
std::optional<std::pair<PartialReductionChain, unsigned>> Pair =
8958+
getScaledReductionForInstr(RdxDesc.getLoopExitInstr());
8959+
unsigned ScaleFactor = Pair ? Pair->second : 1;
8960+
PhiRecipe = new VPReductionPHIRecipe(
8961+
Phi, RdxDesc, *StartV, CM.isInLoopReduction(Phi),
8962+
CM.useOrderedReductions(RdxDesc), ScaleFactor);
88578963
} else {
88588964
// TODO: Currently fixed-order recurrences are modeled as chains of
88598965
// first-order recurrences. If there are no users of the intermediate
@@ -8885,6 +8991,9 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
88858991
if (isa<LoadInst>(Instr) || isa<StoreInst>(Instr))
88868992
return tryToWidenMemory(Instr, Operands, Range);
88878993

8994+
if (getScaledReductionForInstr(Instr))
8995+
return tryToCreatePartialReduction(Instr, Operands);
8996+
88888997
if (!shouldWiden(Instr, Range))
88898998
return nullptr;
88908999

@@ -8905,6 +9014,21 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
89059014
return tryToWiden(Instr, Operands, VPBB);
89069015
}
89079016

9017+
VPRecipeBase *
9018+
VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction,
9019+
ArrayRef<VPValue *> Operands) {
9020+
assert(Operands.size() == 2 &&
9021+
"Unexpected number of operands for partial reduction");
9022+
9023+
VPValue *BinOp = Operands[0];
9024+
VPValue *Phi = Operands[1];
9025+
if (isa<VPReductionPHIRecipe>(BinOp->getDefiningRecipe()))
9026+
std::swap(BinOp, Phi);
9027+
9028+
return new VPPartialReductionRecipe(Reduction->getOpcode(), BinOp, Phi,
9029+
Reduction);
9030+
}
9031+
89089032
void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF,
89099033
ElementCount MaxVF) {
89109034
assert(OrigLoop->isInnermost() && "Inner loop expected.");
@@ -9222,7 +9346,8 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
92229346
bool HasNUW = !IVUpdateMayOverflow || Style == TailFoldingStyle::None;
92239347
addCanonicalIVRecipes(*Plan, Legal->getWidestInductionType(), HasNUW, DL);
92249348

9225-
VPRecipeBuilder RecipeBuilder(*Plan, OrigLoop, TLI, Legal, CM, PSE, Builder);
9349+
VPRecipeBuilder RecipeBuilder(*Plan, OrigLoop, TLI, &TTI, Legal, CM, PSE,
9350+
Builder);
92269351

92279352
// ---------------------------------------------------------------------------
92289353
// Pre-construction: record ingredients whose recipes we'll need to further
@@ -9268,6 +9393,9 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
92689393
bool NeedsBlends = BB != HeaderBB && !BB->phis().empty();
92699394
return Legal->blockNeedsPredication(BB) || NeedsBlends;
92709395
});
9396+
9397+
RecipeBuilder.collectScaledReductions(Range);
9398+
92719399
auto *MiddleVPBB = Plan->getMiddleBlock();
92729400
VPBasicBlock::iterator MBIP = MiddleVPBB->getFirstNonPhi();
92739401
for (BasicBlock *BB : make_range(DFS.beginRPO(), DFS.endRPO())) {

0 commit comments

Comments
 (0)