Skip to content

[VPlan] Unroll VPReplicateRecipe by VF. #142433

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7557,6 +7557,7 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
// cost model is complete for better cost estimates.
VPlanTransforms::runPass(VPlanTransforms::unrollByUF, BestVPlan, BestUF,
OrigLoop->getHeader()->getContext());
VPlanTransforms::runPass(VPlanTransforms::unrollByVF, BestVPlan, BestVF);
VPlanTransforms::runPass(VPlanTransforms::materializeBroadcasts, BestVPlan);
VPlanTransforms::optimizeForVFAndUF(BestVPlan, BestVF, BestUF, PSE);
VPlanTransforms::simplifyRecipes(BestVPlan, *Legal->getWidestInductionType());
Expand Down
8 changes: 8 additions & 0 deletions llvm/lib/Transforms/Vectorize/VPlan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,14 @@ Value *VPTransformState::get(const VPValue *Def, const VPLane &Lane) {
return Data.VPV2Scalars[Def][0];
}

// Look through BuildVector to avoid redundant extracts.
// TODO: Remove once replicate regions are unrolled explicitly.
auto *BV = dyn_cast<VPInstruction>(Def);
if (Lane.getKind() == VPLane::Kind::First && BV &&
BV->getOpcode() == VPInstruction::BuildVector) {
return get(BV->getOperand(Lane.getKnownLane()), true);
}

