Skip to content

[LV] Support generating masks for switch terminators. #99808

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1340,12 +1340,21 @@ bool LoopVectorizationLegality::canVectorizeWithIfConvert() {

// Collect the blocks that need predication.
for (BasicBlock *BB : TheLoop->blocks()) {
// We don't support switch statements inside loops.
if (!isa<BranchInst>(BB->getTerminator())) {
reportVectorizationFailure("Loop contains a switch statement",
"loop contains a switch statement",
"LoopContainsSwitch", ORE, TheLoop,
BB->getTerminator());
// We support only branches and switch statements as terminators inside the
// loop.
if (isa<SwitchInst>(BB->getTerminator())) {
if (TheLoop->isLoopExiting(BB)) {
reportVectorizationFailure("Loop contains an unsupported switch",
"loop contains an unsupported switch",
"LoopContainsUnsupportedSwitch", ORE,
TheLoop, BB->getTerminator());
return false;
}
} else if (!isa<BranchInst>(BB->getTerminator())) {
reportVectorizationFailure("Loop contains an unsupported terminator",
"loop contains an unsupported terminator",
"LoopContainsUnsupportedTerminator", ORE,
TheLoop, BB->getTerminator());
return false;
}

Expand Down
74 changes: 73 additions & 1 deletion llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6453,6 +6453,17 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I,
// a predicated block since it will become a fall-through, although we
// may decide in the future to call TTI for all branches.
}
case Instruction::Switch: {
if (VF.isScalar())
return TTI.getCFInstrCost(Instruction::Switch, CostKind);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The above scalar cost seems right, wonder about the vector cost below - the cost associated with predicating conditional branches is collected when visiting each phi, rather than the branch itself. May be good to calibrate with some tests, can leave behind a TODO to be done separately.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The vector code matches the cost of the generated masks, which h are costed explicitly for the version with branches due to the compares being explicit instructions. Currently it seems more like the scalar cost may be estimated by getCFInstrCost, but that probably would need to be fixed in TTI.

auto *Switch = cast<SwitchInst>(I);
return Switch->getNumCases() *
TTI.getCmpSelInstrCost(
Instruction::ICmp,
ToVectorTy(Switch->getCondition()->getType(), VF),
ToVectorTy(Type::getInt1Ty(I->getContext()), VF),
CmpInst::ICMP_EQ, CostKind);
}
case Instruction::PHI: {
auto *Phi = cast<PHINode>(I);

Expand Down Expand Up @@ -7841,6 +7852,62 @@ VPRecipeBuilder::mapToVPValues(User::op_range Operands) {
return map_range(Operands, Fn);
}

void VPRecipeBuilder::createSwitchEdgeMasks(SwitchInst *SI) {
BasicBlock *Src = SI->getParent();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
BasicBlock *Src = SI->getParent();
BasicBlock *Src = SI->getParent();
assert(!EdgeMaskCache.contains(Src) && "Edge masks already created");

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added but moved down to iterating over all cases, as we need to look up (Src, Dst) pairs, thanks!

assert(!OrigLoop->isLoopExiting(Src) &&
all_of(successors(Src),
[this](BasicBlock *Succ) {
return OrigLoop->getHeader() != Succ;
}) &&
"unsupported switch either exiting loop or continuing to header");
// Create masks where the terminator in Src is a switch. We create mask for
// all edges at the same time. This is more efficient, as we can create and
// collect compares for all cases once.
VPValue *Cond = getVPValueOrAddLiveIn(SI->getCondition(), Plan);
BasicBlock *DefaultDst = SI->getDefaultDest();
MapVector<BasicBlock *, SmallVector<VPValue *>> Dst2Compares;
for (auto &C : SI->cases()) {
BasicBlock *Dst = C.getCaseSuccessor();
assert(!EdgeMaskCache.contains({Src, Dst}) && "Edge masks already created");
// Cases whose destination is the same as default are redundant and can be
// ignored - they will get there anyhow.
if (Dst == DefaultDst)
continue;
auto I = Dst2Compares.insert({Dst, {}});
VPValue *V = getVPValueOrAddLiveIn(C.getCaseValue(), Plan);
I.first->second.push_back(Builder.createICmp(CmpInst::ICMP_EQ, Cond, V));
}

// We need to handle 2 separate cases below for all entries in Dst2Compares,
// which excludes destinations matching the default destination.
VPValue *SrcMask = getBlockInMask(Src);
VPValue *DefaultMask = nullptr;
for (const auto &[Dst, Conds] : Dst2Compares) {
// 1. Dst is not the default destination. Dst is reached if any of the cases
// with destination == Dst are taken. Join the conditions for each case
// whose destination == Dst using an OR.
VPValue *Mask = Conds[0];
for (VPValue *V : ArrayRef<VPValue *>(Conds).drop_front())
Mask = Builder.createOr(Mask, V);
if (SrcMask)
Mask = Builder.createLogicalAnd(SrcMask, Mask);
EdgeMaskCache[{Src, Dst}] = Mask;

// 2. Create the mask for the default destination, which is reached if none
// of the cases with destination != default destination are taken. Join the
// conditions for each case where the destination is != Dst using an OR and
// negate it.
DefaultMask = DefaultMask ? Builder.createOr(DefaultMask, Mask) : Mask;
}

if (DefaultMask) {
DefaultMask = Builder.createNot(DefaultMask);
if (SrcMask)
DefaultMask = Builder.createLogicalAnd(SrcMask, DefaultMask);
}
EdgeMaskCache[{Src, DefaultDst}] = DefaultMask;
}

VPValue *VPRecipeBuilder::createEdgeMask(BasicBlock *Src, BasicBlock *Dst) {
assert(is_contained(predecessors(Dst), Src) && "Invalid edge");

Expand All @@ -7850,12 +7917,17 @@ VPValue *VPRecipeBuilder::createEdgeMask(BasicBlock *Src, BasicBlock *Dst) {
if (ECEntryIt != EdgeMaskCache.end())
return ECEntryIt->second;

if (auto *SI = dyn_cast<SwitchInst>(Src->getTerminator())) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worth asserting that SI stays in the same loop iteration, rather than breaking or continuing to its header? E.g., that !OrigLoop->isLoopExiting(Src).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added assert here and check to LoopVectorizationLegality, thanks!

createSwitchEdgeMasks(SI);
assert(EdgeMaskCache.contains(Edge) && "Mask for Edge not created?");
return EdgeMaskCache[Edge];
}

VPValue *SrcMask = getBlockInMask(Src);

// The terminator has to be a branch inst!
BranchInst *BI = dyn_cast<BranchInst>(Src->getTerminator());
assert(BI && "Unexpected terminator found");

if (!BI->isConditional() || BI->getSuccessor(0) == BI->getSuccessor(1))
return EdgeMaskCache[Edge] = SrcMask;

Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ class VPRecipeBuilder {
/// Returns the *entry* mask for the block \p BB.
VPValue *getBlockInMask(BasicBlock *BB) const;

/// Create an edge mask for every destination of cases and/or default.
void createSwitchEdgeMasks(SwitchInst *SI);

/// A helper function that computes the predicate of the edge between SRC
/// and DST.
VPValue *createEdgeMask(BasicBlock *Src, BasicBlock *Dst);
Expand Down
Loading
Loading