Skip to content

Commit

Permalink
[LV][VPlan] Refactor VPReductionRecipe to use reference for member Rd…
Browse files Browse the repository at this point in the history
…xDesc

This commit refactors the implementation of VPReductionRecipe to use
reference instead of pointer for member RdxDesc. Because the member
RdxDesc in VPReductionRecipe should not be a nullptr, using a reference
will provide clearer semantics.

Reviewed By: fhahn

Differential Revision: https://reviews.llvm.org/D158058
  • Loading branch information
Mel-Chen committed Aug 17, 2023
1 parent 7c6c03e commit 463e7cb
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 22 deletions.
26 changes: 12 additions & 14 deletions llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9214,7 +9214,7 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
}

VPReductionRecipe *RedRecipe = new VPReductionRecipe(
&RdxDesc, CurrentLinkI, PreviousLinkV, VecOp, CondOp, &TTI);
RdxDesc, CurrentLinkI, PreviousLinkV, VecOp, CondOp, &TTI);
// Append the recipe to the end of the VPBasicBlock because we need to
// ensure that it comes after all of it's inputs, including CondOp.
// Note that this transformation may leave over dead recipes (including
Expand Down Expand Up @@ -9523,18 +9523,18 @@ void VPInterleaveRecipe::execute(VPTransformState &State) {
void VPReductionRecipe::execute(VPTransformState &State) {
assert(!State.Instance && "Reduction being replicated.");
Value *PrevInChain = State.get(getChainOp(), 0);
RecurKind Kind = RdxDesc->getRecurrenceKind();
bool IsOrdered = State.ILV->useOrderedReductions(*RdxDesc);
RecurKind Kind = RdxDesc.getRecurrenceKind();
bool IsOrdered = State.ILV->useOrderedReductions(RdxDesc);
// Propagate the fast-math flags carried by the underlying instruction.
IRBuilderBase::FastMathFlagGuard FMFGuard(State.Builder);
State.Builder.setFastMathFlags(RdxDesc->getFastMathFlags());
State.Builder.setFastMathFlags(RdxDesc.getFastMathFlags());
for (unsigned Part = 0; Part < State.UF; ++Part) {
Value *NewVecOp = State.get(getVecOp(), Part);
if (VPValue *Cond = getCondOp()) {
Value *NewCond = State.get(Cond, Part);
VectorType *VecTy = cast<VectorType>(NewVecOp->getType());
Value *Iden = RdxDesc->getRecurrenceIdentity(
Kind, VecTy->getElementType(), RdxDesc->getFastMathFlags());
Value *Iden = RdxDesc.getRecurrenceIdentity(Kind, VecTy->getElementType(),
RdxDesc.getFastMathFlags());
Value *IdenVec =
State.Builder.CreateVectorSplat(VecTy->getElementCount(), Iden);
Value *Select = State.Builder.CreateSelect(NewCond, NewVecOp, IdenVec);
Expand All @@ -9544,27 +9544,25 @@ void VPReductionRecipe::execute(VPTransformState &State) {
Value *NextInChain;
if (IsOrdered) {
if (State.VF.isVector())
NewRed = createOrderedReduction(State.Builder, *RdxDesc, NewVecOp,
NewRed = createOrderedReduction(State.Builder, RdxDesc, NewVecOp,
PrevInChain);
else
NewRed = State.Builder.CreateBinOp(
(Instruction::BinaryOps)RdxDesc->getOpcode(Kind), PrevInChain,
(Instruction::BinaryOps)RdxDesc.getOpcode(Kind), PrevInChain,
NewVecOp);
PrevInChain = NewRed;
} else {
PrevInChain = State.get(getChainOp(), Part);
NewRed = createTargetReduction(State.Builder, TTI, *RdxDesc, NewVecOp);
NewRed = createTargetReduction(State.Builder, TTI, RdxDesc, NewVecOp);
}
if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind)) {
NextInChain =
createMinMaxOp(State.Builder, RdxDesc->getRecurrenceKind(),
NewRed, PrevInChain);
NextInChain = createMinMaxOp(State.Builder, RdxDesc.getRecurrenceKind(),
NewRed, PrevInChain);
} else if (IsOrdered)
NextInChain = NewRed;
else
NextInChain = State.Builder.CreateBinOp(
(Instruction::BinaryOps)RdxDesc->getOpcode(Kind), NewRed,
PrevInChain);
(Instruction::BinaryOps)RdxDesc.getOpcode(Kind), NewRed, PrevInChain);
State.set(this, NextInChain, Part);
}
}
Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/Transforms/Vectorize/VPlan.h
Original file line number Diff line number Diff line change
Expand Up @@ -1727,12 +1727,12 @@ class VPInterleaveRecipe : public VPRecipeBase {
/// The Operands are {ChainOp, VecOp, [Condition]}.
class VPReductionRecipe : public VPRecipeBase, public VPValue {
/// The recurrence decriptor for the reduction in question.
const RecurrenceDescriptor *RdxDesc;
const RecurrenceDescriptor &RdxDesc;
/// Pointer to the TTI, needed to create the target reduction
const TargetTransformInfo *TTI;

public:
VPReductionRecipe(const RecurrenceDescriptor *R, Instruction *I,
VPReductionRecipe(const RecurrenceDescriptor &R, Instruction *I,
VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
const TargetTransformInfo *TTI)
: VPRecipeBase(VPDef::VPReductionSC, {ChainOp, VecOp}), VPValue(this, I),
Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -969,14 +969,14 @@ void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent,
O << " +";
if (isa<FPMathOperator>(getUnderlyingInstr()))
O << getUnderlyingInstr()->getFastMathFlags();
O << " reduce." << Instruction::getOpcodeName(RdxDesc->getOpcode()) << " (";
O << " reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " (";
getVecOp()->printAsOperand(O, SlotTracker);
if (getCondOp()) {
O << ", ";
getCondOp()->printAsOperand(O, SlotTracker);
}
O << ")";
if (RdxDesc->IntermediateStore)
if (RdxDesc.IntermediateStore)
O << " (with final reduction value stored in invariant address sank "
"outside of loop)";
}
Expand Down
8 changes: 4 additions & 4 deletions llvm/unittests/Transforms/Vectorize/VPlanTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1118,8 +1118,8 @@ TEST(VPRecipeTest, MayHaveSideEffectsAndMayReadWriteMemory) {
VPValue ChainOp;
VPValue VecOp;
VPValue CondOp;
VPReductionRecipe Recipe(nullptr, nullptr, &ChainOp, &CondOp, &VecOp,
nullptr);
VPReductionRecipe Recipe(RecurrenceDescriptor(), nullptr, &ChainOp, &CondOp,
&VecOp, nullptr);
EXPECT_FALSE(Recipe.mayHaveSideEffects());
EXPECT_FALSE(Recipe.mayReadFromMemory());
EXPECT_FALSE(Recipe.mayWriteToMemory());
Expand Down Expand Up @@ -1286,8 +1286,8 @@ TEST(VPRecipeTest, CastVPReductionRecipeToVPUser) {
VPValue ChainOp;
VPValue VecOp;
VPValue CondOp;
VPReductionRecipe Recipe(nullptr, nullptr, &ChainOp, &CondOp, &VecOp,
nullptr);
VPReductionRecipe Recipe(RecurrenceDescriptor(), nullptr, &ChainOp, &CondOp,
&VecOp, nullptr);
EXPECT_TRUE(isa<VPUser>(&Recipe));
VPRecipeBase *BaseR = &Recipe;
EXPECT_TRUE(isa<VPUser>(BaseR));
Expand Down

0 comments on commit 463e7cb

Please sign in to comment.