Skip to content

Commit 11792a6

Browse files
committed
Improve MLIR-Query by adding matcher combinators
Limit backward-slice with nested matching Add variadic operators
1 parent 84c1564 commit 11792a6

File tree

10 files changed

+392
-15
lines changed

10 files changed

+392
-15
lines changed

mlir/include/mlir/Query/Matcher/Marshallers.h

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,9 @@ class MatcherDescriptor {
108108
const llvm::ArrayRef<ParserValue> args,
109109
Diagnostics *error) const = 0;
110110

111+
// If the matcher is variadic, it can take any number of arguments.
112+
virtual bool isVariadic() const = 0;
113+
111114
// Returns the number of arguments accepted by the matcher.
112115
virtual unsigned getNumArgs() const = 0;
113116

@@ -140,6 +143,8 @@ class FixedArgCountMatcherDescriptor : public MatcherDescriptor {
140143
return marshaller(matcherFunc, matcherName, nameRange, args, error);
141144
}
142145

146+
bool isVariadic() const override { return false; }
147+
143148
unsigned getNumArgs() const override { return argKinds.size(); }
144149

145150
void getArgKinds(unsigned argNo, std::vector<ArgKind> &kinds) const override {
@@ -153,6 +158,54 @@ class FixedArgCountMatcherDescriptor : public MatcherDescriptor {
153158
const std::vector<ArgKind> argKinds;
154159
};
155160

161+
class VariadicOperatorMatcherDescriptor : public MatcherDescriptor {
162+
public:
163+
using VarOp = DynMatcher::VariadicOperator;
164+
VariadicOperatorMatcherDescriptor(unsigned minCount, unsigned maxCount,
165+
VarOp varOp, StringRef matcherName)
166+
: minCount(minCount), maxCount(maxCount), varOp(varOp),
167+
matcherName(matcherName) {}
168+
169+
VariantMatcher create(SourceRange nameRange, ArrayRef<ParserValue> args,
170+
Diagnostics *error) const override {
171+
if (args.size() < minCount || maxCount < args.size()) {
172+
addError(error, nameRange, ErrorType::RegistryWrongArgCount,
173+
{llvm::Twine("requires between "), llvm::Twine(minCount),
174+
llvm::Twine(" and "), llvm::Twine(maxCount),
175+
llvm::Twine(" args, got "), llvm::Twine(args.size())});
176+
return VariantMatcher();
177+
}
178+
179+
std::vector<VariantMatcher> innerArgs;
180+
for (size_t i = 0, e = args.size(); i != e; ++i) {
181+
const ParserValue &arg = args[i];
182+
const VariantValue &value = arg.value;
183+
if (!value.isMatcher()) {
184+
addError(error, arg.range, ErrorType::RegistryWrongArgType,
185+
{llvm::Twine(i + 1), llvm::Twine("Matcher: "),
186+
llvm::Twine(value.getTypeAsString())});
187+
return VariantMatcher();
188+
}
189+
innerArgs.push_back(value.getMatcher());
190+
}
191+
return VariantMatcher::VariadicOperatorMatcher(varOp, std::move(innerArgs));
192+
}
193+
194+
bool isVariadic() const override { return true; }
195+
196+
unsigned getNumArgs() const override { return 0; }
197+
198+
void getArgKinds(unsigned argNo, std::vector<ArgKind> &kinds) const override {
199+
kinds.push_back(ArgKind(ArgKind::Matcher));
200+
}
201+
202+
private:
203+
const unsigned minCount;
204+
const unsigned maxCount;
205+
const VarOp varOp;
206+
const StringRef matcherName;
207+
};
208+
156209
// Helper function to check if argument count matches expected count
157210
inline bool checkArgCount(SourceRange nameRange, size_t expectedArgCount,
158211
llvm::ArrayRef<ParserValue> args,
@@ -224,6 +277,14 @@ makeMatcherAutoMarshall(ReturnType (*matcherFunc)(ArgTypes...),
224277
reinterpret_cast<void (*)()>(matcherFunc), matcherName, argKinds);
225278
}
226279

280+
// Variadic operator overload.
281+
template <unsigned MinCount, unsigned MaxCount>
282+
std::unique_ptr<MatcherDescriptor>
283+
makeMatcherAutoMarshall(VariadicOperatorMatcherFunc<MinCount, MaxCount> func,
284+
StringRef matcherName) {
285+
return std::make_unique<VariadicOperatorMatcherDescriptor>(
286+
MinCount, MaxCount, func.varOp, matcherName);
287+
}
227288
} // namespace mlir::query::matcher::internal
228289

