Skip to content

Commit 4d8f959

Browse files
committed
Revert "Reland "[LoopVectorizer] Add support for partial reductions" (llvm#120721)"
This reverts commit c858bf6 as it casuse optimization crash on -O2, see llvm#120721 (comment)
1 parent 1f90797 commit 4d8f959

16 files changed

+30
-3927
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

-39
Original file line numberDiff line numberDiff line change
@@ -211,12 +211,6 @@ typedef TargetTransformInfo TTI;
211211
/// for IR-level transformations.
212212
class TargetTransformInfo {
213213
public:
214-
enum PartialReductionExtendKind { PR_None, PR_SignExtend, PR_ZeroExtend };
215-
216-
/// Get the kind of extension that an instruction represents.
217-
static PartialReductionExtendKind
218-
getPartialReductionExtendKind(Instruction *I);
219-
220214
/// Construct a TTI object using a type implementing the \c Concept
221215
/// API below.
222216
///
@@ -1286,18 +1280,6 @@ class TargetTransformInfo {
12861280
/// \return if target want to issue a prefetch in address space \p AS.
12871281
bool shouldPrefetchAddressSpace(unsigned AS) const;
12881282

1289-
/// \return The cost of a partial reduction, which is a reduction from a
1290-
/// vector to another vector with fewer elements of larger size. They are
1291-
/// represented by the llvm.experimental.partial.reduce.add intrinsic, which
1292-
/// takes an accumulator and a binary operation operand that itself is fed by
1293-
/// two extends. An example of an operation that uses a partial reduction is a
1294-
/// dot product, which reduces a vector to another of 4 times fewer elements.
1295-
InstructionCost
1296-
getPartialReductionCost(unsigned Opcode, Type *InputType, Type *AccumType,
1297-
ElementCount VF, PartialReductionExtendKind OpAExtend,
1298-
PartialReductionExtendKind OpBExtend,
1299-
std::optional<unsigned> BinOp = std::nullopt) const;
1300-
13011283
/// \return The maximum interleave factor that any transform should try to
13021284
/// perform for this target. This number depends on the level of parallelism
13031285
/// and the number of execution units in the CPU.
@@ -2125,18 +2107,6 @@ class TargetTransformInfo::Concept {
21252107
/// \return if target want to issue a prefetch in address space \p AS.
21262108
virtual bool shouldPrefetchAddressSpace(unsigned AS) const = 0;
21272109

2128-
/// \return The cost of a partial reduction, which is a reduction from a
2129-
/// vector to another vector with fewer elements of larger size. They are
2130-
/// represented by the llvm.experimental.partial.reduce.add intrinsic, which
2131-
/// takes an accumulator and a binary operation operand that itself is fed by
2132-
/// two extends. An example of an operation that uses a partial reduction is a
2133-
/// dot product, which reduces a vector to another of 4 times fewer elements.
2134-
virtual InstructionCost
2135-
getPartialReductionCost(unsigned Opcode, Type *InputType, Type *AccumType,
2136-
ElementCount VF, PartialReductionExtendKind OpAExtend,
2137-
PartialReductionExtendKind OpBExtend,
2138-
std::optional<unsigned> BinOp) const = 0;
2139-
21402110
virtual unsigned getMaxInterleaveFactor(ElementCount VF) = 0;
21412111
virtual InstructionCost getArithmeticInstrCost(
21422112
unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
@@ -2816,15 +2786,6 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
28162786
return Impl.shouldPrefetchAddressSpace(AS);
28172787
}
28182788

2819-
InstructionCost getPartialReductionCost(
2820-
unsigned Opcode, Type *InputType, Type *AccumType, ElementCount VF,
2821-
PartialReductionExtendKind OpAExtend,
2822-
PartialReductionExtendKind OpBExtend,
2823-
std::optional<unsigned> BinOp = std::nullopt) const override {
2824-
return Impl.getPartialReductionCost(Opcode, InputType, AccumType, VF,
2825-
OpAExtend, OpBExtend, BinOp);
2826-
}
2827-
28282789
unsigned getMaxInterleaveFactor(ElementCount VF) override {
28292790
return Impl.getMaxInterleaveFactor(VF);
28302791
}

llvm/include/llvm/Analysis/TargetTransformInfoImpl.h

-9
Original file line numberDiff line numberDiff line change
@@ -585,15 +585,6 @@ class TargetTransformInfoImplBase {
585585
bool enableWritePrefetching() const { return false; }
586586
bool shouldPrefetchAddressSpace(unsigned AS) const { return !AS; }
587587

588-
InstructionCost
589-
getPartialReductionCost(unsigned Opcode, Type *InputType, Type *AccumType,
590-
ElementCount VF,
591-
TTI::PartialReductionExtendKind OpAExtend,
592-
TTI::PartialReductionExtendKind OpBExtend,
593-
std::optional<unsigned> BinOp = std::nullopt) const {
594-
return InstructionCost::getInvalid();
595-
}
596-
597588
unsigned getMaxInterleaveFactor(ElementCount VF) const { return 1; }
598589

599590
InstructionCost getArithmeticInstrCost(

llvm/lib/Analysis/TargetTransformInfo.cpp

-17
Original file line numberDiff line numberDiff line change
@@ -863,14 +863,6 @@ bool TargetTransformInfo::shouldPrefetchAddressSpace(unsigned AS) const {
863863
return TTIImpl->shouldPrefetchAddressSpace(AS);
864864
}
865865

866-
InstructionCost TargetTransformInfo::getPartialReductionCost(
867-
unsigned Opcode, Type *InputType, Type *AccumType, ElementCount VF,
868-
PartialReductionExtendKind OpAExtend, PartialReductionExtendKind OpBExtend,
869-
std::optional<unsigned> BinOp) const {
870-
return TTIImpl->getPartialReductionCost(Opcode, InputType, AccumType, VF,
871-
OpAExtend, OpBExtend, BinOp);
872-
}
873-
874866
unsigned TargetTransformInfo::getMaxInterleaveFactor(ElementCount VF) const {
875867
return TTIImpl->getMaxInterleaveFactor(VF);
876868
}
@@ -982,15 +974,6 @@ InstructionCost TargetTransformInfo::getShuffleCost(
982974
return Cost;
983975
}
984976

985-
TargetTransformInfo::PartialReductionExtendKind
986-
TargetTransformInfo::getPartialReductionExtendKind(Instruction *I) {
987-
if (isa<SExtInst>(I))
988-
return PR_SignExtend;
989-
if (isa<ZExtInst>(I))
990-
return PR_ZeroExtend;
991-
return PR_None;
992-
}
993-
994977
TTI::CastContextHint
995978
TargetTransformInfo::getCastContextHint(const Instruction *I) {
996979
if (!I)

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h

-56
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
#include "llvm/CodeGen/BasicTTIImpl.h"
2424
#include "llvm/IR/Function.h"
2525
#include "llvm/IR/Intrinsics.h"
26-
#include "llvm/Support/InstructionCost.h"
2726
#include <cstdint>
2827
#include <optional>
2928

@@ -358,61 +357,6 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
358357
return BaseT::isLegalNTLoad(DataType, Alignment);
359358
}
360359

361-
InstructionCost
362-
getPartialReductionCost(unsigned Opcode, Type *InputType, Type *AccumType,
363-
ElementCount VF,
364-
TTI::PartialReductionExtendKind OpAExtend,
365-
TTI::PartialReductionExtendKind OpBExtend,
366-
std::optional<unsigned> BinOp) const {
367-
368-
InstructionCost Invalid = InstructionCost::getInvalid();
369-
InstructionCost Cost(TTI::TCC_Basic);
370-
371-
if (Opcode != Instruction::Add)
372-
return Invalid;
373-
374-
EVT InputEVT = EVT::getEVT(InputType);
375-
EVT AccumEVT = EVT::getEVT(AccumType);
376-
377-
if (VF.isScalable() && !ST->isSVEorStreamingSVEAvailable())
378-
return Invalid;
379-
if (VF.isFixed() && (!ST->isNeonAvailable() || !ST->hasDotProd()))
380-
return Invalid;
381-
382-
if (InputEVT == MVT::i8) {
383-
switch (VF.getKnownMinValue()) {
384-
default:
385-
return Invalid;
386-
case 8:
387-
if (AccumEVT == MVT::i32)
388-
Cost *= 2;
389-
else if (AccumEVT != MVT::i64)
390-
return Invalid;
391-
break;
392-
case 16:
393-
if (AccumEVT == MVT::i64)
394-
Cost *= 2;
395-
else if (AccumEVT != MVT::i32)
396-
return Invalid;
397-
break;
398-
}
399-
} else if (InputEVT == MVT::i16) {
400-
// FIXME: Allow i32 accumulator but increase cost, as we would extend
401-
// it to i64.
402-
if (VF.getKnownMinValue() != 8 || AccumEVT != MVT::i64)
403-
return Invalid;
404-
} else
405-
return Invalid;
406-
407-
if (OpAExtend == TTI::PR_None || OpBExtend == TTI::PR_None)
408-
return Invalid;
409-
410-
if (!BinOp || (*BinOp) != Instruction::Mul)
411-
return Invalid;
412-
413-
return Cost;
414-
}
415-
416360
bool enableOrderedReductions() const { return true; }
417361

418362
InstructionCost getInterleavedMemoryOpCost(

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

+4-132
Original file line numberDiff line numberDiff line change
@@ -7605,10 +7605,6 @@ static bool planContainsAdditionalSimplifications(VPlan &Plan,
76057605
}
76067606
continue;
76077607
}
7608-
// The VPlan-based cost model is more accurate for partial reduction and
7609-
// comparing against the legacy cost isn't desirable.
7610-
if (isa<VPPartialReductionRecipe>(&R))
7611-
return true;
76127608
if (Instruction *UI = GetInstructionForCost(&R))
76137609
SeenInstrs.insert(UI);
76147610
}
@@ -8831,103 +8827,6 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(Instruction *I,
88318827
return Recipe;
88328828
}
88338829

8834-
/// Find all possible partial reductions in the loop and track all of those that
8835-
/// are valid so recipes can be formed later.
8836-
void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
8837-
// Find all possible partial reductions.
8838-
SmallVector<std::pair<PartialReductionChain, unsigned>, 1>
8839-
PartialReductionChains;
8840-
for (const auto &[Phi, RdxDesc] : Legal->getReductionVars())
8841-
if (std::optional<std::pair<PartialReductionChain, unsigned>> Pair =
8842-
getScaledReduction(Phi, RdxDesc, Range))
8843-
PartialReductionChains.push_back(*Pair);
8844-
8845-
// A partial reduction is invalid if any of its extends are used by
8846-
// something that isn't another partial reduction. This is because the
8847-
// extends are intended to be lowered along with the reduction itself.
8848-
8849-
// Build up a set of partial reduction bin ops for efficient use checking.
8850-
SmallSet<User *, 4> PartialReductionBinOps;
8851-
for (const auto &[PartialRdx, _] : PartialReductionChains)
8852-
PartialReductionBinOps.insert(PartialRdx.BinOp);
8853-
8854-
auto ExtendIsOnlyUsedByPartialReductions =
8855-
[&PartialReductionBinOps](Instruction *Extend) {
8856-
return all_of(Extend->users(), [&](const User *U) {
8857-
return PartialReductionBinOps.contains(U);
8858-
});
8859-
};
8860-
8861-
// Check if each use of a chain's two extends is a partial reduction
8862-
// and only add those that don't have non-partial reduction users.
8863-
for (auto Pair : PartialReductionChains) {
8864-
PartialReductionChain Chain = Pair.first;
8865-
if (ExtendIsOnlyUsedByPartialReductions(Chain.ExtendA) &&
8866-
ExtendIsOnlyUsedByPartialReductions(Chain.ExtendB))
8867-
ScaledReductionExitInstrs.insert(std::make_pair(Chain.Reduction, Pair));
8868-
}
8869-
}
8870-
8871-
std::optional<std::pair<PartialReductionChain, unsigned>>
8872-
VPRecipeBuilder::getScaledReduction(PHINode *PHI,
8873-
const RecurrenceDescriptor &Rdx,
8874-
VFRange &Range) {
8875-
// TODO: Allow scaling reductions when predicating. The select at
8876-
// the end of the loop chooses between the phi value and most recent
8877-
// reduction result, both of which have different VFs to the active lane
8878-
// mask when scaling.
8879-
if (CM.blockNeedsPredicationForAnyReason(Rdx.getLoopExitInstr()->getParent()))
8880-
return std::nullopt;
8881-
8882-
auto *Update = dyn_cast<BinaryOperator>(Rdx.getLoopExitInstr());
8883-
if (!Update)
8884-
return std::nullopt;
8885-
8886-
Value *Op = Update->getOperand(0);
8887-
if (Op == PHI)
8888-
Op = Update->getOperand(1);
8889-
8890-
auto *BinOp = dyn_cast<BinaryOperator>(Op);
8891-
if (!BinOp || !BinOp->hasOneUse())
8892-
return std::nullopt;
8893-
8894-
using namespace llvm::PatternMatch;
8895-
Value *A, *B;
8896-
if (!match(BinOp->getOperand(0), m_ZExtOrSExt(m_Value(A))) ||
8897-
!match(BinOp->getOperand(1), m_ZExtOrSExt(m_Value(B))))
8898-
return std::nullopt;
8899-
8900-
Instruction *ExtA = cast<Instruction>(BinOp->getOperand(0));
8901-
Instruction *ExtB = cast<Instruction>(BinOp->getOperand(1));
8902-
8903-
// Check that the extends extend from the same type.
8904-
if (A->getType() != B->getType())
8905-
return std::nullopt;
8906-
8907-
TTI::PartialReductionExtendKind OpAExtend =
8908-
TargetTransformInfo::getPartialReductionExtendKind(ExtA);
8909-
TTI::PartialReductionExtendKind OpBExtend =
8910-
TargetTransformInfo::getPartialReductionExtendKind(ExtB);
8911-
8912-
PartialReductionChain Chain(Rdx.getLoopExitInstr(), ExtA, ExtB, BinOp);
8913-
8914-
unsigned TargetScaleFactor =
8915-
PHI->getType()->getPrimitiveSizeInBits().getKnownScalarFactor(
8916-
A->getType()->getPrimitiveSizeInBits());
8917-
8918-
if (LoopVectorizationPlanner::getDecisionAndClampRange(
8919-
[&](ElementCount VF) {
8920-
InstructionCost Cost = TTI->getPartialReductionCost(
8921-
Update->getOpcode(), A->getType(), PHI->getType(), VF,
8922-
OpAExtend, OpBExtend, std::make_optional(BinOp->getOpcode()));
8923-
return Cost.isValid();
8924-
},
8925-
Range))
8926-
return std::make_pair(Chain, TargetScaleFactor);
8927-
8928-
return std::nullopt;
8929-
}
8930-
89318830
VPRecipeBase *
89328831
VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
89338832
ArrayRef<VPValue *> Operands,
@@ -8952,14 +8851,9 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
89528851
Legal->getReductionVars().find(Phi)->second;
89538852
assert(RdxDesc.getRecurrenceStartValue() ==
89548853
Phi->getIncomingValueForBlock(OrigLoop->getLoopPreheader()));
8955-
8956-
// If the PHI is used by a partial reduction, set the scale factor.
8957-
std::optional<std::pair<PartialReductionChain, unsigned>> Pair =
8958-
getScaledReductionForInstr(RdxDesc.getLoopExitInstr());
8959-
unsigned ScaleFactor = Pair ? Pair->second : 1;
8960-
PhiRecipe = new VPReductionPHIRecipe(
8961-
Phi, RdxDesc, *StartV, CM.isInLoopReduction(Phi),
8962-
CM.useOrderedReductions(RdxDesc), ScaleFactor);
8854+
PhiRecipe = new VPReductionPHIRecipe(Phi, RdxDesc, *StartV,
8855+
CM.isInLoopReduction(Phi),
8856+
CM.useOrderedReductions(RdxDesc));
89638857
} else {
89648858
// TODO: Currently fixed-order recurrences are modeled as chains of
89658859
// first-order recurrences. If there are no users of the intermediate
@@ -8991,9 +8885,6 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
89918885
if (isa<LoadInst>(Instr) || isa<StoreInst>(Instr))
89928886
return tryToWidenMemory(Instr, Operands, Range);
89938887

8994-
if (getScaledReductionForInstr(Instr))
8995-
return tryToCreatePartialReduction(Instr, Operands);
8996-
89978888
if (!shouldWiden(Instr, Range))
89988889
return nullptr;
89998890

@@ -9014,21 +8905,6 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
90148905
return tryToWiden(Instr, Operands, VPBB);
90158906
}
90168907

9017-
VPRecipeBase *
9018-
VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction,
9019-
ArrayRef<VPValue *> Operands) {
9020-
assert(Operands.size() == 2 &&
9021-
"Unexpected number of operands for partial reduction");
9022-
9023-
VPValue *BinOp = Operands[0];
9024-
VPValue *Phi = Operands[1];
9025-
if (isa<VPReductionPHIRecipe>(BinOp->getDefiningRecipe()))
9026-
std::swap(BinOp, Phi);
9027-
9028-
return new VPPartialReductionRecipe(Reduction->getOpcode(), BinOp, Phi,
9029-
Reduction);
9030-
}
9031-
90328908
void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF,
90338909
ElementCount MaxVF) {
90348910
assert(OrigLoop->isInnermost() && "Inner loop expected.");
@@ -9346,8 +9222,7 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
93469222
bool HasNUW = !IVUpdateMayOverflow || Style == TailFoldingStyle::None;
93479223
addCanonicalIVRecipes(*Plan, Legal->getWidestInductionType(), HasNUW, DL);
93489224

9349-
VPRecipeBuilder RecipeBuilder(*Plan, OrigLoop, TLI, &TTI, Legal, CM, PSE,
9350-
Builder);
9225+
VPRecipeBuilder RecipeBuilder(*Plan, OrigLoop, TLI, Legal, CM, PSE, Builder);
93519226

93529227
// ---------------------------------------------------------------------------
93539228
// Pre-construction: record ingredients whose recipes we'll need to further
@@ -9393,9 +9268,6 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
93939268
bool NeedsBlends = BB != HeaderBB && !BB->phis().empty();
93949269
return Legal->blockNeedsPredication(BB) || NeedsBlends;
93959270
});
9396-
9397-
RecipeBuilder.collectScaledReductions(Range);
9398-
93999271
auto *MiddleVPBB = Plan->getMiddleBlock();
94009272
VPBasicBlock::iterator MBIP = MiddleVPBB->getFirstNonPhi();
94019273
for (BasicBlock *BB : make_range(DFS.beginRPO(), DFS.endRPO())) {

0 commit comments

Comments
 (0)