Skip to content

Commit 02f731d

Browse files
committed
Improve MLIR-Query by adding matcher combinators
Limit backward-slice with nested matching Add variadic operators
1 parent 84c1564 commit 02f731d

File tree

9 files changed

+333
-12
lines changed

9 files changed

+333
-12
lines changed

mlir/include/mlir/Query/Matcher/Marshallers.h

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,9 @@ class MatcherDescriptor {
108108
const llvm::ArrayRef<ParserValue> args,
109109
Diagnostics *error) const = 0;
110110

111+
// If the matcher is variadic, it can take any number of arguments.
112+
virtual bool isVariadic() const = 0;
113+
111114
// Returns the number of arguments accepted by the matcher.
112115
virtual unsigned getNumArgs() const = 0;
113116

@@ -140,6 +143,8 @@ class FixedArgCountMatcherDescriptor : public MatcherDescriptor {
140143
return marshaller(matcherFunc, matcherName, nameRange, args, error);
141144
}
142145

146+
bool isVariadic() const override { return false; }
147+
143148
unsigned getNumArgs() const override { return argKinds.size(); }
144149

145150
void getArgKinds(unsigned argNo, std::vector<ArgKind> &kinds) const override {
@@ -153,6 +158,54 @@ class FixedArgCountMatcherDescriptor : public MatcherDescriptor {
153158
const std::vector<ArgKind> argKinds;
154159
};
155160

161+
class VariadicOperatorMatcherDescriptor : public MatcherDescriptor {
162+
public:
163+
using VarOp = DynMatcher::VariadicOperator;
164+
VariadicOperatorMatcherDescriptor(unsigned minCount, unsigned maxCount,
165+
VarOp varOp, StringRef matcherName)
166+
: minCount(minCount), maxCount(maxCount), varOp(varOp),
167+
matcherName(matcherName) {}
168+
169+
VariantMatcher create(SourceRange nameRange, ArrayRef<ParserValue> args,
170+
Diagnostics *error) const override {
171+
if (args.size() < minCount || maxCount < args.size()) {
172+
addError(error, nameRange, ErrorType::RegistryWrongArgCount,
173+
{llvm::Twine("requires between "), llvm::Twine(minCount),
174+
llvm::Twine(" and "), llvm::Twine(maxCount),
175+
llvm::Twine(" args, got "), llvm::Twine(args.size())});
176+
return VariantMatcher();
177+
}
178+
179+
std::vector<VariantMatcher> innerArgs;
180+
for (size_t i = 0, e = args.size(); i != e; ++i) {
181+
const ParserValue &arg = args[i];
182+
const VariantValue &value = arg.value;
183+
if (!value.isMatcher()) {
184+
addError(error, arg.range, ErrorType::RegistryWrongArgType,
185+
{llvm::Twine(i + 1), llvm::Twine("Matcher: "),
186+
llvm::Twine(value.getTypeAsString())});
187+
return VariantMatcher();
188+
}
189+
innerArgs.push_back(value.getMatcher());
190+
}
191+
return VariantMatcher::VariadicOperatorMatcher(varOp, std::move(innerArgs));
192+
}
193+
194+
bool isVariadic() const override { return true; }
195+
196+
unsigned getNumArgs() const override { return 0; }
197+
198+
void getArgKinds(unsigned argNo, std::vector<ArgKind> &kinds) const override {
199+
kinds.push_back(ArgKind(ArgKind::Matcher));
200+
}
201+
202+
private:
203+
const unsigned minCount;
204+
const unsigned maxCount;
205+
const VarOp varOp;
206+
const StringRef matcherName;
207+
};
208+
156209
// Helper function to check if argument count matches expected count
157210
inline bool checkArgCount(SourceRange nameRange, size_t expectedArgCount,
158211
llvm::ArrayRef<ParserValue> args,
@@ -224,6 +277,14 @@ makeMatcherAutoMarshall(ReturnType (*matcherFunc)(ArgTypes...),
224277
reinterpret_cast<void (*)()>(matcherFunc), matcherName, argKinds);
225278
}
226279

280+
// Variadic operator overload.
281+
template <unsigned MinCount, unsigned MaxCount>
282+
std::unique_ptr<MatcherDescriptor>
283+
makeMatcherAutoMarshall(VariadicOperatorMatcherFunc<MinCount, MaxCount> func,
284+
StringRef matcherName) {
285+
return std::make_unique<VariadicOperatorMatcherDescriptor>(
286+
MinCount, MaxCount, func.varOp, matcherName);
287+
}
227288
} // namespace mlir::query::matcher::internal
228289

