Skip to content

[PatternMatch][VPlan] Add std::function match overload. NFCI #146374

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions llvm/include/llvm/IR/PatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ template <typename Val, typename Pattern> bool match(Val *V, const Pattern &P) {
return P.match(V);
}

template <typename Val = const Value, typename Pattern>
std::function<bool(Val *)> match(const Pattern &P) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs to be implemented in a way that does not use std::function. You probably need to create a dedicated functor class for this.

I'm also not sure this should be just an overload of match(). Maybe something more explicit like match_fn()?

return [&P](Val *V) { return P.match(V); };
}

template <typename Pattern> bool match(ArrayRef<int> Mask, const Pattern &P) {
return P.match(Mask);
}
Expand Down
16 changes: 6 additions & 10 deletions llvm/lib/Analysis/InstructionSimplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5028,14 +5028,12 @@ static Value *simplifyGEPInst(Type *SrcTy, Value *Ptr,
}

// All-zero GEP is a no-op, unless it performs a vector splat.
if (Ptr->getType() == GEPTy &&
all_of(Indices, [](const auto *V) { return match(V, m_Zero()); }))
if (Ptr->getType() == GEPTy && all_of(Indices, match(m_Zero())))
return Ptr;

// getelementptr poison, idx -> poison
// getelementptr baseptr, poison -> poison
if (isa<PoisonValue>(Ptr) ||
any_of(Indices, [](const auto *V) { return isa<PoisonValue>(V); }))
if (isa<PoisonValue>(Ptr) || any_of(Indices, match(m_Poison())))
return PoisonValue::get(GEPTy);

// getelementptr undef, idx -> undef
Expand Down Expand Up @@ -5092,8 +5090,7 @@ static Value *simplifyGEPInst(Type *SrcTy, Value *Ptr,
}

