Skip to content

Commit 3874ca2

Browse files
SamTebbs33fhahn
authored andcommitted
Reland "[LoopVectorizer] Add support for partial reductions" with non-phi operand fix. (llvm#121744)
This relands the reverted llvm#120721 with a fix for cases where neither reduction operand are the reduction phi. Only 6311423 and 6311423 are new on top of the reverted PR. --------- Co-authored-by: Nicholas Guy <nicholas.guy@arm.com> (cherry-picked from 795e35a)
1 parent 8a823e3 commit 3874ca2

15 files changed

+7369
-18
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

+44
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,12 @@ typedef TargetTransformInfo TTI;
213213
/// for IR-level transformations.
214214
class TargetTransformInfo {
215215
public:
216+
enum PartialReductionExtendKind { PR_None, PR_SignExtend, PR_ZeroExtend };
217+
218+
/// Get the kind of extension that an instruction represents.
219+
static PartialReductionExtendKind
220+
getPartialReductionExtendKind(Instruction *I);
221+
216222
/// Construct a TTI object using a type implementing the \c Concept
217223
/// API below.
218224
///
@@ -1257,6 +1263,20 @@ class TargetTransformInfo {
12571263
/// \return if target want to issue a prefetch in address space \p AS.
12581264
bool shouldPrefetchAddressSpace(unsigned AS) const;
12591265

1266+
/// \return The cost of a partial reduction, which is a reduction from a
1267+
/// vector to another vector with fewer elements of larger size. They are
1268+
/// represented by the llvm.experimental.partial.reduce.add intrinsic, which
1269+
/// takes an accumulator and a binary operation operand that itself is fed by
1270+
/// two extends. An example of an operation that uses a partial reduction is a
1271+
/// dot product, which reduces two vectors to another of 4 times fewer and 4
1272+
/// times larger elements.
1273+
InstructionCost
1274+
getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
1275+
Type *AccumType, ElementCount VF,
1276+
PartialReductionExtendKind OpAExtend,
1277+
PartialReductionExtendKind OpBExtend,
1278+
std::optional<unsigned> BinOp = std::nullopt) const;
1279+
12601280
/// \return The maximum interleave factor that any transform should try to
12611281
/// perform for this target. This number depends on the level of parallelism
12621282
/// and the number of execution units in the CPU.
@@ -2034,6 +2054,20 @@ class TargetTransformInfo::Concept {
20342054
/// \return if target want to issue a prefetch in address space \p AS.
20352055
virtual bool shouldPrefetchAddressSpace(unsigned AS) const = 0;
20362056

2057+
/// \return The cost of a partial reduction, which is a reduction from a
2058+
/// vector to another vector with fewer elements of larger size. They are
2059+
/// represented by the llvm.experimental.partial.reduce.add intrinsic, which
2060+
/// takes an accumulator and a binary operation operand that itself is fed by
2061+
/// two extends. An example of an operation that uses a partial reduction is a
2062+
/// dot product, which reduces two vectors to another of 4 times fewer and 4
2063+
/// times larger elements.
2064+
virtual InstructionCost
2065+
getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
2066+
Type *AccumType, ElementCount VF,
2067+
PartialReductionExtendKind OpAExtend,
2068+
PartialReductionExtendKind OpBExtend,
2069+
std::optional<unsigned> BinOp) const = 0;
2070+
20372071
virtual unsigned getMaxInterleaveFactor(ElementCount VF) = 0;
20382072
virtual InstructionCost getArithmeticInstrCost(
20392073
unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
@@ -2669,6 +2703,16 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
26692703
return Impl.shouldPrefetchAddressSpace(AS);
26702704
}
26712705

2706+
InstructionCost getPartialReductionCost(
2707+
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
2708+
ElementCount VF, PartialReductionExtendKind OpAExtend,
2709+
PartialReductionExtendKind OpBExtend,
2710+
std::optional<unsigned> BinOp = std::nullopt) const override {
2711+
return Impl.getPartialReductionCost(Opcode, InputTypeA, InputTypeB,
2712+
AccumType, VF, OpAExtend, OpBExtend,
2713+
BinOp);
2714+
}
2715+
26722716
unsigned getMaxInterleaveFactor(ElementCount VF) override {
26732717
return Impl.getMaxInterleaveFactor(VF);
26742718
}

llvm/include/llvm/Analysis/TargetTransformInfoImpl.h

+9
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,15 @@ class TargetTransformInfoImplBase {
543543
bool enableWritePrefetching() const { return false; }
544544
bool shouldPrefetchAddressSpace(unsigned AS) const { return !AS; }
545545

546+
InstructionCost
547+
getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
548+
Type *AccumType, ElementCount VF,
549+
TTI::PartialReductionExtendKind OpAExtend,
550+
TTI::PartialReductionExtendKind OpBExtend,
551+
std::optional<unsigned> BinOp = std::nullopt) const {
552+
return InstructionCost::getInvalid();
553+
}
554+
546555
unsigned getMaxInterleaveFactor(ElementCount VF) const { return 1; }
547556

548557
InstructionCost getArithmeticInstrCost(

llvm/lib/Analysis/TargetTransformInfo.cpp

+18
Original file line numberDiff line numberDiff line change
@@ -829,6 +829,15 @@ bool TargetTransformInfo::shouldPrefetchAddressSpace(unsigned AS) const {
829829
return TTIImpl->shouldPrefetchAddressSpace(AS);
830830
}
831831

832+
InstructionCost TargetTransformInfo::getPartialReductionCost(
833+
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
834+
ElementCount VF, PartialReductionExtendKind OpAExtend,
835+
PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp) const {
836+
return TTIImpl->getPartialReductionCost(Opcode, InputTypeA, InputTypeB,
837+
AccumType, VF, OpAExtend, OpBExtend,
838+
BinOp);
839+
}
840+
832841
unsigned TargetTransformInfo::getMaxInterleaveFactor(ElementCount VF) const {
833842
return TTIImpl->getMaxInterleaveFactor(VF);
834843
}
@@ -940,6 +949,15 @@ InstructionCost TargetTransformInfo::getShuffleCost(
940949
return Cost;
941950
}
942951

952+
TargetTransformInfo::PartialReductionExtendKind
953+
TargetTransformInfo::getPartialReductionExtendKind(Instruction *I) {
954+
if (isa<SExtInst>(I))
955+
return PR_SignExtend;
956+
if (isa<ZExtInst>(I))
957+
return PR_ZeroExtend;
958+
return PR_None;
959+
}
960+
943961
TTI::CastContextHint
944962
TargetTransformInfo::getCastContextHint(const Instruction *I) {
945963
if (!I)

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h

+63
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "llvm/CodeGen/BasicTTIImpl.h"
2525
#include "llvm/IR/Function.h"
2626
#include "llvm/IR/Intrinsics.h"
27+
#include "llvm/Support/InstructionCost.h"
2728
#include <cstdint>
2829
#include <optional>
2930

@@ -341,6 +342,68 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
341342
return BaseT::isLegalNTLoad(DataType, Alignment);
342343
}
343344

345+
InstructionCost
346+
getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
347+
Type *AccumType, ElementCount VF,
348+
TTI::PartialReductionExtendKind OpAExtend,
349+
TTI::PartialReductionExtendKind OpBExtend,
350+
std::optional<unsigned> BinOp) const {
351+
352+
InstructionCost Invalid = InstructionCost::getInvalid();
353+
InstructionCost Cost(TTI::TCC_Basic);
354+
355+
if (Opcode != Instruction::Add)
356+
return Invalid;
357+
358+
if (InputTypeA != InputTypeB)
359+
return Invalid;
360+
361+
EVT InputEVT = EVT::getEVT(InputTypeA);
362+
EVT AccumEVT = EVT::getEVT(AccumType);
363+
364+
if (VF.isScalable() && !ST->isSVEorStreamingSVEAvailable())
365+
return Invalid;
366+
if (VF.isFixed() && (!ST->isNeonAvailable() || !ST->hasDotProd()))
367+
return Invalid;
368+
369+
if (InputEVT == MVT::i8) {
370+
switch (VF.getKnownMinValue()) {
371+
default:
372+
return Invalid;
373+
case 8:
374+
if (AccumEVT == MVT::i32)
375+
Cost *= 2;
376+
else if (AccumEVT != MVT::i64)
377+
return Invalid;
378+
break;
379+
case 16:
380+
if (AccumEVT == MVT::i64)
381+
Cost *= 2;
382+
else if (AccumEVT != MVT::i32)
383+
return Invalid;
384+
break;
385+
}
386+
} else if (InputEVT == MVT::i16) {
387+
// FIXME: Allow i32 accumulator but increase cost, as we would extend
388+
// it to i64.
389+
if (VF.getKnownMinValue() != 8 || AccumEVT != MVT::i64)
390+
return Invalid;
391+
} else
392+
return Invalid;
393+
394+
// AArch64 supports lowering mixed extensions to a usdot but only if the
395+
// i8mm or sve/streaming features are available.
396+
if (OpAExtend == TTI::PR_None || OpBExtend == TTI::PR_None ||
397+
(OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
398+
!ST->isSVEorStreamingSVEAvailable()))
399+
return Invalid;
400+
401+
if (!BinOp || *BinOp != Instruction::Mul)
402+
return Invalid;
403+
404+
return Cost;
405+
}
406+
344407
bool enableOrderedReductions() const { return true; }
345408

346409
InstructionCost getInterleavedMemoryOpCost(

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

+131-5
Original file line numberDiff line numberDiff line change
@@ -8268,6 +8268,105 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(Instruction *I,
82688268
return Recipe;
82698269
}
82708270

8271+
/// Find all possible partial reductions in the loop and track all of those that
8272+
/// are valid so recipes can be formed later.
8273+
void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
8274+
// Find all possible partial reductions.
8275+
SmallVector<std::pair<PartialReductionChain, unsigned>, 1>
8276+
PartialReductionChains;
8277+
for (const auto &[Phi, RdxDesc] : Legal->getReductionVars())
8278+
if (std::optional<std::pair<PartialReductionChain, unsigned>> Pair =
8279+
getScaledReduction(Phi, RdxDesc, Range))
8280+
PartialReductionChains.push_back(*Pair);
8281+
8282+
// A partial reduction is invalid if any of its extends are used by
8283+
// something that isn't another partial reduction. This is because the
8284+
// extends are intended to be lowered along with the reduction itself.
8285+
8286+
// Build up a set of partial reduction bin ops for efficient use checking.
8287+
SmallSet<User *, 4> PartialReductionBinOps;
8288+
for (const auto &[PartialRdx, _] : PartialReductionChains)
8289+
PartialReductionBinOps.insert(PartialRdx.BinOp);
8290+
8291+
auto ExtendIsOnlyUsedByPartialReductions =
8292+
[&PartialReductionBinOps](Instruction *Extend) {
8293+
return all_of(Extend->users(), [&](const User *U) {
8294+
return PartialReductionBinOps.contains(U);
8295+
});
8296+
};
8297+
8298+
// Check if each use of a chain's two extends is a partial reduction
8299+
// and only add those that don't have non-partial reduction users.
8300+
for (auto Pair : PartialReductionChains) {
8301+
PartialReductionChain Chain = Pair.first;
8302+
if (ExtendIsOnlyUsedByPartialReductions(Chain.ExtendA) &&
8303+
ExtendIsOnlyUsedByPartialReductions(Chain.ExtendB))
8304+
ScaledReductionExitInstrs.insert(std::make_pair(Chain.Reduction, Pair));
8305+
}
8306+
}
8307+
8308+
std::optional<std::pair<PartialReductionChain, unsigned>>
8309+
VPRecipeBuilder::getScaledReduction(PHINode *PHI,
8310+
const RecurrenceDescriptor &Rdx,
8311+
VFRange &Range) {
8312+
// TODO: Allow scaling reductions when predicating. The select at
8313+
// the end of the loop chooses between the phi value and most recent
8314+
// reduction result, both of which have different VFs to the active lane
8315+
// mask when scaling.
8316+
if (CM.blockNeedsPredicationForAnyReason(Rdx.getLoopExitInstr()->getParent()))
8317+
return std::nullopt;
8318+
8319+
auto *Update = dyn_cast<BinaryOperator>(Rdx.getLoopExitInstr());
8320+
if (!Update)
8321+
return std::nullopt;
8322+
8323+
Value *Op = Update->getOperand(0);
8324+
Value *PhiOp = Update->getOperand(1);
8325+
if (Op == PHI) {
8326+
Op = Update->getOperand(1);
8327+
PhiOp = Update->getOperand(0);
8328+
}
8329+
if (PhiOp != PHI)
8330+
return std::nullopt;
8331+
8332+
auto *BinOp = dyn_cast<BinaryOperator>(Op);
8333+
if (!BinOp || !BinOp->hasOneUse())
8334+
return std::nullopt;
8335+
8336+
using namespace llvm::PatternMatch;
8337+
Value *A, *B;
8338+
if (!match(BinOp->getOperand(0), m_ZExtOrSExt(m_Value(A))) ||
8339+
!match(BinOp->getOperand(1), m_ZExtOrSExt(m_Value(B))))
8340+
return std::nullopt;
8341+
8342+
Instruction *ExtA = cast<Instruction>(BinOp->getOperand(0));
8343+
Instruction *ExtB = cast<Instruction>(BinOp->getOperand(1));
8344+
8345+
TTI::PartialReductionExtendKind OpAExtend =
8346+
TargetTransformInfo::getPartialReductionExtendKind(ExtA);
8347+
TTI::PartialReductionExtendKind OpBExtend =
8348+
TargetTransformInfo::getPartialReductionExtendKind(ExtB);
8349+
8350+
PartialReductionChain Chain(Rdx.getLoopExitInstr(), ExtA, ExtB, BinOp);
8351+
8352+
unsigned TargetScaleFactor =
8353+
PHI->getType()->getPrimitiveSizeInBits().getKnownScalarFactor(
8354+
A->getType()->getPrimitiveSizeInBits());
8355+
8356+
if (LoopVectorizationPlanner::getDecisionAndClampRange(
8357+
[&](ElementCount VF) {
8358+
InstructionCost Cost = TTI->getPartialReductionCost(
8359+
Update->getOpcode(), A->getType(), B->getType(), PHI->getType(),
8360+
VF, OpAExtend, OpBExtend,
8361+
std::make_optional(BinOp->getOpcode()));
8362+
return Cost.isValid();
8363+
},
8364+
Range))
8365+
return std::make_pair(Chain, TargetScaleFactor);
8366+
8367+
return std::nullopt;
8368+
}
8369+
82718370
VPRecipeBase *
82728371
VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
82738372
ArrayRef<VPValue *> Operands,
@@ -8292,9 +8391,14 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
82928391
Legal->getReductionVars().find(Phi)->second;
82938392
assert(RdxDesc.getRecurrenceStartValue() ==
82948393
Phi->getIncomingValueForBlock(OrigLoop->getLoopPreheader()));
8295-
PhiRecipe = new VPReductionPHIRecipe(Phi, RdxDesc, *StartV,
8296-
CM.isInLoopReduction(Phi),
8297-
CM.useOrderedReductions(RdxDesc));
8394+
8395+
// If the PHI is used by a partial reduction, set the scale factor.
8396+
std::optional<std::pair<PartialReductionChain, unsigned>> Pair =
8397+
getScaledReductionForInstr(RdxDesc.getLoopExitInstr());
8398+
unsigned ScaleFactor = Pair ? Pair->second : 1;
8399+
PhiRecipe = new VPReductionPHIRecipe(
8400+
Phi, RdxDesc, *StartV, CM.isInLoopReduction(Phi),
8401+
CM.useOrderedReductions(RdxDesc), ScaleFactor);
82988402
} else {
82998403
// TODO: Currently fixed-order recurrences are modeled as chains of
83008404
// first-order recurrences. If there are no users of the intermediate
@@ -8322,6 +8426,9 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
83228426
if (isa<LoadInst>(Instr) || isa<StoreInst>(Instr))
83238427
return tryToWidenMemory(Instr, Operands, Range);
83248428

8429+
if (getScaledReductionForInstr(Instr))
8430+
return tryToCreatePartialReduction(Instr, Operands);
8431+
83258432
if (!shouldWiden(Instr, Range))
83268433
return nullptr;
83278434

@@ -8342,6 +8449,21 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
83428449
return tryToWiden(Instr, Operands, VPBB);
83438450
}
83448451

8452+
VPRecipeBase *
8453+
VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction,
8454+
ArrayRef<VPValue *> Operands) {
8455+
assert(Operands.size() == 2 &&
8456+
"Unexpected number of operands for partial reduction");
8457+
8458+
VPValue *BinOp = Operands[0];
8459+
VPValue *Phi = Operands[1];
8460+
if (isa<VPReductionPHIRecipe>(BinOp->getDefiningRecipe()))
8461+
std::swap(BinOp, Phi);
8462+
8463+
return new VPPartialReductionRecipe(Reduction->getOpcode(), BinOp, Phi,
8464+
Reduction);
8465+
}
8466+
83458467
void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF,
83468468
ElementCount MaxVF) {
83478469
assert(OrigLoop->isInnermost() && "Inner loop expected.");
@@ -8514,7 +8636,8 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
85148636
bool HasNUW = Style == TailFoldingStyle::None;
85158637
addCanonicalIVRecipes(*Plan, Legal->getWidestInductionType(), HasNUW, DL);
85168638

8517-
VPRecipeBuilder RecipeBuilder(*Plan, OrigLoop, TLI, Legal, CM, PSE, Builder);
8639+
VPRecipeBuilder RecipeBuilder(*Plan, OrigLoop, TLI, &TTI, Legal, CM, PSE,
8640+
Builder);
85188641

85198642
// ---------------------------------------------------------------------------
85208643
// Pre-construction: record ingredients whose recipes we'll need to further
@@ -8560,6 +8683,9 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
85608683
bool NeedsBlends = BB != HeaderBB && !BB->phis().empty();
85618684
return Legal->blockNeedsPredication(BB) || NeedsBlends;
85628685
});
8686+
8687+
RecipeBuilder.collectScaledReductions(Range);
8688+
85638689
for (BasicBlock *BB : make_range(DFS.beginRPO(), DFS.endRPO())) {
85648690
// Relevant instructions from basic block BB will be grouped into VPRecipe
85658691
// ingredients and fill a new VPBasicBlock.
@@ -8770,7 +8896,7 @@ VPlanPtr LoopVectorizationPlanner::buildVPlan(VFRange &Range) {
87708896
bool HasNUW = true;
87718897
addCanonicalIVRecipes(*Plan, Legal->getWidestInductionType(), HasNUW,
87728898
DebugLoc());
8773-
assert(verifyVPlanIsValid(*Plan) && "VPlan is invalid");
8899+
assert(verifyVPlanIsValid(*Plan) && "VPlan is invalid");
87748900
return Plan;
87758901
}
87768902

0 commit comments

Comments
 (0)