assert(hasVectorValue(Def));
auto *VecPart = Data.VPV2Vector[Def];
if (!VecPart->getType()->isVectorTy()) {
Expand Down
6 changes: 6 additions & 0 deletions llvm/lib/Transforms/Vectorize/VPlan.h
Original file line number Diff line number Diff line change
Expand Up @@ -907,6 +907,12 @@ class VPInstruction : public VPRecipeWithIRFlags,
BranchOnCount,
BranchOnCond,
Broadcast,
/// Creates a vector containing all operands. The vector element count
/// matches the number of operands.
BuildVector,
/// Creates a struct of vectors containing all operands. The vector element
/// count matches the number of operands.
BuildStructVector,
ComputeFindLastIVResult,
ComputeReductionResult,
// Extracts the last lane from its operand if it is a vector, or the last
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) {
case VPInstruction::CalculateTripCountMinusVF:
case VPInstruction::CanonicalIVIncrementForPart:
case VPInstruction::AnyOf:
case VPInstruction::BuildVector:
case VPInstruction::BuildStructVector:
return SetResultTyFromOp();
case VPInstruction::FirstActiveLane:
return Type::getIntNTy(Ctx, 64);
Expand Down
62 changes: 44 additions & 18 deletions llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,9 @@ Value *VPInstruction::generate(VPTransformState &State) {
}
case Instruction::ExtractElement: {
assert(State.VF.isVector() && "Only extract elements from vectors");
return State.get(getOperand(0),
VPLane(cast<ConstantInt>(getOperand(1)->getLiveInIRValue())
->getZExtValue()));
Value *Vec = State.get(getOperand(0));
Value *Idx = State.get(getOperand(1), /*IsScalar=*/true);
return Builder.CreateExtractElement(Vec, Idx, Name);
Expand Down Expand Up @@ -604,6 +607,34 @@ Value *VPInstruction::generate(VPTransformState &State) {
return Builder.CreateVectorSplat(
State.VF, State.get(getOperand(0), /*IsScalar*/ true), "broadcast");
}
case VPInstruction::BuildVector: {
auto *ScalarTy = State.TypeAnalysis.inferScalarType(getOperand(0));
Value *Res = PoisonValue::get(
toVectorizedTy(ScalarTy, ElementCount::getFixed(getNumOperands())));
for (const auto &[Idx, Op] : enumerate(operands()))
Res = State.Builder.CreateInsertElement(Res, State.get(Op, true),
State.Builder.getInt32(Idx));
return Res;
}
case VPInstruction::BuildStructVector: {
// For struct types, we need to build a new 'wide' struct type, where each
// element is widened.
auto *STy =
cast<StructType>(State.TypeAnalysis.inferScalarType(getOperand(0)));
Value *Res = PoisonValue::get(
toVectorizedTy(STy, ElementCount::getFixed(getNumOperands())));
for (const auto &[Idx, Op] : enumerate(operands())) {
for (unsigned I = 0, E = STy->getNumElements(); I != E; I++) {
Value *ScalarValue = Builder.CreateExtractValue(State.get(Op, true), I);
Value *VectorValue = Builder.CreateExtractValue(Res, I);
VectorValue =
Builder.CreateInsertElement(VectorValue, ScalarValue, Idx);
Res = Builder.CreateInsertValue(Res, VectorValue, I);
}
}
return Res;
}

case VPInstruction::ComputeFindLastIVResult: {
// FIXME: The cross-recipe dependency on VPReductionPHIRecipe is temporary
// and will be removed by breaking up the recipe further.
Expand Down Expand Up @@ -864,10 +895,11 @@ void VPInstruction::execute(VPTransformState &State) {
if (!hasResult())
return;
assert(GeneratedValue && "generate must produce a value");
assert(
(GeneratedValue->getType()->isVectorTy() == !GeneratesPerFirstLaneOnly ||
State.VF.isScalar()) &&
"scalar value but not only first lane defined");
assert((((GeneratedValue->getType()->isVectorTy() ||
GeneratedValue->getType()->isStructTy()) ==
!GeneratesPerFirstLaneOnly) ||
State.VF.isScalar()) &&
"scalar value but not only first lane defined");
State.set(this, GeneratedValue,
/*IsScalar*/ GeneratesPerFirstLaneOnly);
}
Expand All @@ -881,6 +913,8 @@ bool VPInstruction::opcodeMayReadOrWriteFromMemory() const {
case Instruction::ICmp:
case Instruction::Select:
case VPInstruction::AnyOf:
case VPInstruction::BuildVector:
case VPInstruction::BuildStructVector:
case VPInstruction::CalculateTripCountMinusVF:
case VPInstruction::CanonicalIVIncrementForPart:
case VPInstruction::ExtractLastElement:
Expand Down Expand Up @@ -999,6 +1033,12 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent,
case VPInstruction::Broadcast:
O << "broadcast";
break;
case VPInstruction::BuildVector:
O << "buildvector";
break;
case VPInstruction::BuildStructVector:
O << "buildstructvector";
break;
case VPInstruction::ExtractLastElement:
O << "extract-last-element";
break;
Expand Down Expand Up @@ -2758,20 +2798,6 @@ void VPReplicateRecipe::execute(VPTransformState &State) {
scalarizeInstruction(UI, this, VPLane(0), State);
return;
}

// A store of a loop varying value to a uniform address only needs the last
// copy of the store.
if (isa<StoreInst>(UI) && vputils::isSingleScalar(getOperand(1))) {
auto Lane = VPLane::getLastLaneForVF(State.VF);
scalarizeInstruction(UI, this, VPLane(Lane), State);
return;
}

// Generate scalar instances for all VF lanes.
assert(!State.VF.isScalable() && "Can't scalarize a scalable vector");
const unsigned EndLane = State.VF.getKnownMinValue();
for (unsigned Lane = 0; Lane < EndLane; ++Lane)
scalarizeInstruction(UI, this, VPLane(Lane), State);
}

bool VPReplicateRecipe::shouldPack() const {
Expand Down
16 changes: 16 additions & 0 deletions llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1142,6 +1142,22 @@ static void simplifyRecipe(VPRecipeBase &R, VPTypeAnalysis &TypeInfo) {
return;
}

// Look through Extract(Last|Penultimate)Element (BuildVector ....).
if (match(&R,
m_VPInstruction<VPInstruction::ExtractLastElement>(m_VPValue(A))) ||
match(&R, m_VPInstruction<VPInstruction::ExtractPenultimateElement>(
m_VPValue(A)))) {
unsigned Offset = cast<VPInstruction>(&R)->getOpcode() ==
VPInstruction::ExtractLastElement
? 1
: 2;
auto *BV = dyn_cast<VPInstruction>(A);
if (BV && BV->getOpcode() == VPInstruction::BuildVector) {
Def->replaceAllUsesWith(BV->getOperand(BV->getNumOperands() - Offset));
return;
}
}

// Some simplifications can only be applied after unrolling. Perform them
// below.
if (!Plan->isUnrolled())
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Transforms/Vectorize/VPlanTransforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ struct VPlanTransforms {
/// Explicitly unroll \p Plan by \p UF.
static void unrollByUF(VPlan &Plan, unsigned UF, LLVMContext &Ctx);

/// Explicitly unroll VPReplicateRecipes outside of replicate regions by \p
/// VF.
static void unrollByVF(VPlan &Plan, ElementCount VF);

/// Optimize \p Plan based on \p BestVF and \p BestUF. This may restrict the
/// resulting plan to \p BestVF and \p BestUF.
static void optimizeForVFAndUF(VPlan &Plan, ElementCount BestVF,
Expand Down
81 changes: 81 additions & 0 deletions llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "VPlan.h"
#include "VPlanAnalysis.h"
#include "VPlanCFG.h"
#include "VPlanHelpers.h"
#include "VPlanPatternMatch.h"
#include "VPlanTransforms.h"
#include "VPlanUtils.h"
Expand Down Expand Up @@ -428,3 +429,83 @@ void VPlanTransforms::unrollByUF(VPlan &Plan, unsigned UF, LLVMContext &Ctx) {

VPlanTransforms::removeDeadRecipes(Plan);
}

/// Create a single-scalar clone of RepR for lane \p Lane.
static VPReplicateRecipe *cloneForLane(VPlan &Plan, VPBuilder &Builder,
Type *IdxTy, VPReplicateRecipe *RepR,
VPLane Lane) {
// Collect the operands at Lane, creating extracts as needed.
SmallVector<VPValue *> NewOps;
for (VPValue *Op : RepR->operands()) {
if (vputils::isSingleScalar(Op)) {
NewOps.push_back(Op);
continue;
}
VPValue *Ext;
if (Lane.getKind() == VPLane::Kind::ScalableLast) {
Ext = Builder.createNaryOp(VPInstruction::ExtractLastElement, {Op});
} else {
// Look through buildvector to avoid unnecessary extracts.
auto *BV = dyn_cast<VPInstruction>(Op);
if (BV && BV->getOpcode() == VPInstruction::BuildVector) {
NewOps.push_back(BV->getOperand(Lane.getKnownLane()));
continue;
}
VPValue *Idx =
Plan.getOrAddLiveIn(ConstantInt::get(IdxTy, Lane.getKnownLane()));
Ext = Builder.createNaryOp(Instruction::ExtractElement, {Op, Idx});
}
NewOps.push_back(Ext);
}

auto *New =
new VPReplicateRecipe(RepR->getUnderlyingInstr(), NewOps,
/*IsSingleScalar=*/true, /*Mask=*/nullptr, *RepR);
New->insertBefore(RepR);
return New;
}

void VPlanTransforms::unrollByVF(VPlan &Plan, ElementCount VF) {
Type *IdxTy = IntegerType::get(
Plan.getScalarHeader()->getIRBasicBlock()->getContext(), 32);
for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(
vp_depth_first_shallow(Plan.getVectorLoopRegion()->getEntry()))) {
for (VPRecipeBase &R : make_early_inc_range(*VPBB)) {
auto *RepR = dyn_cast<VPReplicateRecipe>(&R);
if (!RepR || RepR->isSingleScalar())
continue;

VPBuilder Builder(RepR);
SmallVector<VPValue *> LaneDefs;
// Stores to invariant addresses only need to store the last lane.
if (isa<StoreInst>(RepR->getUnderlyingInstr()) &&
vputils::isSingleScalar(RepR->getOperand(1))) {
cloneForLane(Plan, Builder, IdxTy, RepR, VPLane::getLastLaneForVF(VF));
RepR->eraseFromParent();
continue;
}

/// Create single-scalar version of RepR for all lanes.
for (unsigned I = 0; I != VF.getKnownMinValue(); ++I)
LaneDefs.push_back(cloneForLane(Plan, Builder, IdxTy, RepR, VPLane(I)));

/// Users that only demand the first lane can use the definition for lane
/// 0.
RepR->replaceUsesWithIf(LaneDefs[0], [RepR](VPUser &U, unsigned) {
return U.onlyFirstLaneUsed(RepR);
});

Type *ResTy = RepR->getUnderlyingInstr()->getType();
// If needed, create a Build(Struct)Vector recipe to insert the scalar
// lane values into a vector.
if (!ResTy->isVoidTy()) {
VPValue *VecRes = Builder.createNaryOp(
ResTy->isStructTy() ? VPInstruction::BuildStructVector
: VPInstruction::BuildVector,
LaneDefs);
RepR->replaceAllUsesWith(VecRes);
}
RepR->eraseFromParent();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -398,12 +398,6 @@ define void @test_for_tried_to_force_scalar(ptr noalias %A, ptr noalias %B, ptr
; CHECK-NEXT: [[STRIDED_VEC:%.*]] = shufflevector <12 x float> [[WIDE_VEC]], <12 x float> poison, <4 x i32> <i32 0, i32 3, i32 6, i32 9>
; CHECK-NEXT: [[TMP30:%.*]] = extractelement <4 x float> [[STRIDED_VEC]], i32 3
; CHECK-NEXT: store float [[TMP30]], ptr [[C:%.*]], align 4
; CHECK-NEXT: [[TMP31:%.*]] = extractelement <4 x ptr> [[TMP29]], i32 0
; CHECK-NEXT: [[TMP38:%.*]] = load float, ptr [[TMP31]], align 4
; CHECK-NEXT: [[TMP33:%.*]] = extractelement <4 x ptr> [[TMP29]], i32 1
; CHECK-NEXT: [[TMP32:%.*]] = load float, ptr [[TMP33]], align 4
; CHECK-NEXT: [[TMP35:%.*]] = extractelement <4 x ptr> [[TMP29]], i32 2
; CHECK-NEXT: [[TMP34:%.*]] = load float, ptr [[TMP35]], align 4
; CHECK-NEXT: [[TMP37:%.*]] = extractelement <4 x ptr> [[TMP29]], i32 3
; CHECK-NEXT: [[TMP36:%.*]] = load float, ptr [[TMP37]], align 4
; CHECK-NEXT: store float [[TMP36]], ptr [[B:%.*]], align 4
Expand Down
Loading
Loading