229290
#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H

mlir/include/mlir/Query/Matcher/MatchFinder.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121

2222
namespace mlir::query::matcher {
2323

24-
/// A class that provides utilities to find operations in the IR.
24+
/// Finds and collects matches from the IR. After construction
25+
/// `collectMatches` can be used to traverse the IR and apply
26+
/// matchers.
2527
class MatchFinder {
2628

2729
public:

mlir/include/mlir/Query/Matcher/MatchersInternal.h

Lines changed: 105 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88
//
99
// Implements the base layer of the matcher framework.
1010
//
11-
// Matchers are methods that return a Matcher which provides a method one of the
12-
// following methods: match(Operation *op), match(Operation *op,
13-
// SetVector<Operation *> &matchedOps)
11+
// Matchers are methods that return a Matcher which provide a method
12+
// `match(...)` method. The method's parameters define the context of the match.
13+
// Support includes simple (unary) matchers as well as matcher combinators.
14+
// (anyOf, allOf, etc.)
1415
//
15-
// The matcher functions are defined in include/mlir/IR/Matchers.h.
1616
// This file contains the wrapper classes needed to construct matchers for
1717
// mlir-query.
1818
//
@@ -25,6 +25,15 @@
2525
#include "llvm/ADT/IntrusiveRefCntPtr.h"
2626

2727
namespace mlir::query::matcher {
28+
class DynMatcher;
29+
namespace internal {
30+
31+
bool allOfVariadicOperator(Operation *op, SetVector<Operation *> *matchedOps,
32+
ArrayRef<DynMatcher> innerMatchers);
33+
bool anyOfVariadicOperator(Operation *op, SetVector<Operation *> *matchedOps,
34+
ArrayRef<DynMatcher> innerMatchers);
35+
36+
} // namespace internal
2837

2938
// Defaults to false if T has no match() method with the signature:
3039
// match(Operation* op).
@@ -84,6 +93,26 @@ class MatcherFnImpl : public MatcherInterface {
8493
MatcherFn matcherFn;
8594
};
8695

96+
// VariadicMatcher takes a vector of Matchers and returns true if any Matchers
97+
// match the given operation.
98+
using VariadicOperatorFunction = bool (*)(Operation *op,
99+
SetVector<Operation *> *matchedOps,
100+
ArrayRef<DynMatcher> innerMatchers);
101+
102+
template <VariadicOperatorFunction Func>
103+
class VariadicMatcher : public MatcherInterface {
104+
public:
105+
VariadicMatcher(std::vector<DynMatcher> matchers) : matchers(matchers) {}
106+
107+
bool match(Operation *op) override { return Func(op, nullptr, matchers); }
108+
bool match(Operation *op, SetVector<Operation *> &matchedOps) override {
109+
return Func(op, &matchedOps, matchers);
110+
}
111+
112+
private:
113+
std::vector<DynMatcher> matchers;
114+
};
115+
87116
// Matcher wraps a MatcherInterface implementation and provides match()
88117
// methods that redirect calls to the underlying implementation.
89118
class DynMatcher {
@@ -92,6 +121,31 @@ class DynMatcher {
92121
DynMatcher(MatcherInterface *implementation)
93122
: implementation(implementation) {}
94123

124+
// Construct from a variadic function.
125+
enum VariadicOperator {
126+
// Matches operations for which all provided matchers match.
127+
AllOf,
128+
// Matches operations for which at least one of the provided matchers
129+
// matches.
130+
AnyOf
131+
};
132+
133+
static std::unique_ptr<DynMatcher>
134+
constructVariadic(VariadicOperator Op,
135+
std::vector<DynMatcher> innerMatchers) {
136+
switch (Op) {
137+
case AllOf:
138+
return std::make_unique<DynMatcher>(
139+
new VariadicMatcher<internal::allOfVariadicOperator>(
140+
std::move(innerMatchers)));
141+
case AnyOf:
142+
return std::make_unique<DynMatcher>(
143+
new VariadicMatcher<internal::anyOfVariadicOperator>(
144+
std::move(innerMatchers)));
145+
}
146+
llvm_unreachable("Invalid Op value.");
147+
}
148+
95149
template <typename MatcherFn>
96150
static std::unique_ptr<DynMatcher>
97151
constructDynMatcherFromMatcherFn(MatcherFn &matcherFn) {
@@ -113,6 +167,53 @@ class DynMatcher {
113167
std::string functionName;
114168
};
115169

170+
// VariadicOperatorMatcher related types.
171+
template <typename... Ps>
172+
class VariadicOperatorMatcher {
173+
public:
174+
VariadicOperatorMatcher(DynMatcher::VariadicOperator varOp, Ps &&...params)
175+
: varOp(varOp), params(std::forward<Ps>(params)...) {}
176+
177+
operator std::unique_ptr<DynMatcher>() const & {
178+
return DynMatcher::constructVariadic(
179+
varOp, getMatchers(std::index_sequence_for<Ps...>()));
180+
}
181+
182+
operator std::unique_ptr<DynMatcher>() && {
183+
return DynMatcher::constructVariadic(
184+
varOp, std::move(*this).getMatchers(std::index_sequence_for<Ps...>()));
185+
}
186+
187+
private:
188+
// Helper method to unpack the tuple into a vector.
189+
template <std::size_t... Is>
190+
std::vector<DynMatcher> getMatchers(std::index_sequence<Is...>) const & {
191+
return {DynMatcher(std::get<Is>(params))...};
192+
}
193+
194+
template <std::size_t... Is>
195+
std::vector<DynMatcher> getMatchers(std::index_sequence<Is...>) && {
196+
return {DynMatcher(std::get<Is>(std::move(params)))...};
197+
}
198+
199+
const DynMatcher::VariadicOperator varOp;
200+
std::tuple<Ps...> params;
201+
};
202+
203+
// Overloaded function object to generate VariadicOperatorMatcher objects from
204+
// arbitrary matchers.
205+
template <unsigned MinCount, unsigned MaxCount>
206+
struct VariadicOperatorMatcherFunc {
207+
DynMatcher::VariadicOperator varOp;
208+
209+
template <typename... Ms>
210+
VariadicOperatorMatcher<Ms...> operator()(Ms &&...Ps) const {
211+
static_assert(MinCount <= sizeof...(Ms) && sizeof...(Ms) <= MaxCount,
212+
"invalid number of parameters for variadic matcher");
213+
return VariadicOperatorMatcher<Ms...>(varOp, std::forward<Ms>(Ps)...);
214+
}
215+
};
216+
116217
} // namespace mlir::query::matcher
117218

118219
#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H

mlir/include/mlir/Query/Matcher/SliceMatchers.h

Lines changed: 105 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
//
77
//===----------------------------------------------------------------------===//
88
//
9-
// This file provides matchers for MLIRQuery that peform slicing analysis
9+
// This file defines slicing-analysis matchers that extend and abstract the
10+
// core implementations from `SliceAnalysis.h`.
1011
//
1112
//===----------------------------------------------------------------------===//
1213

@@ -15,9 +16,9 @@
1516

1617
#include "mlir/Analysis/SliceAnalysis.h"
1718

18-
/// A matcher encapsulating `getBackwardSlice` method from SliceAnalysis.h.
19-
/// Additionally, it limits the slice computation to a certain depth level using
20-
/// a custom filter.
19+
/// Computes the backward-slice of all transitive defs reachable from `rootOp`,
20+
/// if `innerMatcher` matches. The traversal stops once the desired depth level
21+
/// is reached.
2122
///
2223
/// Example: starting from node 9, assuming the matcher
2324
/// computes the slice for the first two depth levels:
@@ -116,6 +117,83 @@ bool BackwardSliceMatcher<Matcher>::matches(
116117
: backwardSlice.size() >= 1;
117118
}
118119

120+
/// Computes the backward-slice of all transitive defs reachable from `rootOp`,
121+
/// if `innerMatcher` matches. Traversal stops where `filterMatcher` matches.
122+
template <typename BaseMatcher, typename Filter>
123+
class PredicateBackwardSliceMatcher {
124+
public:
125+
PredicateBackwardSliceMatcher(BaseMatcher innerMatcher, Filter filterMatcher,
126+
bool inclusive, bool omitBlockArguments,
127+
bool omitUsesFromAbove)
128+
: innerMatcher(std::move(innerMatcher)),
129+
filterMatcher(std::move(filterMatcher)), inclusive(inclusive),
130+
omitBlockArguments(omitBlockArguments),
131+
omitUsesFromAbove(omitUsesFromAbove) {}
132+
133+
bool match(Operation *rootOp, SetVector<Operation *> &backwardSlice) {
134+
backwardSlice.clear();
135+
BackwardSliceOptions options;
136+
options.inclusive = inclusive;
137+
options.omitUsesFromAbove = omitUsesFromAbove;
138+
options.omitBlockArguments = omitBlockArguments;
139+
if (innerMatcher.match(rootOp)) {
140+
options.filter = [&](Operation *subOp) {
141+
return !filterMatcher.match(subOp);
142+
};
143+
getBackwardSlice(rootOp, &backwardSlice, options);
144+
return options.inclusive ? backwardSlice.size() > 1
145+
: backwardSlice.size() >= 1;
146+
}
147+
return false;
148+
}
149+
150+
private:
151+
BaseMatcher innerMatcher;
152+
Filter filterMatcher;
153+
bool inclusive;
154+
bool omitBlockArguments;
155+
bool omitUsesFromAbove;
156+
};
157+
158+
/// Computes the forward-slice of all users reachable from `rootOp`,
159+
/// if `innerMatcher` matches. Traversal stops where `filterMatcher` matches.
160+
template <typename BaseMatcher, typename Filter>
161+
class PredicateForwardSliceMatcher {
162+
public:
163+
PredicateForwardSliceMatcher(BaseMatcher innerMatcher, Filter filterMatcher,
164+
bool inclusive)
165+
: innerMatcher(std::move(innerMatcher)),
166+
filterMatcher(std::move(filterMatcher)), inclusive(inclusive) {}
167+
168+
bool match(Operation *rootOp, SetVector<Operation *> &forwardSlice) {
169+
forwardSlice.clear();
170+
ForwardSliceOptions options;
171+
options.inclusive = inclusive;
172+
if (innerMatcher.match(rootOp)) {
173+
options.filter = [&](Operation *subOp) {
174+
return !filterMatcher.match(subOp);
175+
};
176+
getForwardSlice(rootOp, &forwardSlice, options);
177+
return options.inclusive ? forwardSlice.size() > 1
178+
: forwardSlice.size() >= 1;
179+
}
180+
return false;
181+
}
182+
183+
private:
184+
BaseMatcher innerMatcher;
185+
Filter filterMatcher;
186+
bool inclusive;
187+
};
188+
189+
namespace internal {
190+
const matcher::VariadicOperatorMatcherFunc<1,
191+
std::numeric_limits<unsigned>::max()>
192+
anyOf = {matcher::DynMatcher::AnyOf};
193+
const matcher::VariadicOperatorMatcherFunc<1,
194+
std::numeric_limits<unsigned>::max()>
195+
allOf = {matcher::DynMatcher::AllOf};
196+
} // namespace internal
119197
/// Matches transitive defs of a top-level operation up to N levels.
120198
template <typename Matcher>
121199
inline BackwardSliceMatcher<Matcher>
@@ -127,7 +205,7 @@ m_GetDefinitions(Matcher innerMatcher, int64_t maxDepth, bool inclusive,
127205
omitUsesFromAbove);
128206
}
129207

130-
/// Matches all transitive defs of a top-level operation up to N levels
208+
/// Matches all transitive defs of a top-level operation up to N levels.
131209
template <typename Matcher>
132210
inline BackwardSliceMatcher<Matcher> m_GetAllDefinitions(Matcher innerMatcher,
133211
int64_t maxDepth) {
@@ -136,6 +214,28 @@ inline BackwardSliceMatcher<Matcher> m_GetAllDefinitions(Matcher innerMatcher,
136214
false, false);
137215
}
138216

217+
/// Matches all transitive defs of a top-level operation and stops where
218+
/// `filterMatcher` rejects.
219+
template <typename BaseMatcher, typename Filter>
220+
inline PredicateBackwardSliceMatcher<BaseMatcher, Filter>
221+
m_GetDefinitionsByPredicate(BaseMatcher innerMatcher, Filter filterMatcher,
222+
bool inclusive, bool omitBlockArguments,
223+
bool omitUsesFromAbove) {
224+
return PredicateBackwardSliceMatcher<BaseMatcher, Filter>(
225+
std::move(innerMatcher), std::move(filterMatcher), inclusive,
226+
omitBlockArguments, omitUsesFromAbove);
227+
}
228+
229+
/// Matches all users of a top-level operation and stops where
230+
/// `filterMatcher` rejects.
231+
template <typename BaseMatcher, typename Filter>
232+
inline PredicateForwardSliceMatcher<BaseMatcher, Filter>
233+
m_GetUsersByPredicate(BaseMatcher innerMatcher, Filter filterMatcher,
234+
bool inclusive) {
235+
return PredicateForwardSliceMatcher<BaseMatcher, Filter>(
236+
std::move(innerMatcher), std::move(filterMatcher), inclusive);
237+
}
238+
139239
} // namespace mlir::query::matcher
140240

141241
#endif // MLIR_TOOLS_MLIRQUERY_MATCHERS_SLICEMATCHERS_H

0 commit comments

Comments
 (0)