229290
#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H

mlir/include/mlir/Query/Matcher/MatchersInternal.h

Lines changed: 103 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88
//
99
// Implements the base layer of the matcher framework.
1010
//
11-
// Matchers are methods that return a Matcher which provides a method one of the
12-
// following methods: match(Operation *op), match(Operation *op,
13-
// SetVector<Operation *> &matchedOps)
11+
// Matchers are methods that return a Matcher which provide a method
12+
// `match(...)` method. The method's parameters define the context of the match.
13+
// Support includes simple (unary) matchers as well as matcher combinators.
14+
// (anyOf, allOf, etc.)
1415
//
15-
// The matcher functions are defined in include/mlir/IR/Matchers.h.
1616
// This file contains the wrapper classes needed to construct matchers for
1717
// mlir-query.
1818
//
@@ -25,6 +25,13 @@
2525
#include "llvm/ADT/IntrusiveRefCntPtr.h"
2626

2727
namespace mlir::query::matcher {
28+
class DynMatcher;
29+
namespace internal {
30+
31+
bool allOfVariadicOperator(Operation *op, ArrayRef<DynMatcher> innerMatchers);
32+
bool anyOfVariadicOperator(Operation *op, ArrayRef<DynMatcher> innerMatchers);
33+
34+
} // namespace internal
2835

2936
// Defaults to false if T has no match() method with the signature:
3037
// match(Operation* op).
@@ -84,6 +91,26 @@ class MatcherFnImpl : public MatcherInterface {
8491
MatcherFn matcherFn;
8592
};
8693

94+
// VariadicMatcher takes a vector of Matchers and returns true if any Matchers
95+
// match the given operation.
96+
using VariadicOperatorFunction = bool (*)(Operation *op,
97+
ArrayRef<DynMatcher> innerMatchers);
98+
99+
template <VariadicOperatorFunction Func>
100+
class VariadicMatcher : public MatcherInterface {
101+
public:
102+
VariadicMatcher(std::vector<DynMatcher> matchers) : matchers(matchers) {}
103+
104+
bool match(Operation *op) override { return Func(op, matchers); }
105+
// Fallback case
106+
bool match(Operation *op, SetVector<Operation *> &matchedOps) override {
107+
return false;
108+
}
109+
110+
private:
111+
std::vector<DynMatcher> matchers;
112+
};
113+
87114
// Matcher wraps a MatcherInterface implementation and provides match()
88115
// methods that redirect calls to the underlying implementation.
89116
class DynMatcher {
@@ -92,6 +119,31 @@ class DynMatcher {
92119
DynMatcher(MatcherInterface *implementation)
93120
: implementation(implementation) {}
94121

122+
// Construct from a variadic function.
123+
enum VariadicOperator {
124+
// Matches operations for which all provided matchers match.
125+
AllOf,
126+
// Matches operations for which at least one of the provided matchers
127+
// matches.
128+
AnyOf
129+
};
130+
131+
static std::unique_ptr<DynMatcher>
132+
constructVariadic(VariadicOperator Op,
133+
std::vector<DynMatcher> innerMatchers) {
134+
switch (Op) {
135+
case AllOf:
136+
return std::make_unique<DynMatcher>(
137+
new VariadicMatcher<internal::allOfVariadicOperator>(
138+
std::move(innerMatchers)));
139+
case AnyOf:
140+
return std::make_unique<DynMatcher>(
141+
new VariadicMatcher<internal::anyOfVariadicOperator>(
142+
std::move(innerMatchers)));
143+
}
144+
llvm_unreachable("Invalid Op value.");
145+
}
146+
95147
template <typename MatcherFn>
96148
static std::unique_ptr<DynMatcher>
97149
constructDynMatcherFromMatcherFn(MatcherFn &matcherFn) {
@@ -113,6 +165,53 @@ class DynMatcher {
113165
std::string functionName;
114166
};
115167

168+
// VariadicOperatorMatcher related types.
169+
template <typename... Ps>
170+
class VariadicOperatorMatcher {
171+
public:
172+
VariadicOperatorMatcher(DynMatcher::VariadicOperator varOp, Ps &&...params)
173+
: varOp(varOp), params(std::forward<Ps>(params)...) {}
174+
175+
operator std::unique_ptr<DynMatcher>() const & {
176+
return DynMatcher::constructVariadic(
177+
varOp, getMatchers(std::index_sequence_for<Ps...>()));
178+
}
179+
180+
operator std::unique_ptr<DynMatcher>() && {
181+
return DynMatcher::constructVariadic(
182+
varOp, std::move(*this).getMatchers(std::index_sequence_for<Ps...>()));
183+
}
184+
185+
private:
186+
// Helper method to unpack the tuple into a vector.
187+
template <std::size_t... Is>
188+
std::vector<DynMatcher> getMatchers(std::index_sequence<Is...>) const & {
189+
return {DynMatcher(std::get<Is>(params))...};
190+
}
191+
192+
template <std::size_t... Is>
193+
std::vector<DynMatcher> getMatchers(std::index_sequence<Is...>) && {
194+
return {DynMatcher(std::get<Is>(std::move(params)))...};
195+
}
196+
197+
const DynMatcher::VariadicOperator varOp;
198+
std::tuple<Ps...> params;
199+
};
200+
201+
// Overloaded function object to generate VariadicOperatorMatcher objects from
202+
// arbitrary matchers.
203+
template <unsigned MinCount, unsigned MaxCount>
204+
struct VariadicOperatorMatcherFunc {
205+
DynMatcher::VariadicOperator varOp;
206+
207+
template <typename... Ms>
208+
VariadicOperatorMatcher<Ms...> operator()(Ms &&...Ps) const {
209+
static_assert(MinCount <= sizeof...(Ms) && sizeof...(Ms) <= MaxCount,
210+
"invalid number of parameters for variadic matcher");
211+
return VariadicOperatorMatcher<Ms...>(varOp, std::forward<Ms>(Ps)...);
212+
}
213+
};
214+
116215
} // namespace mlir::query::matcher
117216

118217
#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H

mlir/include/mlir/Query/Matcher/SliceMatchers.h

Lines changed: 63 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
//
77
//===----------------------------------------------------------------------===//
88
//
9-
// This file provides matchers for MLIRQuery that peform slicing analysis
9+
// This file defines slicing-analysis matchers that extend and abstract the
10+
// core implementations from `SliceAnalysis.h`.
1011
//
1112
//===----------------------------------------------------------------------===//
1213

@@ -15,9 +16,9 @@
1516

1617
#include "mlir/Analysis/SliceAnalysis.h"
1718

18-
/// A matcher encapsulating `getBackwardSlice` method from SliceAnalysis.h.
19-
/// Additionally, it limits the slice computation to a certain depth level using
20-
/// a custom filter.
19+
/// Computes the backward-slice of all transitive defs reachable from `rootOp`,
20+
/// if `innerMatcher` matches. The traversal stops once the desired depth level
21+
/// is reached.
2122
///
2223
/// Example: starting from node 9, assuming the matcher
2324
/// computes the slice for the first two depth levels:
@@ -116,6 +117,51 @@ bool BackwardSliceMatcher<Matcher>::matches(
116117
: backwardSlice.size() >= 1;
117118
}
118119

120+
/// Computes the backward-slice of all transitive defs reachable from `rootOp`,
121+
/// if `innerMatcher` matches. Traversal stops where `filterMatcher` matches.
122+
template <typename BaseMatcher, typename Filter>
123+
class PredicateBackwardSliceMatcher {
124+
public:
125+
PredicateBackwardSliceMatcher(BaseMatcher innerMatcher, Filter filterMatcher,
126+
bool inclusive, bool omitBlockArguments,
127+
bool omitUsesFromAbove)
128+
: innerMatcher(std::move(innerMatcher)),
129+
filterMatcher(std::move(filterMatcher)), inclusive(inclusive),
130+
omitBlockArguments(omitBlockArguments),
131+
omitUsesFromAbove(omitUsesFromAbove) {}
132+
133+
bool match(Operation *rootOp, SetVector<Operation *> &backwardSlice) {
134+
backwardSlice.clear();
135+
BackwardSliceOptions options;
136+
options.inclusive = inclusive;
137+
options.omitUsesFromAbove = omitUsesFromAbove;
138+
options.omitBlockArguments = omitBlockArguments;
139+
if (innerMatcher.match(rootOp)) {
140+
options.filter = [&](Operation *subOp) {
141+
return !filterMatcher.match(subOp);
142+
};
143+
getBackwardSlice(rootOp, &backwardSlice, options);
144+
return options.inclusive ? backwardSlice.size() > 1
145+
: backwardSlice.size() >= 1;
146+
}
147+
return false;
148+
}
149+
150+
private:
151+
BaseMatcher innerMatcher;
152+
Filter filterMatcher;
153+
bool inclusive;
154+
bool omitBlockArguments;
155+
bool omitUsesFromAbove;
156+
};
157+
158+
const matcher::VariadicOperatorMatcherFunc<1,
159+
std::numeric_limits<unsigned>::max()>
160+
anyOf = {matcher::DynMatcher::AnyOf};
161+
const matcher::VariadicOperatorMatcherFunc<1,
162+
std::numeric_limits<unsigned>::max()>
163+
allOf = {matcher::DynMatcher::AllOf};
164+
119165
/// Matches transitive defs of a top-level operation up to N levels.
120166
template <typename Matcher>
121167
inline BackwardSliceMatcher<Matcher>
@@ -127,7 +173,7 @@ m_GetDefinitions(Matcher innerMatcher, int64_t maxDepth, bool inclusive,
127173
omitUsesFromAbove);
128174
}
129175

130-
/// Matches all transitive defs of a top-level operation up to N levels
176+
/// Matches all transitive defs of a top-level operation up to N levels.
131177
template <typename Matcher>
132178
inline BackwardSliceMatcher<Matcher> m_GetAllDefinitions(Matcher innerMatcher,
133179
int64_t maxDepth) {
@@ -136,6 +182,18 @@ inline BackwardSliceMatcher<Matcher> m_GetAllDefinitions(Matcher innerMatcher,
136182
false, false);
137183
}
138184

185+
/// Matches all transitive defs of a top-level operation and stops where
186+
/// `filterMatcher` rejects.
187+
template <typename BaseMatcher, typename Filter>
188+
inline PredicateBackwardSliceMatcher<BaseMatcher, Filter>
189+
m_GetDefinitionsByPredicate(BaseMatcher innerMatcher, Filter filterMatcher,
190+
bool inclusive, bool omitBlockArguments,
191+
bool omitUsesFromAbove) {
192+
return PredicateBackwardSliceMatcher<BaseMatcher, Filter>(
193+
std::move(innerMatcher), std::move(filterMatcher), inclusive,
194+
omitBlockArguments, omitUsesFromAbove);
195+
}
196+
139197
} // namespace mlir::query::matcher
140198

141199
#endif // MLIR_TOOLS_MLIRQUERY_MATCHERS_SLICEMATCHERS_H

mlir/include/mlir/Query/Matcher/VariantValue.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,12 @@ enum class ArgKind { Boolean, Matcher, Signed, String };
2626
// A variant matcher object to abstract simple and complex matchers into a
2727
// single object type.
2828
class VariantMatcher {
29-
class MatcherOps;
29+
class MatcherOps {
30+
public:
31+
std::optional<DynMatcher>
32+
constructVariadicOperator(DynMatcher::VariadicOperator varOp,
33+
ArrayRef<VariantMatcher> innerMatchers) const;
34+
};
3035

3136
// Payload interface to be specialized by each matcher type. It follows a
3237
// similar interface as VariantMatcher itself.
@@ -43,6 +48,9 @@ class VariantMatcher {
4348

4449
// Clones the provided matcher.
4550
static VariantMatcher SingleMatcher(DynMatcher matcher);
51+
static VariantMatcher
52+
VariadicOperatorMatcher(DynMatcher::VariadicOperator varOp,
53+
ArrayRef<VariantMatcher> args);
4654

4755
// Makes the matcher the "null" matcher.
4856
void reset();
@@ -61,6 +69,7 @@ class VariantMatcher {
6169
: value(std::move(value)) {}
6270

6371
class SinglePayload;
72+
class VariadicOpPayload;
6473

6574
std::shared_ptr<const Payload> value;
6675
};

mlir/lib/Query/Matcher/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
add_mlir_library(MLIRQueryMatcher
22
MatchFinder.cpp
3+
MatchersInternal.cpp
34
Parser.cpp
45
RegistryManager.cpp
56
VariantValue.cpp

0 commit comments

Comments
 (0)