Skip to content

Commit 7e878aa

Browse files
authored
[PatternMatch] Add support for capture-and-match (NFC) (#149825)
When using PatternMatch, there is a common problem where we want to both match something against a pattern, but also capture the value/instruction for various reasons (e.g. to access flags). Currently the two ways to do that is to either capture using m_Value/m_Instruction and do a separate match on the result, or to use the somewhat awkward `m_CombineAnd(m_XYZ, m_Value(V))` pattern. This PR introduces to add a variant of `m_Value`/`m_Instruction` which does both a capture and a match. `m_Value(V, m_XYZ)` is basically equivalent to `m_CombineAnd(m_XYZ, m_Value(V))`. I've ported two InstCombine files to this pattern as a sample.
1 parent b7889a6 commit 7e878aa

File tree

3 files changed

+77
-48
lines changed

3 files changed

+77
-48
lines changed

llvm/include/llvm/IR/PatternMatch.h

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -822,12 +822,52 @@ template <typename Class> struct bind_ty {
822822
}
823823
};
824824

825+
/// Check whether the value has the given Class and matches the nested
826+
/// pattern. Capture it into the provided variable if successful.
827+
template <typename Class, typename MatchTy> struct bind_and_match_ty {
828+
Class *&VR;
829+
MatchTy Match;
830+
831+
bind_and_match_ty(Class *&V, const MatchTy &Match) : VR(V), Match(Match) {}
832+
833+
template <typename ITy> bool match(ITy *V) const {
834+
auto *CV = dyn_cast<Class>(V);
835+
if (CV && Match.match(V)) {
836+
VR = CV;
837+
return true;
838+
}
839+
return false;
840+
}
841+
};
842+
825843
/// Match a value, capturing it if we match.
826844
inline bind_ty<Value> m_Value(Value *&V) { return V; }
827845
inline bind_ty<const Value> m_Value(const Value *&V) { return V; }
828846

847+
/// Match against the nested pattern, and capture the value if we match.
848+
template <typename MatchTy>
849+
inline bind_and_match_ty<Value, MatchTy> m_Value(Value *&V,
850+
const MatchTy &Match) {
851+
return {V, Match};
852+
}
853+
854+
/// Match against the nested pattern, and capture the value if we match.
855+
template <typename MatchTy>
856+
inline bind_and_match_ty<const Value, MatchTy> m_Value(const Value *&V,
857+
const MatchTy &Match) {
858+
return {V, Match};
859+
}
860+
829861
/// Match an instruction, capturing it if we match.
830862
inline bind_ty<Instruction> m_Instruction(Instruction *&I) { return I; }
863+
864+
/// Match against the nested pattern, and capture the instruction if we match.
865+
template <typename MatchTy>
866+
inline bind_and_match_ty<Instruction, MatchTy>
867+
m_Instruction(Instruction *&I, const MatchTy &Match) {
868+
return {I, Match};
869+
}
870+
831871
/// Match a unary operator, capturing it if we match.
832872
inline bind_ty<UnaryOperator> m_UnOp(UnaryOperator *&I) { return I; }
833873
/// Match a binary operator, capturing it if we match.

llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1355,9 +1355,9 @@ Instruction *InstCombinerImpl::
13551355
// right-shift of X and a "select".
13561356
Value *X, *Select;
13571357
Instruction *LowBitsToSkip, *Extract;
1358-
if (!match(&I, m_c_BinOp(m_TruncOrSelf(m_CombineAnd(
1359-
m_LShr(m_Value(X), m_Instruction(LowBitsToSkip)),
1360-
m_Instruction(Extract))),
1358+
if (!match(&I, m_c_BinOp(m_TruncOrSelf(m_Instruction(
1359+
Extract, m_LShr(m_Value(X),
1360+
m_Instruction(LowBitsToSkip)))),
13611361
m_Value(Select))))
13621362
return nullptr;
13631363

@@ -1763,13 +1763,12 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
17631763
Constant *C;
17641764
// (add X, (sext/zext (icmp eq X, C)))
17651765
// -> (select (icmp eq X, C), (add C, (sext/zext 1)), X)
1766-
auto CondMatcher = m_CombineAnd(
1767-
m_Value(Cond),
1768-
m_SpecificICmp(ICmpInst::ICMP_EQ, m_Deferred(A), m_ImmConstant(C)));
1766+
auto CondMatcher =
1767+
m_Value(Cond, m_SpecificICmp(ICmpInst::ICMP_EQ, m_Deferred(A),
1768+
m_ImmConstant(C)));
17691769

17701770
if (match(&I,
1771-
m_c_Add(m_Value(A),
1772-
m_CombineAnd(m_Value(Ext), m_ZExtOrSExt(CondMatcher)))) &&
1771+
m_c_Add(m_Value(A), m_Value(Ext, m_ZExtOrSExt(CondMatcher)))) &&
17731772
Ext->hasOneUse()) {
17741773
Value *Add = isa<ZExtInst>(Ext) ? InstCombiner::AddOne(C)
17751774
: InstCombiner::SubOne(C);

llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp

Lines changed: 30 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2025,10 +2025,9 @@ static Instruction *foldComplexAndOrPatterns(BinaryOperator &I,
20252025
if (CountUses && !Op->hasOneUse())
20262026
return false;
20272027

2028-
if (match(Op, m_c_BinOp(FlippedOpcode,
2029-
m_CombineAnd(m_Value(X),
2030-
m_Not(m_c_BinOp(Opcode, m_A, m_B))),
2031-
m_C)))
2028+
if (match(Op,
2029+
m_c_BinOp(FlippedOpcode,
2030+
m_Value(X, m_Not(m_c_BinOp(Opcode, m_A, m_B))), m_C)))
20322031
return !CountUses || X->hasOneUse();
20332032

20342033
return false;
@@ -2079,10 +2078,10 @@ static Instruction *foldComplexAndOrPatterns(BinaryOperator &I,
20792078
// result is more undefined than a source:
20802079
// (~(A & B) | C) & ~(C & (A ^ B)) --> (A ^ B ^ C) | ~(A | C) is invalid.
20812080
if (Opcode == Instruction::Or && Op0->hasOneUse() &&
2082-
match(Op1, m_OneUse(m_Not(m_CombineAnd(
2083-
m_Value(Y),
2084-
m_c_BinOp(Opcode, m_Specific(C),
2085-
m_c_Xor(m_Specific(A), m_Specific(B)))))))) {
2081+
match(Op1,
2082+
m_OneUse(m_Not(m_Value(
2083+
Y, m_c_BinOp(Opcode, m_Specific(C),
2084+
m_c_Xor(m_Specific(A), m_Specific(B)))))))) {
20862085
// X = ~(A | B)
20872086
// Y = (C | (A ^ B)
20882087
Value *Or = cast<BinaryOperator>(X)->getOperand(0);
@@ -2098,12 +2097,11 @@ static Instruction *foldComplexAndOrPatterns(BinaryOperator &I,
20982097
if (match(Op0,
20992098
m_OneUse(m_c_BinOp(FlippedOpcode,
21002099
m_BinOp(FlippedOpcode, m_Value(B), m_Value(C)),
2101-
m_CombineAnd(m_Value(X), m_Not(m_Value(A)))))) ||
2102-
match(Op0, m_OneUse(m_c_BinOp(
2103-
FlippedOpcode,
2104-
m_c_BinOp(FlippedOpcode, m_Value(C),
2105-
m_CombineAnd(m_Value(X), m_Not(m_Value(A)))),
2106-
m_Value(B))))) {
2100+
m_Value(X, m_Not(m_Value(A)))))) ||
2101+
match(Op0, m_OneUse(m_c_BinOp(FlippedOpcode,
2102+
m_c_BinOp(FlippedOpcode, m_Value(C),
2103+
m_Value(X, m_Not(m_Value(A)))),
2104+
m_Value(B))))) {
21072105
// X = ~A
21082106
// (~A & B & C) | ~(A | B | C) --> ~(A | (B ^ C))
21092107
// (~A | B | C) & ~(A & B & C) --> (~A | (B ^ C))
@@ -2434,8 +2432,7 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) {
24342432
// (-(X & 1)) & Y --> (X & 1) == 0 ? 0 : Y
24352433
Value *Neg;
24362434
if (match(&I,
2437-
m_c_And(m_CombineAnd(m_Value(Neg),
2438-
m_OneUse(m_Neg(m_And(m_Value(), m_One())))),
2435+
m_c_And(m_Value(Neg, m_OneUse(m_Neg(m_And(m_Value(), m_One())))),
24392436
m_Value(Y)))) {
24402437
Value *Cmp = Builder.CreateIsNull(Neg);
24412438
return SelectInst::Create(Cmp, ConstantInt::getNullValue(Ty), Y);
@@ -3728,9 +3725,8 @@ static Value *foldOrUnsignedUMulOverflowICmp(BinaryOperator &I,
37283725
const APInt *C1, *C2;
37293726
if (match(&I,
37303727
m_c_Or(m_ExtractValue<1>(
3731-
m_CombineAnd(m_Intrinsic<Intrinsic::umul_with_overflow>(
3732-
m_Value(X), m_APInt(C1)),
3733-
m_Value(WOV))),
3728+
m_Value(WOV, m_Intrinsic<Intrinsic::umul_with_overflow>(
3729+
m_Value(X), m_APInt(C1)))),
37343730
m_OneUse(m_SpecificCmp(ICmpInst::ICMP_UGT,
37353731
m_ExtractValue<0>(m_Deferred(WOV)),
37363732
m_APInt(C2))))) &&
@@ -3988,12 +3984,12 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
39883984
// ~(B & ?) | (A ^ B) --> ~((B & ?) & A)
39893985
Instruction *And;
39903986
if ((Op0->hasOneUse() || Op1->hasOneUse()) &&
3991-
match(Op0, m_Not(m_CombineAnd(m_Instruction(And),
3992-
m_c_And(m_Specific(A), m_Value())))))
3987+
match(Op0,
3988+
m_Not(m_Instruction(And, m_c_And(m_Specific(A), m_Value())))))
39933989
return BinaryOperator::CreateNot(Builder.CreateAnd(And, B));
39943990
if ((Op0->hasOneUse() || Op1->hasOneUse()) &&
3995-
match(Op0, m_Not(m_CombineAnd(m_Instruction(And),
3996-
m_c_And(m_Specific(B), m_Value())))))
3991+
match(Op0,
3992+
m_Not(m_Instruction(And, m_c_And(m_Specific(B), m_Value())))))
39973993
return BinaryOperator::CreateNot(Builder.CreateAnd(And, A));
39983994

39993995
// (~A | C) | (A ^ B) --> ~(A & B) | C
@@ -4125,16 +4121,13 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
41254121
// treating any non-zero result as overflow. In that case, we overflow if both
41264122
// umul.with.overflow operands are != 0, as in that case the result can only
41274123
// be 0, iff the multiplication overflows.
4128-
if (match(&I,
4129-
m_c_Or(m_CombineAnd(m_ExtractValue<1>(m_Value(UMulWithOv)),
4130-
m_Value(Ov)),
4131-
m_CombineAnd(
4132-
m_SpecificICmp(ICmpInst::ICMP_NE,
4133-
m_CombineAnd(m_ExtractValue<0>(
4134-
m_Deferred(UMulWithOv)),
4135-
m_Value(Mul)),
4136-
m_ZeroInt()),
4137-
m_Value(MulIsNotZero)))) &&
4124+
if (match(&I, m_c_Or(m_Value(Ov, m_ExtractValue<1>(m_Value(UMulWithOv))),
4125+
m_Value(MulIsNotZero,
4126+
m_SpecificICmp(
4127+
ICmpInst::ICMP_NE,
4128+
m_Value(Mul, m_ExtractValue<0>(
4129+
m_Deferred(UMulWithOv))),
4130+
m_ZeroInt())))) &&
41384131
(Ov->hasOneUse() || (MulIsNotZero->hasOneUse() && Mul->hasOneUse()))) {
41394132
Value *A, *B;
41404133
if (match(UMulWithOv, m_Intrinsic<Intrinsic::umul_with_overflow>(
@@ -4151,9 +4144,8 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
41514144
const WithOverflowInst *WO;
41524145
const Value *WOV;
41534146
const APInt *C1, *C2;
4154-
if (match(&I, m_c_Or(m_CombineAnd(m_ExtractValue<1>(m_CombineAnd(
4155-
m_WithOverflowInst(WO), m_Value(WOV))),
4156-
m_Value(Ov)),
4147+
if (match(&I, m_c_Or(m_Value(Ov, m_ExtractValue<1>(
4148+
m_Value(WOV, m_WithOverflowInst(WO)))),
41574149
m_OneUse(m_ICmp(Pred, m_ExtractValue<0>(m_Deferred(WOV)),
41584150
m_APInt(C2))))) &&
41594151
(WO->getBinaryOp() == Instruction::Add ||
@@ -4501,8 +4493,7 @@ static Instruction *visitMaskedMerge(BinaryOperator &I,
45014493
Value *M;
45024494
if (!match(&I, m_c_Xor(m_Value(B),
45034495
m_OneUse(m_c_And(
4504-
m_CombineAnd(m_c_Xor(m_Deferred(B), m_Value(X)),
4505-
m_Value(D)),
4496+
m_Value(D, m_c_Xor(m_Deferred(B), m_Value(X))),
45064497
m_Value(M))))))
45074498
return nullptr;
45084499

@@ -5206,8 +5197,7 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) {
52065197
// (X ^ C) ^ Y --> (X ^ Y) ^ C
52075198
// Just like we do in other places, we completely avoid the fold
52085199
// for constantexprs, at least to avoid endless combine loop.
5209-
if (match(&I, m_c_Xor(m_OneUse(m_Xor(m_CombineAnd(m_Value(X),
5210-
m_Unless(m_ConstantExpr())),
5200+
if (match(&I, m_c_Xor(m_OneUse(m_Xor(m_Value(X, m_Unless(m_ConstantExpr())),
52115201
m_ImmConstant(C1))),
52125202
m_Value(Y))))
52135203
return BinaryOperator::CreateXor(Builder.CreateXor(X, Y), C1);

0 commit comments

Comments
 (0)