Skip to content

[mlir] Improve mlir-query tool by implementing getBackwardSlice matcher #115670

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions mlir/include/mlir/Query/Matcher/Marshallers.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,36 @@ struct ArgTypeTraits<llvm::StringRef> {
}
};

template <>
struct ArgTypeTraits<int64_t> {
static bool hasCorrectType(const VariantValue &value) {
return value.isSigned();
}

static unsigned get(const VariantValue &value) { return value.getSigned(); }

static ArgKind getKind() { return ArgKind::Signed; }

static std::optional<std::string> getBestGuess(const VariantValue &) {
return std::nullopt;
}
};

template <>
struct ArgTypeTraits<bool> {
static bool hasCorrectType(const VariantValue &value) {
return value.isBoolean();
}

static unsigned get(const VariantValue &value) { return value.getBoolean(); }

static ArgKind getKind() { return ArgKind::Boolean; }

static std::optional<std::string> getBestGuess(const VariantValue &) {
return std::nullopt;
}
};

template <>
struct ArgTypeTraits<DynMatcher> {

Expand Down
48 changes: 33 additions & 15 deletions mlir/include/mlir/Query/Matcher/MatchFinder.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,33 +7,51 @@
//===----------------------------------------------------------------------===//
//
// This file contains the MatchFinder class, which is used to find operations
// that match a given matcher.
// that match a given matcher and print them.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H
#define MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H

#include "MatchersInternal.h"
#include "mlir/Query/Query.h"
#include "mlir/Query/QuerySession.h"
#include "llvm/ADT/SetVector.h"

namespace mlir::query::matcher {

// MatchFinder is used to find all operations that match a given matcher.
/// A class that provides utilities to find operations in the IR.
class MatchFinder {

public:
// Returns all operations that match the given matcher.
static std::vector<Operation *> getMatches(Operation *root,
DynMatcher matcher) {
std::vector<Operation *> matches;

// Simple match finding with walk.
root->walk([&](Operation *subOp) {
if (matcher.match(subOp))
matches.push_back(subOp);
});

return matches;
}
/// A subclass which preserves the matching information. Each instance
/// contains the `rootOp` along with the matching environment.
struct MatchResult {
MatchResult() = default;
MatchResult(Operation *rootOp, std::vector<Operation *> matchedOps);

Operation *rootOp = nullptr;
/// Contains the matching environment.
std::vector<Operation *> matchedOps;
};

/// Traverses the IR and returns a vector of `MatchResult` for each match of
/// the `matcher`.
std::vector<MatchResult> collectMatches(Operation *root,
DynMatcher matcher) const;

/// Prints the matched operation.
void printMatch(llvm::raw_ostream &os, QuerySession &qs, Operation *op) const;

/// Labels the matched operation with the given binding (e.g., `"root"`) and
/// prints it.
void printMatch(llvm::raw_ostream &os, QuerySession &qs, Operation *op,
const std::string &binding) const;

/// Flattens a vector of `MatchResult` into a vector of operations.
std::vector<Operation *>
flattenMatchedOps(std::vector<MatchResult> &matches) const;
};

} // namespace mlir::query::matcher
Expand Down
59 changes: 49 additions & 10 deletions mlir/include/mlir/Query/Matcher/MatchersInternal.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
//
// Implements the base layer of the matcher framework.
//
// Matchers are methods that return a Matcher which provides a method
// match(Operation *op)
// Matchers are methods that return a Matcher which provides a method one of the
// following methods: match(Operation *op), match(Operation *op,
// SetVector<Operation *> &matchedOps)
//
// The matcher functions are defined in include/mlir/IR/Matchers.h.
// This file contains the wrapper classes needed to construct matchers for
Expand All @@ -25,13 +26,39 @@

namespace mlir::query::matcher {

// Defaults to false if T has no match() method with the signature:
// match(Operation* op).
template <typename T, typename = void>
struct has_simple_match : std::false_type {};

// Specialized type trait that evaluates to true if T has a match() method
// with the signature: match(Operation* op).
template <typename T>
struct has_simple_match<T, std::void_t<decltype(std::declval<T>().match(
std::declval<Operation *>()))>>
: std::true_type {};

// Defaults to false if T has no match() method with the signature:
// match(Operation* op, SetVector<Operation*>&).
template <typename T, typename = void>
struct has_bound_match : std::false_type {};

// Specialized type trait that evaluates to true if T has a match() method
// with the signature: match(Operation* op, SetVector<Operation*>&).
template <typename T>
struct has_bound_match<T, std::void_t<decltype(std::declval<T>().match(
std::declval<Operation *>(),
std::declval<SetVector<Operation *> &>()))>>
: std::true_type {};

// Generic interface for matchers on an MLIR operation.
class MatcherInterface
: public llvm::ThreadSafeRefCountedBase<MatcherInterface> {
public:
virtual ~MatcherInterface() = default;

virtual bool match(Operation *op) = 0;
virtual bool match(Operation *op, SetVector<Operation *> &matchedOps) = 0;
};

// MatcherFnImpl takes a matcher function object and implements
Expand All @@ -40,14 +67,25 @@ template <typename MatcherFn>
class MatcherFnImpl : public MatcherInterface {
public:
MatcherFnImpl(MatcherFn &matcherFn) : matcherFn(matcherFn) {}
bool match(Operation *op) override { return matcherFn.match(op); }

bool match(Operation *op) override {
if constexpr (has_simple_match<MatcherFn>::value)
return matcherFn.match(op);
return false;
}

bool match(Operation *op, SetVector<Operation *> &matchedOps) override {
if constexpr (has_bound_match<MatcherFn>::value)
return matcherFn.match(op, matchedOps);
return false;
}

private:
MatcherFn matcherFn;
};

// Matcher wraps a MatcherInterface implementation and provides a match()
// method that redirects calls to the underlying implementation.
// Matcher wraps a MatcherInterface implementation and provides match()
// methods that redirect calls to the underlying implementation.
class DynMatcher {
public:
// Takes ownership of the provided implementation pointer.
Expand All @@ -62,12 +100,13 @@ class DynMatcher {
}

bool match(Operation *op) const { return implementation->match(op); }
bool match(Operation *op, SetVector<Operation *> &matchedOps) const {
return implementation->match(op, matchedOps);
}

void setFunctionName(StringRef name) { functionName = name.str(); };

bool hasFunctionName() const { return !functionName.empty(); };

StringRef getFunctionName() const { return functionName; };
void setFunctionName(StringRef name) { functionName = name.str(); }
bool hasFunctionName() const { return !functionName.empty(); }
StringRef getFunctionName() const { return functionName; }

private:
llvm::IntrusiveRefCntPtr<MatcherInterface> implementation;
Expand Down
141 changes: 141 additions & 0 deletions mlir/include/mlir/Query/Matcher/SliceMatchers.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
//===- SliceMatchers.h - Matchers for slicing analysis ----------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file provides matchers for MLIRQuery that peform slicing analysis
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_TOOLS_MLIRQUERY_MATCHERS_SLICEMATCHERS_H
#define MLIR_TOOLS_MLIRQUERY_MATCHERS_SLICEMATCHERS_H

#include "mlir/Analysis/SliceAnalysis.h"

/// A matcher encapsulating `getBackwardSlice` method from SliceAnalysis.h.
/// Additionally, it limits the slice computation to a certain depth level using
/// a custom filter.
///
/// Example: starting from node 9, assuming the matcher
/// computes the slice for the first two depth levels:
/// ============================
/// 1 2 3 4
/// |_______| |______|
/// | | |
/// | 5 6
/// |___|_____________|
/// | |
/// 7 8
/// |_______________|
/// |
/// 9
///
/// Assuming all local orders match the numbering order:
/// {5, 7, 6, 8, 9}
namespace mlir::query::matcher {

template <typename Matcher>
class BackwardSliceMatcher {
public:
BackwardSliceMatcher(Matcher innerMatcher, int64_t maxDepth, bool inclusive,
bool omitBlockArguments, bool omitUsesFromAbove)
: innerMatcher(std::move(innerMatcher)), maxDepth(maxDepth),
inclusive(inclusive), omitBlockArguments(omitBlockArguments),
omitUsesFromAbove(omitUsesFromAbove) {}

bool match(Operation *rootOp, SetVector<Operation *> &backwardSlice) {
BackwardSliceOptions options;
options.inclusive = inclusive;
options.omitUsesFromAbove = omitUsesFromAbove;
options.omitBlockArguments = omitBlockArguments;
return (innerMatcher.match(rootOp) &&
matches(rootOp, backwardSlice, options, maxDepth));
}

private:
bool matches(Operation *rootOp, llvm::SetVector<Operation *> &backwardSlice,
BackwardSliceOptions &options, int64_t maxDepth);

private:
// The outer matcher (e.g., BackwardSliceMatcher) relies on the innerMatcher
// to determine whether we want to traverse the IR or not. For example, we
// want to explore the IR only if the top-level operation name is
// `"arith.addf"`.
Matcher innerMatcher;
// `maxDepth` specifies the maximum depth that the matcher can traverse the
// IR. For example, if `maxDepth` is 2, the matcher will explore the defining
// operations of the top-level op up to 2 levels.
int64_t maxDepth;
bool inclusive;
bool omitBlockArguments;
bool omitUsesFromAbove;
};

template <typename Matcher>
bool BackwardSliceMatcher<Matcher>::matches(
Operation *rootOp, llvm::SetVector<Operation *> &backwardSlice,
BackwardSliceOptions &options, int64_t maxDepth) {
backwardSlice.clear();
llvm::DenseMap<Operation *, int64_t> opDepths;
// Initializing the root op with a depth of 0
opDepths[rootOp] = 0;
options.filter = [&](Operation *subOp) {
// If the subOp hasn't been recorded in opDepths, it is deeper than
// maxDepth.
if (!opDepths.contains(subOp))
return false;
// Examine subOp's operands to compute depths of their defining operations.
for (auto operand : subOp->getOperands()) {
int64_t newDepth = opDepths[subOp] + 1;
// If the newDepth is greater than maxDepth, further computation can be
// skipped.
if (newDepth > maxDepth)
continue;

if (auto definingOp = operand.getDefiningOp()) {
// Registers the minimum depth
if (!opDepths.contains(definingOp) || newDepth < opDepths[definingOp])
opDepths[definingOp] = newDepth;
} else {
auto blockArgument = cast<BlockArgument>(operand);
Operation *parentOp = blockArgument.getOwner()->getParentOp();
if (!parentOp)
continue;

if (!opDepths.contains(parentOp) || newDepth < opDepths[parentOp])
opDepths[parentOp] = newDepth;
}
}
return true;
};
getBackwardSlice(rootOp, &backwardSlice, options);
return options.inclusive ? backwardSlice.size() > 1
: backwardSlice.size() >= 1;
}

/// Matches transitive defs of a top-level operation up to N levels.
template <typename Matcher>
inline BackwardSliceMatcher<Matcher>
m_GetDefinitions(Matcher innerMatcher, int64_t maxDepth, bool inclusive,
bool omitBlockArguments, bool omitUsesFromAbove) {
assert(maxDepth >= 0 && "maxDepth must be non-negative");
return BackwardSliceMatcher<Matcher>(std::move(innerMatcher), maxDepth,
inclusive, omitBlockArguments,
omitUsesFromAbove);
}

/// Matches all transitive defs of a top-level operation up to N levels
template <typename Matcher>
inline BackwardSliceMatcher<Matcher> m_GetAllDefinitions(Matcher innerMatcher,
int64_t maxDepth) {
assert(maxDepth >= 0 && "maxDepth must be non-negative");
return BackwardSliceMatcher<Matcher>(std::move(innerMatcher), maxDepth, true,
false, false);
}

} // namespace mlir::query::matcher

#endif // MLIR_TOOLS_MLIRQUERY_MATCHERS_SLICEMATCHERS_H
Loading