-
Notifications
You must be signed in to change notification settings - Fork 14.1k
[LV] Support generating masks for switch terminators. #99808
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -6453,6 +6453,17 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, | |||||||
// a predicated block since it will become a fall-through, although we | ||||||||
// may decide in the future to call TTI for all branches. | ||||||||
} | ||||||||
case Instruction::Switch: { | ||||||||
if (VF.isScalar()) | ||||||||
return TTI.getCFInstrCost(Instruction::Switch, CostKind); | ||||||||
auto *Switch = cast<SwitchInst>(I); | ||||||||
return Switch->getNumCases() * | ||||||||
TTI.getCmpSelInstrCost( | ||||||||
Instruction::ICmp, | ||||||||
ToVectorTy(Switch->getCondition()->getType(), VF), | ||||||||
ToVectorTy(Type::getInt1Ty(I->getContext()), VF), | ||||||||
CmpInst::ICMP_EQ, CostKind); | ||||||||
} | ||||||||
case Instruction::PHI: { | ||||||||
auto *Phi = cast<PHINode>(I); | ||||||||
|
||||||||
|
@@ -7841,6 +7852,62 @@ VPRecipeBuilder::mapToVPValues(User::op_range Operands) { | |||||||
return map_range(Operands, Fn); | ||||||||
} | ||||||||
|
||||||||
void VPRecipeBuilder::createSwitchEdgeMasks(SwitchInst *SI) { | ||||||||
BasicBlock *Src = SI->getParent(); | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added but moved down to iterating over all cases, as we need to look up (Src, Dst) pairs, thanks! |
||||||||
assert(!OrigLoop->isLoopExiting(Src) && | ||||||||
all_of(successors(Src), | ||||||||
[this](BasicBlock *Succ) { | ||||||||
return OrigLoop->getHeader() != Succ; | ||||||||
}) && | ||||||||
"unsupported switch either exiting loop or continuing to header"); | ||||||||
// Create masks where the terminator in Src is a switch. We create mask for | ||||||||
// all edges at the same time. This is more efficient, as we can create and | ||||||||
// collect compares for all cases once. | ||||||||
VPValue *Cond = getVPValueOrAddLiveIn(SI->getCondition(), Plan); | ||||||||
BasicBlock *DefaultDst = SI->getDefaultDest(); | ||||||||
MapVector<BasicBlock *, SmallVector<VPValue *>> Dst2Compares; | ||||||||
for (auto &C : SI->cases()) { | ||||||||
BasicBlock *Dst = C.getCaseSuccessor(); | ||||||||
assert(!EdgeMaskCache.contains({Src, Dst}) && "Edge masks already created"); | ||||||||
// Cases whose destination is the same as default are redundant and can be | ||||||||
// ignored - they will get there anyhow. | ||||||||
if (Dst == DefaultDst) | ||||||||
continue; | ||||||||
auto I = Dst2Compares.insert({Dst, {}}); | ||||||||
VPValue *V = getVPValueOrAddLiveIn(C.getCaseValue(), Plan); | ||||||||
I.first->second.push_back(Builder.createICmp(CmpInst::ICMP_EQ, Cond, V)); | ||||||||
} | ||||||||
|
||||||||
// We need to handle 2 separate cases below for all entries in Dst2Compares, | ||||||||
// which excludes destinations matching the default destination. | ||||||||
VPValue *SrcMask = getBlockInMask(Src); | ||||||||
VPValue *DefaultMask = nullptr; | ||||||||
for (const auto &[Dst, Conds] : Dst2Compares) { | ||||||||
// 1. Dst is not the default destination. Dst is reached if any of the cases | ||||||||
// with destination == Dst are taken. Join the conditions for each case | ||||||||
// whose destination == Dst using an OR. | ||||||||
VPValue *Mask = Conds[0]; | ||||||||
for (VPValue *V : ArrayRef<VPValue *>(Conds).drop_front()) | ||||||||
Mask = Builder.createOr(Mask, V); | ||||||||
if (SrcMask) | ||||||||
Mask = Builder.createLogicalAnd(SrcMask, Mask); | ||||||||
EdgeMaskCache[{Src, Dst}] = Mask; | ||||||||
|
||||||||
// 2. Create the mask for the default destination, which is reached if none | ||||||||
// of the cases with destination != default destination are taken. Join the | ||||||||
// conditions for each case where the destination is != Dst using an OR and | ||||||||
// negate it. | ||||||||
DefaultMask = DefaultMask ? Builder.createOr(DefaultMask, Mask) : Mask; | ||||||||
} | ||||||||
|
||||||||
if (DefaultMask) { | ||||||||
DefaultMask = Builder.createNot(DefaultMask); | ||||||||
if (SrcMask) | ||||||||
DefaultMask = Builder.createLogicalAnd(SrcMask, DefaultMask); | ||||||||
} | ||||||||
EdgeMaskCache[{Src, DefaultDst}] = DefaultMask; | ||||||||
} | ||||||||
|
||||||||
VPValue *VPRecipeBuilder::createEdgeMask(BasicBlock *Src, BasicBlock *Dst) { | ||||||||
assert(is_contained(predecessors(Dst), Src) && "Invalid edge"); | ||||||||
|
||||||||
|
@@ -7850,12 +7917,17 @@ VPValue *VPRecipeBuilder::createEdgeMask(BasicBlock *Src, BasicBlock *Dst) { | |||||||
if (ECEntryIt != EdgeMaskCache.end()) | ||||||||
return ECEntryIt->second; | ||||||||
|
||||||||
if (auto *SI = dyn_cast<SwitchInst>(Src->getTerminator())) { | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Worth asserting that SI stays in the same loop iteration, rather than breaking or continuing to its header? E.g., that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added assert here and check to LoopVectorizationLegality, thanks! |
||||||||
createSwitchEdgeMasks(SI); | ||||||||
assert(EdgeMaskCache.contains(Edge) && "Mask for Edge not created?"); | ||||||||
return EdgeMaskCache[Edge]; | ||||||||
} | ||||||||
|
||||||||
VPValue *SrcMask = getBlockInMask(Src); | ||||||||
|
||||||||
// The terminator has to be a branch inst! | ||||||||
BranchInst *BI = dyn_cast<BranchInst>(Src->getTerminator()); | ||||||||
assert(BI && "Unexpected terminator found"); | ||||||||
|
||||||||
if (!BI->isConditional() || BI->getSuccessor(0) == BI->getSuccessor(1)) | ||||||||
return EdgeMaskCache[Edge] = SrcMask; | ||||||||
|
||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The above scalar cost seems right, wonder about the vector cost below - the cost associated with predicating conditional branches is collected when visiting each phi, rather than the branch itself. May be good to calibrate with some tests, can leave behind a TODO to be done separately.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The vector code matches the cost of the generated masks, which h are costed explicitly for the version with branches due to the compares being explicit instructions. Currently it seems more like the scalar cost may be estimated by getCFInstrCost, but that probably would need to be fixed in TTI.