if (!IsScalableVec && Q.DL.getTypeAllocSize(LastType) == 1 &&
all_of(Indices.drop_back(1),
[](Value *Idx) { return match(Idx, m_Zero()); })) {
all_of(Indices.drop_back(1), match(m_Zero()))) {
unsigned IdxWidth =
Q.DL.getIndexSizeInBits(Ptr->getType()->getPointerAddressSpace());
if (Q.DL.getTypeSizeInBits(Indices.back()->getType()) == IdxWidth) {
Expand Down Expand Up @@ -5123,8 +5120,7 @@ static Value *simplifyGEPInst(Type *SrcTy, Value *Ptr,
}

// Check to see if this is constant foldable.
if (!isa<Constant>(Ptr) ||
!all_of(Indices, [](Value *V) { return isa<Constant>(V); }))
if (!isa<Constant>(Ptr) || !all_of(Indices, match(m_Constant())))
return nullptr;

if (!ConstantExpr::isSupportedGetElementPtr(SrcTy))
Expand Down Expand Up @@ -5649,7 +5645,7 @@ static Constant *simplifyFPOp(ArrayRef<Value *> Ops, FastMathFlags FMF,
RoundingMode Rounding) {
// Poison is independent of anything else. It always propagates from an
// operand to a math result.
if (any_of(Ops, [](Value *V) { return match(V, m_Poison()); }))
if (any_of(Ops, match(m_Poison())))
return PoisonValue::get(Ops[0]->getType());

for (Value *V : Ops) {
Expand Down Expand Up @@ -7116,7 +7112,7 @@ static Value *simplifyInstructionWithOperands(Instruction *I,

switch (I->getOpcode()) {
default:
if (llvm::all_of(NewOps, [](Value *V) { return isa<Constant>(V); })) {
if (all_of(NewOps, match(m_Constant()))) {
SmallVector<Constant *, 8> NewConstOps(NewOps.size());
transform(NewOps, NewConstOps.begin(),
[](Value *V) { return cast<Constant>(V); });
Expand Down
5 changes: 2 additions & 3 deletions llvm/lib/Analysis/ValueTracking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,8 @@ bool llvm::haveNoCommonBitsSet(const WithCache<const Value *> &LHSCache,
}

bool llvm::isOnlyUsedInZeroComparison(const Instruction *I) {
return !I->user_empty() && all_of(I->users(), [](const User *U) {
return match(U, m_ICmp(m_Value(), m_Zero()));
});
return !I->user_empty() &&
all_of(I->users(), match(m_ICmp(m_Value(), m_Zero())));
}

bool llvm::isOnlyUsedInZeroEqualityComparison(const Instruction *I) {
Expand Down
7 changes: 3 additions & 4 deletions llvm/lib/CodeGen/InterleavedAccessPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,10 +294,9 @@ bool InterleavedAccessImpl::lowerInterleavedLoad(
continue;
}
if (auto *BI = dyn_cast<BinaryOperator>(User)) {
if (!BI->user_empty() && all_of(BI->users(), [](auto *U) {
auto *SVI = dyn_cast<ShuffleVectorInst>(U);
return SVI && isa<UndefValue>(SVI->getOperand(1));
})) {
using namespace PatternMatch;
if (!BI->user_empty() &&
all_of(BI->users(), match(m_Shuffle(m_Value(), m_Undef())))) {
for (auto *SVI : BI->users())
BinOpShuffles.insert(cast<ShuffleVectorInst>(SVI));
continue;
Expand Down
8 changes: 2 additions & 6 deletions llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2307,12 +2307,8 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
// and let's try to sink `(sub 0, b)` into `b` itself. But only if this isn't
// a pure negation used by a select that looks like abs/nabs.
bool IsNegation = match(Op0, m_ZeroInt());
if (!IsNegation || none_of(I.users(), [&I, Op1](const User *U) {
const Instruction *UI = dyn_cast<Instruction>(U);
if (!UI)
return false;
return match(UI, m_c_Select(m_Specific(Op1), m_Specific(&I)));
})) {
if (!IsNegation ||
none_of(I.users(), match(m_c_Select(m_Specific(Op1), m_Specific(&I))))) {
if (Value *NegOp1 = Negator::Negate(IsNegation, /* IsNSW */ IsNegation &&
I.hasNoSignedWrap(),
Op1, *this))
Expand Down
4 changes: 1 addition & 3 deletions llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1418,9 +1418,7 @@ InstCombinerImpl::foldShuffledIntrinsicOperands(IntrinsicInst *II) {

// At least 1 operand must be a shuffle with 1 use because we are creating 2
// instructions.
if (none_of(II->args(), [](Value *V) {
return isa<ShuffleVectorInst>(V) && V->hasOneUse();
}))
if (none_of(II->args(), match(m_OneUse(m_Shuffle(m_Value(), m_Value())))))
return nullptr;

// See if all arguments are shuffled with the same mask.
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1341,7 +1341,7 @@ Instruction *InstCombinerImpl::foldICmpWithConstant(ICmpInst &Cmp) {
return nullptr;

if (auto *Phi = dyn_cast<PHINode>(Op0))
if (all_of(Phi->operands(), [](Value *V) { return isa<Constant>(V); })) {
if (all_of(Phi->operands(), match(m_Constant()))) {
SmallVector<Constant *> Ops;
for (Value *V : Phi->incoming_values()) {
Constant *Res =
Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ bool InstCombinerImpl::foldIntegerTypedPHI(PHINode &PN) {
Instruction *InstCombinerImpl::foldPHIArgIntToPtrToPHI(PHINode &PN) {
// convert ptr2int ( phi[ int2ptr(ptr2int(x))] ) --> ptr2int ( phi [ x ] )
// Make sure all uses of phi are ptr2int.
if (!all_of(PN.users(), [](User *U) { return isa<PtrToIntInst>(U); }))
if (!all_of(PN.users(), match(m_PtrToInt(m_Value()))))
return nullptr;

// Iterating over all operands to check presence of target pointers for
Expand Down Expand Up @@ -1298,7 +1298,7 @@ static Value *simplifyUsingControlFlow(InstCombiner &Self, PHINode &PN,
// \ /
// phi [v1] [v2]
// Make sure all inputs are constants.
if (!all_of(PN.operands(), [](Value *V) { return isa<ConstantInt>(V); }))
if (!all_of(PN.operands(), match(m_ConstantInt())))
return nullptr;

BasicBlock *BB = PN.getParent();
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3142,7 +3142,7 @@ static Instruction *foldNestedSelects(SelectInst &OuterSelVal,

// Profitability check - avoid increasing instruction count.
if (none_of(ArrayRef<Value *>({OuterSelVal.getCondition(), InnerSelVal}),
[](Value *V) { return V->hasOneUse(); }))
match(m_OneUse(m_Value()))))
return nullptr;

// The appropriate hand of the outermost `select` must be a select itself.
Expand Down
7 changes: 3 additions & 4 deletions llvm/lib/Transforms/Scalar/LICM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -435,10 +435,9 @@ bool LoopInvariantCodeMotion::runOnLoop(Loop *L, AAResults *AA, LoopInfo *LI,
// potentially happen in other passes where instructions are being moved
// across that edge.
bool HasCoroSuspendInst = llvm::any_of(L->getBlocks(), [](BasicBlock *BB) {
return llvm::any_of(*BB, [](Instruction &I) {
IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
return II && II->getIntrinsicID() == Intrinsic::coro_suspend;
});
using namespace PatternMatch;
return any_of(make_pointer_range(*BB),
match(m_Intrinsic<Intrinsic::coro_suspend>()));
});

MemorySSAUpdater MSSAU(MSSA);
Expand Down
35 changes: 16 additions & 19 deletions llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7036,11 +7036,12 @@ static bool planContainsAdditionalSimplifications(VPlan &Plan,
// Unused FOR splices are removed by VPlan transforms, so the VPlan-based
// cost model won't cost it whilst the legacy will.
if (auto *FOR = dyn_cast<VPFirstOrderRecurrencePHIRecipe>(&R)) {
if (none_of(FOR->users(), [](VPUser *U) {
auto *VPI = dyn_cast<VPInstruction>(U);
return VPI && VPI->getOpcode() ==
VPInstruction::FirstOrderRecurrenceSplice;
}))
using namespace VPlanPatternMatch;
if (none_of(
FOR->users(),
match(
m_VPInstruction<VPInstruction::FirstOrderRecurrenceSplice>(
m_VPValue(), m_VPValue()))))
return true;
}
// The VPlan-based cost model is more accurate for partial reduction and
Expand Down Expand Up @@ -7449,13 +7450,11 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
Hints.setAlreadyVectorized();

// Check if it's EVL-vectorized and mark the corresponding metadata.
using namespace VPlanPatternMatch;
bool IsEVLVectorized =
llvm::any_of(*HeaderVPBB, [](const VPRecipeBase &Recipe) {
// Looking for the ExplictVectorLength VPInstruction.
if (const auto *VI = dyn_cast<VPInstruction>(&Recipe))
return VI->getOpcode() == VPInstruction::ExplicitVectorLength;
return false;
});
any_of(make_pointer_range(*HeaderVPBB),
match(m_VPInstruction<VPInstruction::ExplicitVectorLength>(
m_VPValue())));
if (IsEVLVectorized) {
LLVMContext &Context = L->getHeader()->getContext();
MDNode *LoopID = L->getLoopID();
Expand Down Expand Up @@ -9737,10 +9736,9 @@ static void preparePlanForMainVectorLoop(VPlan &MainPlan, VPlan &EpiPlan) {
// If there is a suitable resume value for the canonical induction in the
// scalar (which will become vector) epilogue loop we are done. Otherwise
// create it below.
if (any_of(*MainScalarPH, [VectorTC](VPRecipeBase &R) {
return match(&R, m_VPInstruction<Instruction::PHI>(m_Specific(VectorTC),
m_SpecificInt(0)));
}))
if (any_of(make_pointer_range(*MainScalarPH),
match(m_VPInstruction<Instruction::PHI>(m_Specific(VectorTC),
m_SpecificInt(0)))))
return;
VPBuilder ScalarPHBuilder(MainScalarPH, MainScalarPH->begin());
ScalarPHBuilder.createScalarPhi(
Expand Down Expand Up @@ -9778,10 +9776,9 @@ preparePlanForEpilogueVectorLoop(VPlan &Plan, Loop *L,
match(
P.getIncomingValueForBlock(EPI.MainLoopIterationCountCheck),
m_SpecificInt(0)) &&
all_of(P.incoming_values(), [&EPI](Value *Inc) {
return Inc == EPI.VectorTripCount ||
match(Inc, m_SpecificInt(0));
}))
all_of(P.incoming_values(),
match(m_CombineOr(m_Specific(EPI.VectorTripCount),
m_SpecificInt(0)))))
return &P;
return nullptr;
});
Expand Down
7 changes: 3 additions & 4 deletions llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20708,10 +20708,9 @@ void BoUpSLP::computeMinimumValueSizes() {
IsTruncRoot = true;
}
bool IsSignedCmp = false;
if (UserIgnoreList && all_of(*UserIgnoreList, [](Value *V) {
return match(V, m_SMin(m_Value(), m_Value())) ||
match(V, m_SMax(m_Value(), m_Value()));
}))
if (UserIgnoreList &&
all_of(*UserIgnoreList, match(m_CombineOr(m_SMin(m_Value(), m_Value()),
m_SMax(m_Value(), m_Value())))))
IsSignedCmp = true;
while (NodeIdx < VectorizableTree.size()) {
ArrayRef<Value *> TreeRoot = VectorizableTree[NodeIdx]->Scalars;
Expand Down
10 changes: 10 additions & 0 deletions llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,21 @@ template <typename Val, typename Pattern> bool match(Val *V, const Pattern &P) {
return P.match(V);
}

template <typename Val, typename Pattern>
std::function<bool(Val *)> match(const Pattern &P) {
return [&P](Val *V) { return P.match(V); };
}

template <typename Pattern> bool match(VPUser *U, const Pattern &P) {
auto *R = dyn_cast<VPRecipeBase>(U);
return R && match(R, P);
}

template <typename Pattern>
std::function<bool(VPUser *)> match(const Pattern &P) {
return [&P](VPUser *U) { return match(U, P); };
}

template <typename Class> struct class_match {
template <typename ITy> bool match(ITy *V) const { return isa<Class>(V); }
};
Expand Down
Loading