-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir] Improve mlir-query tool by implementing getBackwardSlice
matcher
#115670
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir] Improve mlir-query tool by implementing getBackwardSlice
matcher
#115670
Conversation
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Quick skim comments.
bafbd37
to
88a01fd
Compare
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir-core Author: Denzel-Brian Budii (chios202) ChangesThis Pull Request aims to improve MLIR-QUERY tool by implementing Example of current matcher. The query was made to the file:
Patch is 34.52 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/115670.diff 17 Files Affected:
diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index 1dce055db1b4a7..2204a68be26b10 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -59,7 +59,7 @@ struct NameOpMatcher {
NameOpMatcher(StringRef name) : name(name) {}
bool match(Operation *op) { return op->getName().getStringRef() == name; }
- StringRef name;
+ std::string name;
};
/// The matcher that matches operations that have the specified attribute name.
@@ -67,7 +67,7 @@ struct AttrOpMatcher {
AttrOpMatcher(StringRef attrName) : attrName(attrName) {}
bool match(Operation *op) { return op->hasAttr(attrName); }
- StringRef attrName;
+ std::string attrName;
};
/// The matcher that matches operations that have the `ConstantLike` trait, and
diff --git a/mlir/include/mlir/Query/Matcher/ExtraMatchers.h b/mlir/include/mlir/Query/Matcher/ExtraMatchers.h
new file mode 100644
index 00000000000000..57adc3241b0bef
--- /dev/null
+++ b/mlir/include/mlir/Query/Matcher/ExtraMatchers.h
@@ -0,0 +1,180 @@
+//===- ExtraMatchers.h - Various common matchers ---------------------*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file provides extra matchers that are very useful for mlir-query
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_EXTRAMATCHERS_H
+#define MLIR_IR_EXTRAMATCHERS_H
+
+#include "MatchFinder.h"
+#include "MatchersInternal.h"
+#include "mlir/IR/Region.h"
+#include "mlir/Query/Query.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace mlir {
+
+namespace query {
+
+namespace extramatcher {
+
+namespace detail {
+
+class BackwardSliceMatcher {
+public:
+ BackwardSliceMatcher(matcher::DynMatcher &&innerMatcher, unsigned hops)
+ : innerMatcher(std::move(innerMatcher)), hops(hops) {}
+
+private:
+ bool matches(Operation *op, SetVector<Operation *> &backwardSlice,
+ QueryOptions &options, unsigned tempHops) {
+
+ if (op->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
+ return false;
+ }
+
+ auto processValue = [&](Value value) {
+ if (tempHops == 0) {
+ return;
+ }
+ if (auto *definingOp = value.getDefiningOp()) {
+ if (backwardSlice.count(definingOp) == 0)
+ matches(definingOp, backwardSlice, options, tempHops - 1);
+ } else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
+ if (options.omitBlockArguments)
+ return;
+ Block *block = blockArg.getOwner();
+
+ Operation *parentOp = block->getParentOp();
+
+ if (parentOp && backwardSlice.count(parentOp) == 0) {
+ assert(parentOp->getNumRegions() == 1 &&
+ parentOp->getRegion(0).getBlocks().size() == 1);
+ matches(parentOp, backwardSlice, options, tempHops-1);
+ }
+ } else {
+ llvm_unreachable("No definingOp and not a block argument.");
+ }
+ };
+
+ if (!options.omitUsesFromAbove) {
+ llvm::for_each(op->getRegions(), [&](Region ®ion) {
+ SmallPtrSet<Region *, 4> descendents;
+ region.walk(
+ [&](Region *childRegion) { descendents.insert(childRegion); });
+ region.walk([&](Operation *op) {
+ for (OpOperand &operand : op->getOpOperands()) {
+ if (!descendents.contains(operand.get().getParentRegion()))
+ processValue(operand.get());
+ }
+ });
+ });
+ }
+
+ llvm::for_each(op->getOperands(), processValue);
+ backwardSlice.insert(op);
+ return true;
+ }
+
+public:
+ bool match(Operation *op, SetVector<Operation *> &backwardSlice,
+ QueryOptions &options) {
+ if (innerMatcher.match(op) && matches(op, backwardSlice, options, hops)) {
+ if (!options.inclusive) {
+ backwardSlice.remove(op);
+ }
+ return true;
+ }
+ return false;
+ }
+
+private:
+ matcher::DynMatcher innerMatcher;
+ unsigned hops;
+};
+
+class ForwardSliceMatcher {
+public:
+ ForwardSliceMatcher(matcher::DynMatcher &&innerMatcher, unsigned hops)
+ : innerMatcher(std::move(innerMatcher)), hops(hops) {}
+
+private:
+ bool matches(Operation *op, SetVector<Operation *> &forwardSlice,
+ QueryOptions &options, unsigned tempHops) {
+
+ if (tempHops == 0) {
+ forwardSlice.insert(op);
+ return true;
+ }
+
+ for (Region ®ion : op->getRegions())
+ for (Block &block : region)
+ for (Operation &blockOp : block)
+ if (forwardSlice.count(&blockOp) == 0)
+ matches(&blockOp, forwardSlice, options, tempHops - 1);
+ for (Value result : op->getResults()) {
+ for (Operation *userOp : result.getUsers())
+ if (forwardSlice.count(userOp) == 0)
+ matches(userOp, forwardSlice, options, tempHops - 1);
+ }
+
+ forwardSlice.insert(op);
+ return true;
+ }
+
+public:
+ bool match(Operation *op, SetVector<Operation *> &forwardSlice,
+ QueryOptions &options) {
+ if (innerMatcher.match(op) && matches(op, forwardSlice, options, hops)) {
+ if (!options.inclusive) {
+ forwardSlice.remove(op);
+ }
+ SmallVector<Operation *, 0> v(forwardSlice.takeVector());
+ forwardSlice.insert(v.rbegin(), v.rend());
+ return true;
+ }
+ return false;
+ }
+
+private:
+ matcher::DynMatcher innerMatcher;
+ unsigned hops;
+};
+
+} // namespace detail
+
+inline detail::BackwardSliceMatcher
+definedBy(mlir::query::matcher::DynMatcher innerMatcher) {
+ return detail::BackwardSliceMatcher(std::move(innerMatcher), 1);
+}
+
+inline detail::BackwardSliceMatcher
+getDefinitions(mlir::query::matcher::DynMatcher innerMatcher, unsigned hops) {
+ return detail::BackwardSliceMatcher(std::move(innerMatcher), hops);
+}
+
+inline detail::ForwardSliceMatcher
+usedBy(mlir::query::matcher::DynMatcher innerMatcher) {
+ return detail::ForwardSliceMatcher(std::move(innerMatcher), 1);
+}
+
+inline detail::ForwardSliceMatcher
+getUses(mlir::query::matcher::DynMatcher innerMatcher, unsigned hops) {
+ return detail::ForwardSliceMatcher(std::move(innerMatcher), hops);
+}
+
+} // namespace extramatcher
+
+} // namespace query
+
+} // namespace mlir
+
+#endif // MLIR_IR_EXTRAMATCHERS_H
diff --git a/mlir/include/mlir/Query/Matcher/Marshallers.h b/mlir/include/mlir/Query/Matcher/Marshallers.h
index 6ed35ac0ddccc7..c775dbc5c86da0 100644
--- a/mlir/include/mlir/Query/Matcher/Marshallers.h
+++ b/mlir/include/mlir/Query/Matcher/Marshallers.h
@@ -50,6 +50,21 @@ struct ArgTypeTraits<llvm::StringRef> {
}
};
+template <>
+struct ArgTypeTraits<unsigned> {
+ static bool hasCorrectType(const VariantValue &value) {
+ return value.isUnsigned();
+ }
+
+ static unsigned get(const VariantValue &value) { return value.getUnsigned(); }
+
+ static ArgKind getKind() { return ArgKind::Unsigned; }
+
+ static std::optional<std::string> getBestGuess(const VariantValue &) {
+ return std::nullopt;
+ }
+};
+
template <>
struct ArgTypeTraits<DynMatcher> {
diff --git a/mlir/include/mlir/Query/Matcher/MatchFinder.h b/mlir/include/mlir/Query/Matcher/MatchFinder.h
index b008a21f53ae2a..2175db86a91bdf 100644
--- a/mlir/include/mlir/Query/Matcher/MatchFinder.h
+++ b/mlir/include/mlir/Query/Matcher/MatchFinder.h
@@ -7,7 +7,7 @@
//===----------------------------------------------------------------------===//
//
// This file contains the MatchFinder class, which is used to find operations
-// that match a given matcher.
+// that match a given matcher and print them.
//
//===----------------------------------------------------------------------===//
@@ -15,24 +15,52 @@
#define MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H
#include "MatchersInternal.h"
+#include "mlir/Query/QuerySession.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/raw_ostream.h"
namespace mlir::query::matcher {
-// MatchFinder is used to find all operations that match a given matcher.
class MatchFinder {
+private:
+ static void printMatch(llvm::raw_ostream &os, QuerySession &qs,
+ mlir::Operation *op, const std::string &binding) {
+ auto fileLoc = op->getLoc()->findInstanceOf<FileLineColLoc>();
+ auto smloc = qs.getSourceManager().FindLocForLineAndColumn(
+ qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn());
+ qs.getSourceManager().PrintMessage(os, smloc, llvm::SourceMgr::DK_Note,
+ "\"" + binding + "\" binds here");
+ };
+
public:
- // Returns all operations that match the given matcher.
- static std::vector<Operation *> getMatches(Operation *root,
- DynMatcher matcher) {
- std::vector<Operation *> matches;
+ static std::vector<Operation *>
+ getMatches(Operation *root, QueryOptions &options, DynMatcher matcher,
+ llvm::raw_ostream &os, QuerySession &qs) {
+ unsigned matchCount = 0;
+ std::vector<Operation *> matchedOps;
+ SetVector<Operation *> tempStorage;
- // Simple match finding with walk.
root->walk([&](Operation *subOp) {
- if (matcher.match(subOp))
- matches.push_back(subOp);
- });
+ if (matcher.match(subOp)) {
+ matchedOps.push_back(subOp);
+ os << "Match #" << ++matchCount << ":\n\n";
+ printMatch(os, qs, subOp, "root");
+ } else {
+ SmallVector<Operation *> printingOps;
- return matches;
+ if (matcher.match(subOp, tempStorage, options)) {
+ os << "Match #" << ++matchCount << ":\n\n";
+ SmallVector<Operation *> printingOps(tempStorage.takeVector());
+ for (auto op : printingOps) {
+ printMatch(os, qs, op, "root");
+ matchedOps.push_back(op);
+ }
+ printingOps.clear();
+ }
+ }
+ });
+ return matchedOps;
}
};
diff --git a/mlir/include/mlir/Query/Matcher/MatchersInternal.h b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
index 117f7d4edef9e3..b532b47be7d051 100644
--- a/mlir/include/mlir/Query/Matcher/MatchersInternal.h
+++ b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
@@ -1,4 +1,3 @@
-//===- MatchersInternal.h - Structural query framework ----------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -10,28 +9,53 @@
//
// Matchers are methods that return a Matcher which provides a method
// match(Operation *op)
+// match(Operation *op, SetVector<Operation *> &matchedOps, QueryOptions
+// &options)
//
// The matcher functions are defined in include/mlir/IR/Matchers.h.
// This file contains the wrapper classes needed to construct matchers for
// mlir-query.
//
//===----------------------------------------------------------------------===//
-
#ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
#define MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
#include "mlir/IR/Matchers.h"
#include "llvm/ADT/IntrusiveRefCntPtr.h"
+namespace mlir {
+namespace query {
+struct QueryOptions;
+}
+} // namespace mlir
+
namespace mlir::query::matcher {
+template <typename T, typename = void>
+struct has_simple_match : std::false_type {};
+
+template <typename T>
+struct has_simple_match<T, std::void_t<decltype(std::declval<T>().match(
+ std::declval<Operation *>()))>>
+ : std::true_type {};
+
+template <typename T, typename = void>
+struct has_bound_match : std::false_type {};
+
+template <typename T>
+struct has_bound_match<T, std::void_t<decltype(std::declval<T>().match(
+ std::declval<Operation *>(),
+ std::declval<SetVector<Operation *> &>(),
+ std::declval<QueryOptions &>()))>>
+ : std::true_type {};
// Generic interface for matchers on an MLIR operation.
class MatcherInterface
: public llvm::ThreadSafeRefCountedBase<MatcherInterface> {
public:
virtual ~MatcherInterface() = default;
-
virtual bool match(Operation *op) = 0;
+ virtual bool match(Operation *op, SetVector<Operation *> &matchedOps,
+ QueryOptions &options) = 0;
};
// MatcherFnImpl takes a matcher function object and implements
@@ -40,14 +64,26 @@ template <typename MatcherFn>
class MatcherFnImpl : public MatcherInterface {
public:
MatcherFnImpl(MatcherFn &matcherFn) : matcherFn(matcherFn) {}
- bool match(Operation *op) override { return matcherFn.match(op); }
+
+ bool match(Operation *op) override {
+ if constexpr (has_simple_match<MatcherFn>::value)
+ return matcherFn.match(op);
+ return false;
+ }
+
+ bool match(Operation *op, SetVector<Operation *> &matchedOps,
+ QueryOptions &options) override {
+ if constexpr (has_bound_match<MatcherFn>::value)
+ return matcherFn.match(op, matchedOps, options);
+ return false;
+ }
private:
MatcherFn matcherFn;
};
-// Matcher wraps a MatcherInterface implementation and provides a match()
-// method that redirects calls to the underlying implementation.
+// Matcher wraps a MatcherInterface implementation and provides match()
+// methods that redirect calls to the underlying implementation.
class DynMatcher {
public:
// Takes ownership of the provided implementation pointer.
@@ -62,12 +98,14 @@ class DynMatcher {
}
bool match(Operation *op) const { return implementation->match(op); }
+ bool match(Operation *op, SetVector<Operation *> &matchedOps,
+ QueryOptions &options) const {
+ return implementation->match(op, matchedOps, options);
+ }
- void setFunctionName(StringRef name) { functionName = name.str(); };
-
- bool hasFunctionName() const { return !functionName.empty(); };
-
- StringRef getFunctionName() const { return functionName; };
+ void setFunctionName(StringRef name) { functionName = name.str(); }
+ bool hasFunctionName() const { return !functionName.empty(); }
+ StringRef getFunctionName() const { return functionName; }
private:
llvm::IntrusiveRefCntPtr<MatcherInterface> implementation;
diff --git a/mlir/include/mlir/Query/Matcher/VariantValue.h b/mlir/include/mlir/Query/Matcher/VariantValue.h
index 449f8b3a01e021..6b57119df7a9bf 100644
--- a/mlir/include/mlir/Query/Matcher/VariantValue.h
+++ b/mlir/include/mlir/Query/Matcher/VariantValue.h
@@ -21,7 +21,7 @@
namespace mlir::query::matcher {
// All types that VariantValue can contain.
-enum class ArgKind { Matcher, String };
+enum class ArgKind { Matcher, String, Unsigned };
// A variant matcher object to abstract simple and complex matchers into a
// single object type.
@@ -81,6 +81,7 @@ class VariantValue {
// Specific constructors for each supported type.
VariantValue(const llvm::StringRef string);
VariantValue(const VariantMatcher &matcher);
+ VariantValue(unsigned Unsigned);
// String value functions.
bool isString() const;
@@ -92,8 +93,15 @@ class VariantValue {
const VariantMatcher &getMatcher() const;
void setMatcher(const VariantMatcher &matcher);
+ // Unsigned value functions.
+ bool isUnsigned() const;
+ unsigned getUnsigned() const;
+ void setUnsigned(unsigned Unsigned);
+
// String representation of the type of the value.
std::string getTypeAsString() const;
+ explicit operator bool() const { return hasValue(); }
+ bool hasValue() const { return type != ValueType::Nothing; }
private:
void reset();
@@ -103,12 +111,14 @@ class VariantValue {
Nothing,
String,
Matcher,
+ Unsigned,
};
// All supported value types.
union AllValues {
llvm::StringRef *String;
VariantMatcher *Matcher;
+ unsigned Unsigned;
};
ValueType type;
diff --git a/mlir/include/mlir/Query/Query.h b/mlir/include/mlir/Query/Query.h
index 18f2172c9510a3..89d48773d2c3e6 100644
--- a/mlir/include/mlir/Query/Query.h
+++ b/mlir/include/mlir/Query/Query.h
@@ -17,7 +17,13 @@
namespace mlir::query {
-enum class QueryKind { Invalid, NoOp, Help, Match, Quit };
+struct QueryOptions {
+ bool omitBlockArguments = false;
+ bool omitUsesFromAbove = true;
+ bool inclusive = true;
+};
+
+enum class QueryKind { Invalid, NoOp, Help, Match, Quit, Let, SetBool };
class QuerySession;
@@ -103,6 +109,47 @@ struct MatchQuery : Query {
}
};
+struct LetQuery : Query {
+ LetQuery(llvm::StringRef name, const matcher::VariantValue &value)
+ : Query(QueryKind::Let), name(name), value(value) {}
+
+ llvm::LogicalResult run(llvm::raw_ostream &os,
+ QuerySession &qs) const override;
+
+ std::string name;
+ matcher::VariantValue value;
+
+ static bool classof(const Query *query) {
+ return query->kind == QueryKind::Let;
+ }
+};
+
+template <typename T>
+struct SetQueryKind {};
+
+template <>
+struct SetQueryKind<bool> {
+ static const QueryKind value = QueryKind::SetBool;
+};
+template <typename T>
+struct SetQuery : Query {
+ SetQuery(T QuerySession::*var, T value)
+ : Query(SetQueryKind<T>::value), var(var), value(value) {}
+
+ llvm::LogicalResult run(llvm::raw_ostream &os,
+ QuerySession &qs) const override {
+ qs.*var = value;
+ return mlir::success();
+ }
+
+ static bool classof(const Query *query) {
+ return query->kind == SetQueryKind<T>::value;
+ }
+
+ T QuerySession::*var;
+ T value;
+};
+
} // namespace mlir::query
#endif
diff --git a/mlir/include/mlir/Query/QuerySession.h b/mlir/include/mlir/Query/QuerySession.h
index fe552d750fc771..495358e8f36f94 100644
--- a/mlir/include/mlir/Query/QuerySession.h
+++ b/mlir/include/mlir/Query/QuerySession.h
@@ -9,14 +9,18 @@
#ifndef MLIR_TOOLS_MLIRQUERY_QUERYSESSION_H
#define MLIR_TOOLS_MLIRQUERY_QUERYSESSION_H
+#include "Matcher/VariantValue.h"
#include "mlir/IR/Operation.h"
#include "mlir/Query/Matcher/Registry.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/Support/SourceMgr.h"
+namespace mlir::query::matcher {
+class Registry;
+}
+
namespace mlir::query {
-class Registry;
// Represents the state for a particular mlir-query session.
class QuerySession {
public:
@@ -33,6 +37,11 @@ class QuerySession {
llvm::StringMap<matcher::VariantValue> namedValues;
bool terminate = false;
+public:
+ bool omitBlockArguments = false;
+ bool omitUsesFromAbove = true;
+ bool inclusive = true;
+
private:
Operation *rootOp;
llvm::SourceMgr &sourceMgr;
diff --git a/mlir/lib/Query/Matcher/Parser.cpp b/mlir/lib/Query/Matcher/Parser.cpp
index 3609e24f9939f7..726f1188d7e4c8 100644
--- a/mlir/lib/Query/Matcher/Parser.cpp
+++ b/mlir/lib/Query/Matcher/Parser.cpp
@@ -135,6 +135,18 @@ class Parser::CodeTokenizer {
case '\'':
consumeStringLiteral(&result);
break;
+ case '0':
+ case '1':
+ case '2':
+ case '3':
+ case '4':
+ case '5':
+ case '6':
+ case '7':
+ case '8':
+ case '9':
+ consumeNumberLiteral(&result);
+ break;
default:
parseIdentifierOrInvalid(&result);
break;
@@ -144,6 +156,30 @@ class Parser::CodeTokenizer {
return result;
}
+ void consumeNumberLiteral(TokenInfo *result) {
+ unsigned length = 1;
+ if (code.size() > 1) {
+ // Consume the 'x' or 'b' radix modifier, if present.
+ switch (tolower(code[1])) {
+ case 'x':
+ case 'b':
+ length = 2;
+ }
+ }
+ while (length < code.size() && isdigit(code[length]))
+ ++length;
+
+ result->text = code.take_front(length);
+ code = code.drop_front(length);
+
+ unsigned value;
+ if (!result->text.getAsInteger(0, value)) {
+ result->kind = TokenKind::Literal;
+ result->value = static_cast<unsigned>(value);
+ return;
+ }
+ }
+
// Consume a string literal, handle escape sequences and missing closing
// quote.
void consumeStringLiteral(TokenInfo *result) {
@@ -257,13 +293,19 @@ bool Parser::parseIdentifierPrefixImpl(VariantValue *value) {
if (tokenizer->nextTokenKind() != TokenKind::OpenParen) {
// Parse as a named value.
- auto namedValue =
- namedValues ? namedValues->lookup(nameToken.text) : VariantValue();
+ if (auto namedValue = namedValues ? namedValues->lookup(nameToken.te...
[truncated]
|
88a01fd
to
0a12247
Compare
0a12247
to
30aff7a
Compare
I see the presubmit is failing on the newly added tests (mlir-query/complex-test.mlir failing) |
return false; | ||
} | ||
|
||
auto processValue = [&](Value value) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could mlir/include/mlir/Analysis/SliceAnalysis.h's methods be used here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
temHops parameter which specifies how far into the graph we are searching (how many levels) limits us to reuse code from mlir/include/mlir/Analysis/SliceAnalysis.h
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think maxDepth
would be more fitting for this parameter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at it, it may be possible (but less efficient) by keeping track of the depth in a map and then having the filter return whether the minimum depth of any of the operation's argument's depth is greater or equal to maxDepth.
62593f0
to
7a9fbd6
Compare
SmallVector<Operation *> printingOps(tempStorage.takeVector()); | ||
for (auto op : printingOps) { | ||
if (printMatchingOps) { | ||
printMatch(os, qs, op, "root"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should use the overloaded print method with no argument for the binding. I accidentally reverted this back to using the old method, but will update it accordingly with the next changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SG
// | ||
// getMatches walks the IR and prints operations as soon as it matches them | ||
// if a matcher is to be further extracted into the function, then it does not | ||
// print operations |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is printing needed as a match is no longer just an Operation? How about returning a vector of "Match" which consists of the root operation match + environment? If we could decouple the matching from the printing of the matching, this would enable reuse in more places.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would we want to return a vector of vectors (each one containing rootOp + matching environment) for printing and for reusability we could return a vector containing sequences of rootOp + matching environment?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It now follows a similar design to clang-query's MatchFinder class
return false; | ||
} | ||
|
||
auto processValue = [&](Value value) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at it, it may be possible (but less efficient) by keeping track of the depth in a map and then having the filter return whether the minimum depth of any of the operation's argument's depth is greater or equal to maxDepth.
Edit: suitable solution: backwardSlice.clear();
llvm::DenseMap<Operation *, int64_t> opDepths;
opDepths[rootOp] = 0;
query::QueryOptions depthOptions = options;
depthOptions.filter = [&](Operation *op) {
if (opDepths[op] > maxDepth)
return false;
for (auto operand : op->getOperands()) {
if (auto definingOp = operand.getDefiningOp()) {
auto it = opDepths.find(definingOp);
if (it == opDepths.end()) {
opDepths[definingOp] = opDepths[op] + 1;
if (opDepths[op] > maxDepth)
return false;
}
} else {
auto blockArgument = cast<BlockArgument>(operand));
Operation *parentOp = blockArgument.getOwner()->getParentOp();
auto it = opDepths.find(parentOp);
if (it == opDepths.end()) {
opDepths[parentOp] = opDepths[op] + 1;
if (opDepths[op] > maxDepth)
return false;
}
}
}
return true;
};
getBackwardSlice(rootOp, &backwardSlice, depthOptions); |
7a9fbd6
to
e6bc9b3
Compare
@jpienaar could you take another look? mlir-query MatchFinder now follows a similar design to clang-query's MatchFinder class and we now reuse the existing algorithm logic with a custom filter |
qs.getSourceManager().PrintMessage(os, smloc, llvm::SourceMgr::DK_Note, | ||
"\"" + binding + "\" binds here"); | ||
} | ||
|
||
// TODO: Extract into a helper function that can be reused outside query | ||
// context. | ||
static Operation *extractFunction(std::vector<Operation *> &ops, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could this method be a candidate for the MatchFinder class?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, or even higher, such as a utility function in Analysis or Transform or so, as this is rather common.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could move it to MatchFinder. I could also create an issue for this and assign it to myself. I’d need to investigate where in Analysis or Transform would be the most suitable place.
c51a7b9
to
95bb64e
Compare
524318d
to
cad295c
Compare
cad295c
to
e07e1fe
Compare
I'm not sure why complex-test.mlir is failing in CI for Linux. Locally I use wsl and it seems to pass. Also I did not modify the printing logic. |
69472f0
to
5f940da
Compare
562b1a1
to
9a16aed
Compare
@jpienaar could you provide feedback on the latest changes? |
654ac76
to
d637929
Compare
// RUN: mlir-query %s -c "m getAllDefinitions(hasOpName(\"arith.addf\"),2)" | FileCheck %s | ||
|
||
#map = affine_map<(d0, d1) -> (d0, d1)> | ||
func.func @slice_use_from_above(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've copied this test from mlir/test/IR/slice.mlir
I coud at least rename it to "slicing-query" and/or change the contents to avoid duplication?
d637929
to
cfabf87
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this look overall good. Thanks. Lets do the one rename you were also thinking about, and then can approve. Do you need help landing it?
@@ -0,0 +1,85 @@ | |||
//===- ExtraMatchers.h - Various common matchers --------------------------===// | |||
// |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about just SliceMatchers? (says what it is and then we don't need to think about what is and what isn't extra)
Relocate backwardSlice matcher to Query specific headers Remove unncecessary code
Make BackwardSlice matcher more generic Capture values in tests
af3bb15
to
87e2e44
Compare
@chios202 Congratulations on having your first Pull Request (PR) merged into the LLVM Project! Your changes will be combined with recent changes from other authors, then tested by our build bots. If there is a problem with a build, you may receive a report in an email or a comment on this PR. Please check whether problems have been caused by your change specifically, as the builds can include changes from many authors. It is not uncommon for your change to be included in a build that fails due to someone else's changes, or infrastructure issues. How to do this, and the rest of the post-merge process, is covered in detail here. If your change does cause a problem, it may be reverted, or you can revert it yourself. This is a normal part of LLVM development. You can fix your changes and open a new PR to merge them again. If you don't get any reports, no action is required from you. Your changes are working as expected, well done! |
getBackwardSlice
and getForwardSlice
matchersgetBackwardSlice
matcher
Improve mlir-query tool by implementing
getBackwardSlice
matcherNote: backwardSlice and forwardSlice algoritms are the same as the ones in
mlir/lib/Analysis/SliceAnalysis.cpp
Example of current matcher. The query was made to the file:
mlir/test/mlir-query/complex-test.mlir