Skip to content

[mlir] Improve mlir-query by adding matcher combinators #141423

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

chios202
Copy link
Contributor

@chios202 chios202 commented May 25, 2025

Whereas backward-slice matching provides support to limit traversal by specifying the desired depth level, this pull request introduces support for limiting traversal with a nested matcher (adding forward-slice also). It also adds support for variadic operators, including anyOf and allOf. Rather than simply stopping traversal when an operation named foo is encountered, one can now define a matcher that specifies different exit conditions. Variadic support implementation within mlir-query is very similar to clang-query.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels May 25, 2025
@llvmbot
Copy link
Member

llvmbot commented May 25, 2025

@llvm/pr-subscribers-mlir-core

Author: Denzel-Brian Budii (chios202)

Changes

Whereas backward-slice matching provides support to limit traversal by specifying the desired depth level, this pull request introduces support for limiting traversal with a nested matcher. It also adds support for variadic operators, including anyOf and allOf. Rather than simply stopping traversal when an operation named foo is encountered, you can now define a matcher that specifies different exit conditions.


Patch is 21.24 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/141423.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Query/Matcher/Marshallers.h (+61)
  • (modified) mlir/include/mlir/Query/Matcher/MatchersInternal.h (+103-4)
  • (modified) mlir/include/mlir/Query/Matcher/SliceMatchers.h (+63-5)
  • (modified) mlir/include/mlir/Query/Matcher/VariantValue.h (+10-1)
  • (modified) mlir/lib/Query/Matcher/CMakeLists.txt (+1)
  • (added) mlir/lib/Query/Matcher/MatchersInternal.cpp (+30)
  • (modified) mlir/lib/Query/Matcher/RegistryManager.cpp (+5-2)
  • (modified) mlir/lib/Query/Matcher/VariantValue.cpp (+54)
  • (modified) mlir/tools/mlir-query/mlir-query.cpp (+6)
diff --git a/mlir/include/mlir/Query/Matcher/Marshallers.h b/mlir/include/mlir/Query/Matcher/Marshallers.h
index 012bf7b9ec4a9..f81e789f274e6 100644
--- a/mlir/include/mlir/Query/Matcher/Marshallers.h
+++ b/mlir/include/mlir/Query/Matcher/Marshallers.h
@@ -108,6 +108,9 @@ class MatcherDescriptor {
                                 const llvm::ArrayRef<ParserValue> args,
                                 Diagnostics *error) const = 0;
 
+  // If the matcher is variadic, it can take any number of arguments.
+  virtual bool isVariadic() const = 0;
+
   // Returns the number of arguments accepted by the matcher.
   virtual unsigned getNumArgs() const = 0;
 
@@ -140,6 +143,8 @@ class FixedArgCountMatcherDescriptor : public MatcherDescriptor {
     return marshaller(matcherFunc, matcherName, nameRange, args, error);
   }
 
+  bool isVariadic() const override { return false; }
+
   unsigned getNumArgs() const override { return argKinds.size(); }
 
   void getArgKinds(unsigned argNo, std::vector<ArgKind> &kinds) const override {
@@ -153,6 +158,54 @@ class FixedArgCountMatcherDescriptor : public MatcherDescriptor {
   const std::vector<ArgKind> argKinds;
 };
 
+class VariadicOperatorMatcherDescriptor : public MatcherDescriptor {
+public:
+  using VarOp = DynMatcher::VariadicOperator;
+  VariadicOperatorMatcherDescriptor(unsigned minCount, unsigned maxCount,
+                                    VarOp varOp, StringRef matcherName)
+      : minCount(minCount), maxCount(maxCount), varOp(varOp),
+        matcherName(matcherName) {}
+
+  VariantMatcher create(SourceRange nameRange, ArrayRef<ParserValue> args,
+                        Diagnostics *error) const override {
+    if (args.size() < minCount || maxCount < args.size()) {
+      addError(error, nameRange, ErrorType::RegistryWrongArgCount,
+               {llvm::Twine("requires between "), llvm::Twine(minCount),
+                llvm::Twine(" and "), llvm::Twine(maxCount),
+                llvm::Twine(" args, got "), llvm::Twine(args.size())});
+      return VariantMatcher();
+    }
+
+    std::vector<VariantMatcher> innerArgs;
+    for (size_t i = 0, e = args.size(); i != e; ++i) {
+      const ParserValue &arg = args[i];
+      const VariantValue &value = arg.value;
+      if (!value.isMatcher()) {
+        addError(error, arg.range, ErrorType::RegistryWrongArgType,
+                 {llvm::Twine(i + 1), llvm::Twine("Matcher: "),
+                  llvm::Twine(value.getTypeAsString())});
+        return VariantMatcher();
+      }
+      innerArgs.push_back(value.getMatcher());
+    }
+    return VariantMatcher::VariadicOperatorMatcher(varOp, std::move(innerArgs));
+  }
+
+  bool isVariadic() const override { return true; }
+
+  unsigned getNumArgs() const override { return 0; }
+
+  void getArgKinds(unsigned argNo, std::vector<ArgKind> &kinds) const override {
+    kinds.push_back(ArgKind(ArgKind::Matcher));
+  }
+
+private:
+  const unsigned minCount;
+  const unsigned maxCount;
+  const VarOp varOp;
+  const StringRef matcherName;
+};
+
 // Helper function to check if argument count matches expected count
 inline bool checkArgCount(SourceRange nameRange, size_t expectedArgCount,
                           llvm::ArrayRef<ParserValue> args,
@@ -224,6 +277,14 @@ makeMatcherAutoMarshall(ReturnType (*matcherFunc)(ArgTypes...),
       reinterpret_cast<void (*)()>(matcherFunc), matcherName, argKinds);
 }
 
+// Variadic operator overload.
+template <unsigned MinCount, unsigned MaxCount>
+std::unique_ptr<MatcherDescriptor>
+makeMatcherAutoMarshall(VariadicOperatorMatcherFunc<MinCount, MaxCount> func,
+                        StringRef matcherName) {
+  return std::make_unique<VariadicOperatorMatcherDescriptor>(
+      MinCount, MaxCount, func.varOp, matcherName);
+}
 } // namespace mlir::query::matcher::internal
 
 #endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H
diff --git a/mlir/include/mlir/Query/Matcher/MatchersInternal.h b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
index 183b2514e109f..904103407611e 100644
--- a/mlir/include/mlir/Query/Matcher/MatchersInternal.h
+++ b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
@@ -8,11 +8,11 @@
 //
 // Implements the base layer of the matcher framework.
 //
-// Matchers are methods that return a Matcher which provides a method one of the
-// following methods: match(Operation *op), match(Operation *op,
-// SetVector<Operation *> &matchedOps)
+// Matchers are methods that return a Matcher which provide a method
+// `match(...)` method. The method's parameters define the context of the match.
+// Support includes simple (unary) matchers as well as matcher combinators.
+// (anyOf, allOf, etc.)
 //
-// The matcher functions are defined in include/mlir/IR/Matchers.h.
 // This file contains the wrapper classes needed to construct matchers for
 // mlir-query.
 //
@@ -25,6 +25,13 @@
 #include "llvm/ADT/IntrusiveRefCntPtr.h"
 
 namespace mlir::query::matcher {
+class DynMatcher;
+namespace internal {
+
+bool allOfVariadicOperator(Operation *op, ArrayRef<DynMatcher> innerMatchers);
+bool anyOfVariadicOperator(Operation *op, ArrayRef<DynMatcher> innerMatchers);
+
+} // namespace internal
 
 // Defaults to false if T has no match() method with the signature:
 // match(Operation* op).
@@ -84,6 +91,26 @@ class MatcherFnImpl : public MatcherInterface {
   MatcherFn matcherFn;
 };
 
+// VariadicMatcher takes a vector of Matchers and returns true if any Matchers
+// match the given operation.
+using VariadicOperatorFunction = bool (*)(Operation *op,
+                                          ArrayRef<DynMatcher> innerMatchers);
+
+template <VariadicOperatorFunction Func>
+class VariadicMatcher : public MatcherInterface {
+public:
+  VariadicMatcher(std::vector<DynMatcher> matchers) : matchers(matchers) {}
+
+  bool match(Operation *op) override { return Func(op, matchers); }
+  // Fallback case
+  bool match(Operation *op, SetVector<Operation *> &matchedOps) override {
+    return false;
+  }
+
+private:
+  std::vector<DynMatcher> matchers;
+};
+
 // Matcher wraps a MatcherInterface implementation and provides match()
 // methods that redirect calls to the underlying implementation.
 class DynMatcher {
@@ -92,6 +119,31 @@ class DynMatcher {
   DynMatcher(MatcherInterface *implementation)
       : implementation(implementation) {}
 
+  // Construct from a variadic function.
+  enum VariadicOperator {
+    // Matches operations for which all provided matchers match.
+    AllOf,
+    // Matches operations for which at least one of the provided matchers
+    // matches.
+    AnyOf
+  };
+
+  static std::unique_ptr<DynMatcher>
+  constructVariadic(VariadicOperator Op,
+                    std::vector<DynMatcher> innerMatchers) {
+    switch (Op) {
+    case AllOf:
+      return std::make_unique<DynMatcher>(
+          new VariadicMatcher<internal::allOfVariadicOperator>(
+              std::move(innerMatchers)));
+    case AnyOf:
+      return std::make_unique<DynMatcher>(
+          new VariadicMatcher<internal::anyOfVariadicOperator>(
+              std::move(innerMatchers)));
+    }
+    llvm_unreachable("Invalid Op value.");
+  }
+
   template <typename MatcherFn>
   static std::unique_ptr<DynMatcher>
   constructDynMatcherFromMatcherFn(MatcherFn &matcherFn) {
@@ -113,6 +165,53 @@ class DynMatcher {
   std::string functionName;
 };
 
+// VariadicOperatorMatcher related types.
+template <typename... Ps>
+class VariadicOperatorMatcher {
+public:
+  VariadicOperatorMatcher(DynMatcher::VariadicOperator varOp, Ps &&...params)
+      : varOp(varOp), params(std::forward<Ps>(params)...) {}
+
+  operator std::unique_ptr<DynMatcher>() const & {
+    return DynMatcher::constructVariadic(
+        varOp, getMatchers(std::index_sequence_for<Ps...>()));
+  }
+
+  operator std::unique_ptr<DynMatcher>() && {
+    return DynMatcher::constructVariadic(
+        varOp, std::move(*this).getMatchers(std::index_sequence_for<Ps...>()));
+  }
+
+private:
+  // Helper method to unpack the tuple into a vector.
+  template <std::size_t... Is>
+  std::vector<DynMatcher> getMatchers(std::index_sequence<Is...>) const & {
+    return {DynMatcher(std::get<Is>(params))...};
+  }
+
+  template <std::size_t... Is>
+  std::vector<DynMatcher> getMatchers(std::index_sequence<Is...>) && {
+    return {DynMatcher(std::get<Is>(std::move(params)))...};
+  }
+
+  const DynMatcher::VariadicOperator varOp;
+  std::tuple<Ps...> params;
+};
+
+// Overloaded function object to generate VariadicOperatorMatcher objects from
+// arbitrary matchers.
+template <unsigned MinCount, unsigned MaxCount>
+struct VariadicOperatorMatcherFunc {
+  DynMatcher::VariadicOperator varOp;
+
+  template <typename... Ms>
+  VariadicOperatorMatcher<Ms...> operator()(Ms &&...Ps) const {
+    static_assert(MinCount <= sizeof...(Ms) && sizeof...(Ms) <= MaxCount,
+                  "invalid number of parameters for variadic matcher");
+    return VariadicOperatorMatcher<Ms...>(varOp, std::forward<Ms>(Ps)...);
+  }
+};
+
 } // namespace mlir::query::matcher
 
 #endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
diff --git a/mlir/include/mlir/Query/Matcher/SliceMatchers.h b/mlir/include/mlir/Query/Matcher/SliceMatchers.h
index 5bb8251672eb7..fec46d2ff814d 100644
--- a/mlir/include/mlir/Query/Matcher/SliceMatchers.h
+++ b/mlir/include/mlir/Query/Matcher/SliceMatchers.h
@@ -6,7 +6,8 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This file provides matchers for MLIRQuery that peform slicing analysis
+// This file defines slicing-analysis matchers that extend and abstract the
+// core implementations from `SliceAnalysis.h`.
 //
 //===----------------------------------------------------------------------===//
 
@@ -15,9 +16,9 @@
 
 #include "mlir/Analysis/SliceAnalysis.h"
 
-/// A matcher encapsulating `getBackwardSlice` method from SliceAnalysis.h.
-/// Additionally, it limits the slice computation to a certain depth level using
-/// a custom filter.
+/// Computes the backward-slice of all transitive defs reachable from `rootOp`,
+/// if `innerMatcher` matches. The traversal stops once the desired depth level
+/// is reached.
 ///
 /// Example: starting from node 9, assuming the matcher
 /// computes the slice for the first two depth levels:
@@ -116,6 +117,51 @@ bool BackwardSliceMatcher<Matcher>::matches(
                            : backwardSlice.size() >= 1;
 }
 
+/// Computes the backward-slice of all transitive defs reachable from `rootOp`,
+/// if `innerMatcher` matches. Traversal stops where `filterMatcher` matches.
+template <typename BaseMatcher, typename Filter>
+class PredicateBackwardSliceMatcher {
+public:
+  PredicateBackwardSliceMatcher(BaseMatcher innerMatcher, Filter filterMatcher,
+                                bool inclusive, bool omitBlockArguments,
+                                bool omitUsesFromAbove)
+      : innerMatcher(std::move(innerMatcher)),
+        filterMatcher(std::move(filterMatcher)), inclusive(inclusive),
+        omitBlockArguments(omitBlockArguments),
+        omitUsesFromAbove(omitUsesFromAbove) {}
+
+  bool match(Operation *rootOp, SetVector<Operation *> &backwardSlice) {
+    backwardSlice.clear();
+    BackwardSliceOptions options;
+    options.inclusive = inclusive;
+    options.omitUsesFromAbove = omitUsesFromAbove;
+    options.omitBlockArguments = omitBlockArguments;
+    if (innerMatcher.match(rootOp)) {
+      options.filter = [&](Operation *subOp) {
+        return !filterMatcher.match(subOp);
+      };
+      getBackwardSlice(rootOp, &backwardSlice, options);
+      return options.inclusive ? backwardSlice.size() > 1
+                               : backwardSlice.size() >= 1;
+    }
+    return false;
+  }
+
+private:
+  BaseMatcher innerMatcher;
+  Filter filterMatcher;
+  bool inclusive;
+  bool omitBlockArguments;
+  bool omitUsesFromAbove;
+};
+
+const matcher::VariadicOperatorMatcherFunc<1,
+                                           std::numeric_limits<unsigned>::max()>
+    anyOf = {matcher::DynMatcher::AnyOf};
+const matcher::VariadicOperatorMatcherFunc<1,
+                                           std::numeric_limits<unsigned>::max()>
+    allOf = {matcher::DynMatcher::AllOf};
+
 /// Matches transitive defs of a top-level operation up to N levels.
 template <typename Matcher>
 inline BackwardSliceMatcher<Matcher>
@@ -127,7 +173,7 @@ m_GetDefinitions(Matcher innerMatcher, int64_t maxDepth, bool inclusive,
                                        omitUsesFromAbove);
 }
 
-/// Matches all transitive defs of a top-level operation up to N levels
+/// Matches all transitive defs of a top-level operation up to N levels.
 template <typename Matcher>
 inline BackwardSliceMatcher<Matcher> m_GetAllDefinitions(Matcher innerMatcher,
                                                          int64_t maxDepth) {
@@ -136,6 +182,18 @@ inline BackwardSliceMatcher<Matcher> m_GetAllDefinitions(Matcher innerMatcher,
                                        false, false);
 }
 
+/// Matches all transitive defs of a top-level operation and stops where
+/// `filterMatcher` rejects.
+template <typename BaseMatcher, typename Filter>
+inline PredicateBackwardSliceMatcher<BaseMatcher, Filter>
+m_GetDefinitionsByPredicate(BaseMatcher innerMatcher, Filter filterMatcher,
+                            bool inclusive, bool omitBlockArguments,
+                            bool omitUsesFromAbove) {
+  return PredicateBackwardSliceMatcher<BaseMatcher, Filter>(
+      std::move(innerMatcher), std::move(filterMatcher), inclusive,
+      omitBlockArguments, omitUsesFromAbove);
+}
+
 } // namespace mlir::query::matcher
 
 #endif // MLIR_TOOLS_MLIRQUERY_MATCHERS_SLICEMATCHERS_H
diff --git a/mlir/include/mlir/Query/Matcher/VariantValue.h b/mlir/include/mlir/Query/Matcher/VariantValue.h
index 98c0a18e25101..1a47576de1841 100644
--- a/mlir/include/mlir/Query/Matcher/VariantValue.h
+++ b/mlir/include/mlir/Query/Matcher/VariantValue.h
@@ -26,7 +26,12 @@ enum class ArgKind { Boolean, Matcher, Signed, String };
 // A variant matcher object to abstract simple and complex matchers into a
 // single object type.
 class VariantMatcher {
-  class MatcherOps;
+  class MatcherOps {
+  public:
+    std::optional<DynMatcher>
+    constructVariadicOperator(DynMatcher::VariadicOperator varOp,
+                              ArrayRef<VariantMatcher> innerMatchers) const;
+  };
 
   // Payload interface to be specialized by each matcher type. It follows a
   // similar interface as VariantMatcher itself.
@@ -43,6 +48,9 @@ class VariantMatcher {
 
   // Clones the provided matcher.
   static VariantMatcher SingleMatcher(DynMatcher matcher);
+  static VariantMatcher
+  VariadicOperatorMatcher(DynMatcher::VariadicOperator varOp,
+                          ArrayRef<VariantMatcher> args);
 
   // Makes the matcher the "null" matcher.
   void reset();
@@ -61,6 +69,7 @@ class VariantMatcher {
       : value(std::move(value)) {}
 
   class SinglePayload;
+  class VariadicOpPayload;
 
   std::shared_ptr<const Payload> value;
 };
diff --git a/mlir/lib/Query/Matcher/CMakeLists.txt b/mlir/lib/Query/Matcher/CMakeLists.txt
index 629479bf7adc1..ba202762fdfbb 100644
--- a/mlir/lib/Query/Matcher/CMakeLists.txt
+++ b/mlir/lib/Query/Matcher/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_library(MLIRQueryMatcher
   MatchFinder.cpp
+  MatchersInternal.cpp
   Parser.cpp
   RegistryManager.cpp
   VariantValue.cpp
diff --git a/mlir/lib/Query/Matcher/MatchersInternal.cpp b/mlir/lib/Query/Matcher/MatchersInternal.cpp
new file mode 100644
index 0000000000000..e3593aa001f31
--- /dev/null
+++ b/mlir/lib/Query/Matcher/MatchersInternal.cpp
@@ -0,0 +1,30 @@
+//===--- MatchersInternal.cpp----------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Implements the base layer of the matcher framework.
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Query/Matcher/MatchersInternal.h"
+
+namespace mlir::query::matcher {
+
+namespace internal {
+
+bool allOfVariadicOperator(Operation *op, ArrayRef<DynMatcher> innerMatchers) {
+  return llvm::all_of(innerMatchers, [op](const DynMatcher &matcher) {
+    return matcher.match(op);
+  });
+}
+
+bool anyOfVariadicOperator(Operation *op, ArrayRef<DynMatcher> innerMatchers) {
+  return llvm::any_of(innerMatchers, [op](const DynMatcher &matcher) {
+    return matcher.match(op);
+  });
+}
+} // namespace internal
+} // namespace mlir::query::matcher
diff --git a/mlir/lib/Query/Matcher/RegistryManager.cpp b/mlir/lib/Query/Matcher/RegistryManager.cpp
index 4b511c5f009e7..08b610453b11a 100644
--- a/mlir/lib/Query/Matcher/RegistryManager.cpp
+++ b/mlir/lib/Query/Matcher/RegistryManager.cpp
@@ -64,7 +64,7 @@ std::vector<ArgKind> RegistryManager::getAcceptedCompletionTypes(
     unsigned argNumber = ctxEntry.second;
     std::vector<ArgKind> nextTypeSet;
 
-    if (argNumber < ctor->getNumArgs())
+    if (ctor->isVariadic() || argNumber < ctor->getNumArgs())
       ctor->getArgKinds(argNumber, nextTypeSet);
 
     typeSet.insert(nextTypeSet.begin(), nextTypeSet.end());
@@ -83,7 +83,7 @@ RegistryManager::getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes,
     const internal::MatcherDescriptor &matcher = *m.getValue();
     llvm::StringRef name = m.getKey();
 
-    unsigned numArgs = matcher.getNumArgs();
+    unsigned numArgs = matcher.isVariadic() ? 1 : matcher.getNumArgs();
     std::vector<std::vector<ArgKind>> argKinds(numArgs);
 
     for (const ArgKind &kind : acceptedTypes) {
@@ -115,6 +115,9 @@ RegistryManager::getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes,
       }
     }
 
+    if (matcher.isVariadic())
+      os << ",...";
+
     os << ")";
     typedText += "(";
 
diff --git a/mlir/lib/Query/Matcher/VariantValue.cpp b/mlir/lib/Query/Matcher/VariantValue.cpp
index 1cb2d48f9d56f..61316cfd0d489 100644
--- a/mlir/lib/Query/Matcher/VariantValue.cpp
+++ b/mlir/lib/Query/Matcher/VariantValue.cpp
@@ -27,12 +27,66 @@ class VariantMatcher::SinglePayload : public VariantMatcher::Payload {
   DynMatcher matcher;
 };
 
+class VariantMatcher::VariadicOpPayload : public VariantMatcher::Payload {
+public:
+  VariadicOpPayload(DynMatcher::VariadicOperator varOp,
+                    std::vector<VariantMatcher> args)
+      : varOp(varOp), args(std::move(args)) {}
+
+  std::optional<DynMatcher> getDynMatcher() const override {
+    std::vector<DynMatcher> dynMatchers;
+    for (auto variantMatcher : args) {
+      std::optional<DynMatcher> dynMatcher = variantMatcher.getDynMatcher();
+      if (dynMatcher)
+        dynMatchers.push_back(dynMatcher.value());
+    }
+    auto result = DynMatcher::constructVariadic(varOp, dynMatchers);
+    return *result;
+  }
+
+  std::string getTypeAsString() const override {
+    std::string inner;
+    for (size_t i = 0, e = args.size(); i != e; ++i) {
+      if (i != 0)
+        inner += "&";
+      inner += args[i].getTypeAsString();
+    }
+    return inner;
+  }
+
+private:
+  const DynMatcher::VariadicOperator varOp;
+  const std::vector<VariantMatcher> args;
+};
+
 VariantMatcher::VariantMatcher() = default;
 
 VariantMatcher VariantMatcher::SingleMatcher(DynMatcher matcher) {
   return VariantMatcher(std::make_shared<SinglePayload>(std::move(matcher)));
 }
 
+VariantMatcher
+VariantMatcher::VariadicOperatorMatcher(DynMatcher::VariadicOperator varOp,
+                                        ArrayRef<VariantMatcher> args) {
+  return VariantMatcher(
+      std::make_shared<VariadicOpPayload>(varOp, std::move(args)));
+}
+
+std::optional<DynMatcher> VariantMatcher::MatcherOps::constructVariadicOperator(
+    DynMatcher::VariadicOperator varOp,
+    ArrayRef<VariantMatcher> innerMatchers) const {
+  std::vector<DynMatcher> dynMatchers;
+  for (const auto &innerMatcher : innerMatchers) {
+    if (!innerMatcher.value)
+      return std::nullopt;
+    std::optional<DynMatcher> inner = innerMatcher.value->getDynMatcher();
+    if (!inner)
+      return std::nullopt;
+    dynMatchers.push_back(*inner);
+  }
+  return *DynMatcher::constructVariadic(varOp, dynMatchers);
+}
+
 std::optional<DynMatcher> VariantMatcher::getDynMatcher() const {
   return value...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented May 25, 2025

@llvm/pr-subscribers-mlir

Author: Denzel-Brian Budii (chios202)

Changes

Whereas backward-slice matching provides support to limit traversal by specifying the desired depth level, this pull request introduces support for limiting traversal with a nested matcher. It also adds support for variadic operators, including anyOf and allOf. Rather than simply stopping traversal when an operation named foo is encountered, you can now define a matcher that specifies different exit conditions.


Patch is 21.24 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/141423.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Query/Matcher/Marshallers.h (+61)
  • (modified) mlir/include/mlir/Query/Matcher/MatchersInternal.h (+103-4)
  • (modified) mlir/include/mlir/Query/Matcher/SliceMatchers.h (+63-5)
  • (modified) mlir/include/mlir/Query/Matcher/VariantValue.h (+10-1)
  • (modified) mlir/lib/Query/Matcher/CMakeLists.txt (+1)
  • (added) mlir/lib/Query/Matcher/MatchersInternal.cpp (+30)
  • (modified) mlir/lib/Query/Matcher/RegistryManager.cpp (+5-2)
  • (modified) mlir/lib/Query/Matcher/VariantValue.cpp (+54)
  • (modified) mlir/tools/mlir-query/mlir-query.cpp (+6)
diff --git a/mlir/include/mlir/Query/Matcher/Marshallers.h b/mlir/include/mlir/Query/Matcher/Marshallers.h
index 012bf7b9ec4a9..f81e789f274e6 100644
--- a/mlir/include/mlir/Query/Matcher/Marshallers.h
+++ b/mlir/include/mlir/Query/Matcher/Marshallers.h
@@ -108,6 +108,9 @@ class MatcherDescriptor {
                                 const llvm::ArrayRef<ParserValue> args,
                                 Diagnostics *error) const = 0;
 
+  // If the matcher is variadic, it can take any number of arguments.
+  virtual bool isVariadic() const = 0;
+
   // Returns the number of arguments accepted by the matcher.
   virtual unsigned getNumArgs() const = 0;
 
@@ -140,6 +143,8 @@ class FixedArgCountMatcherDescriptor : public MatcherDescriptor {
     return marshaller(matcherFunc, matcherName, nameRange, args, error);
   }
 
+  bool isVariadic() const override { return false; }
+
   unsigned getNumArgs() const override { return argKinds.size(); }
 
   void getArgKinds(unsigned argNo, std::vector<ArgKind> &kinds) const override {
@@ -153,6 +158,54 @@ class FixedArgCountMatcherDescriptor : public MatcherDescriptor {
   const std::vector<ArgKind> argKinds;
 };
 
+class VariadicOperatorMatcherDescriptor : public MatcherDescriptor {
+public:
+  using VarOp = DynMatcher::VariadicOperator;
+  VariadicOperatorMatcherDescriptor(unsigned minCount, unsigned maxCount,
+                                    VarOp varOp, StringRef matcherName)
+      : minCount(minCount), maxCount(maxCount), varOp(varOp),
+        matcherName(matcherName) {}
+
+  VariantMatcher create(SourceRange nameRange, ArrayRef<ParserValue> args,
+                        Diagnostics *error) const override {
+    if (args.size() < minCount || maxCount < args.size()) {
+      addError(error, nameRange, ErrorType::RegistryWrongArgCount,
+               {llvm::Twine("requires between "), llvm::Twine(minCount),
+                llvm::Twine(" and "), llvm::Twine(maxCount),
+                llvm::Twine(" args, got "), llvm::Twine(args.size())});
+      return VariantMatcher();
+    }
+
+    std::vector<VariantMatcher> innerArgs;
+    for (size_t i = 0, e = args.size(); i != e; ++i) {
+      const ParserValue &arg = args[i];
+      const VariantValue &value = arg.value;
+      if (!value.isMatcher()) {
+        addError(error, arg.range, ErrorType::RegistryWrongArgType,
+                 {llvm::Twine(i + 1), llvm::Twine("Matcher: "),
+                  llvm::Twine(value.getTypeAsString())});
+        return VariantMatcher();
+      }
+      innerArgs.push_back(value.getMatcher());
+    }
+    return VariantMatcher::VariadicOperatorMatcher(varOp, std::move(innerArgs));
+  }
+
+  bool isVariadic() const override { return true; }
+
+  unsigned getNumArgs() const override { return 0; }
+
+  void getArgKinds(unsigned argNo, std::vector<ArgKind> &kinds) const override {
+    kinds.push_back(ArgKind(ArgKind::Matcher));
+  }
+
+private:
+  const unsigned minCount;
+  const unsigned maxCount;
+  const VarOp varOp;
+  const StringRef matcherName;
+};
+
 // Helper function to check if argument count matches expected count
 inline bool checkArgCount(SourceRange nameRange, size_t expectedArgCount,
                           llvm::ArrayRef<ParserValue> args,
@@ -224,6 +277,14 @@ makeMatcherAutoMarshall(ReturnType (*matcherFunc)(ArgTypes...),
       reinterpret_cast<void (*)()>(matcherFunc), matcherName, argKinds);
 }
 
+// Variadic operator overload.
+template <unsigned MinCount, unsigned MaxCount>
+std::unique_ptr<MatcherDescriptor>
+makeMatcherAutoMarshall(VariadicOperatorMatcherFunc<MinCount, MaxCount> func,
+                        StringRef matcherName) {
+  return std::make_unique<VariadicOperatorMatcherDescriptor>(
+      MinCount, MaxCount, func.varOp, matcherName);
+}
 } // namespace mlir::query::matcher::internal
 
 #endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H
diff --git a/mlir/include/mlir/Query/Matcher/MatchersInternal.h b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
index 183b2514e109f..904103407611e 100644
--- a/mlir/include/mlir/Query/Matcher/MatchersInternal.h
+++ b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
@@ -8,11 +8,11 @@
 //
 // Implements the base layer of the matcher framework.
 //
-// Matchers are methods that return a Matcher which provides a method one of the
-// following methods: match(Operation *op), match(Operation *op,
-// SetVector<Operation *> &matchedOps)
+// Matchers are methods that return a Matcher which provide a method
+// `match(...)` method. The method's parameters define the context of the match.
+// Support includes simple (unary) matchers as well as matcher combinators.
+// (anyOf, allOf, etc.)
 //
-// The matcher functions are defined in include/mlir/IR/Matchers.h.
 // This file contains the wrapper classes needed to construct matchers for
 // mlir-query.
 //
@@ -25,6 +25,13 @@
 #include "llvm/ADT/IntrusiveRefCntPtr.h"
 
 namespace mlir::query::matcher {
+class DynMatcher;
+namespace internal {
+
+bool allOfVariadicOperator(Operation *op, ArrayRef<DynMatcher> innerMatchers);
+bool anyOfVariadicOperator(Operation *op, ArrayRef<DynMatcher> innerMatchers);
+
+} // namespace internal
 
 // Defaults to false if T has no match() method with the signature:
 // match(Operation* op).
@@ -84,6 +91,26 @@ class MatcherFnImpl : public MatcherInterface {
   MatcherFn matcherFn;
 };
 
+// VariadicMatcher takes a vector of Matchers and returns true if any Matchers
+// match the given operation.
+using VariadicOperatorFunction = bool (*)(Operation *op,
+                                          ArrayRef<DynMatcher> innerMatchers);
+
+template <VariadicOperatorFunction Func>
+class VariadicMatcher : public MatcherInterface {
+public:
+  VariadicMatcher(std::vector<DynMatcher> matchers) : matchers(matchers) {}
+
+  bool match(Operation *op) override { return Func(op, matchers); }
+  // Fallback case
+  bool match(Operation *op, SetVector<Operation *> &matchedOps) override {
+    return false;
+  }
+
+private:
+  std::vector<DynMatcher> matchers;
+};
+
 // Matcher wraps a MatcherInterface implementation and provides match()
 // methods that redirect calls to the underlying implementation.
 class DynMatcher {
@@ -92,6 +119,31 @@ class DynMatcher {
   DynMatcher(MatcherInterface *implementation)
       : implementation(implementation) {}
 
+  // Construct from a variadic function.
+  enum VariadicOperator {
+    // Matches operations for which all provided matchers match.
+    AllOf,
+    // Matches operations for which at least one of the provided matchers
+    // matches.
+    AnyOf
+  };
+
+  static std::unique_ptr<DynMatcher>
+  constructVariadic(VariadicOperator Op,
+                    std::vector<DynMatcher> innerMatchers) {
+    switch (Op) {
+    case AllOf:
+      return std::make_unique<DynMatcher>(
+          new VariadicMatcher<internal::allOfVariadicOperator>(
+              std::move(innerMatchers)));
+    case AnyOf:
+      return std::make_unique<DynMatcher>(
+          new VariadicMatcher<internal::anyOfVariadicOperator>(
+              std::move(innerMatchers)));
+    }
+    llvm_unreachable("Invalid Op value.");
+  }
+
   template <typename MatcherFn>
   static std::unique_ptr<DynMatcher>
   constructDynMatcherFromMatcherFn(MatcherFn &matcherFn) {
@@ -113,6 +165,53 @@ class DynMatcher {
   std::string functionName;
 };
 
+// VariadicOperatorMatcher related types.
+template <typename... Ps>
+class VariadicOperatorMatcher {
+public:
+  VariadicOperatorMatcher(DynMatcher::VariadicOperator varOp, Ps &&...params)
+      : varOp(varOp), params(std::forward<Ps>(params)...) {}
+
+  operator std::unique_ptr<DynMatcher>() const & {
+    return DynMatcher::constructVariadic(
+        varOp, getMatchers(std::index_sequence_for<Ps...>()));
+  }
+
+  operator std::unique_ptr<DynMatcher>() && {
+    return DynMatcher::constructVariadic(
+        varOp, std::move(*this).getMatchers(std::index_sequence_for<Ps...>()));
+  }
+
+private:
+  // Helper method to unpack the tuple into a vector.
+  template <std::size_t... Is>
+  std::vector<DynMatcher> getMatchers(std::index_sequence<Is...>) const & {
+    return {DynMatcher(std::get<Is>(params))...};
+  }
+
+  template <std::size_t... Is>
+  std::vector<DynMatcher> getMatchers(std::index_sequence<Is...>) && {
+    return {DynMatcher(std::get<Is>(std::move(params)))...};
+  }
+
+  const DynMatcher::VariadicOperator varOp;
+  std::tuple<Ps...> params;
+};
+
+// Overloaded function object to generate VariadicOperatorMatcher objects from
+// arbitrary matchers.
+template <unsigned MinCount, unsigned MaxCount>
+struct VariadicOperatorMatcherFunc {
+  DynMatcher::VariadicOperator varOp;
+
+  template <typename... Ms>
+  VariadicOperatorMatcher<Ms...> operator()(Ms &&...Ps) const {
+    static_assert(MinCount <= sizeof...(Ms) && sizeof...(Ms) <= MaxCount,
+                  "invalid number of parameters for variadic matcher");
+    return VariadicOperatorMatcher<Ms...>(varOp, std::forward<Ms>(Ps)...);
+  }
+};
+
 } // namespace mlir::query::matcher
 
 #endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
diff --git a/mlir/include/mlir/Query/Matcher/SliceMatchers.h b/mlir/include/mlir/Query/Matcher/SliceMatchers.h
index 5bb8251672eb7..fec46d2ff814d 100644
--- a/mlir/include/mlir/Query/Matcher/SliceMatchers.h
+++ b/mlir/include/mlir/Query/Matcher/SliceMatchers.h
@@ -6,7 +6,8 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This file provides matchers for MLIRQuery that peform slicing analysis
+// This file defines slicing-analysis matchers that extend and abstract the
+// core implementations from `SliceAnalysis.h`.
 //
 //===----------------------------------------------------------------------===//
 
@@ -15,9 +16,9 @@
 
 #include "mlir/Analysis/SliceAnalysis.h"
 
-/// A matcher encapsulating `getBackwardSlice` method from SliceAnalysis.h.
-/// Additionally, it limits the slice computation to a certain depth level using
-/// a custom filter.
+/// Computes the backward-slice of all transitive defs reachable from `rootOp`,
+/// if `innerMatcher` matches. The traversal stops once the desired depth level
+/// is reached.
 ///
 /// Example: starting from node 9, assuming the matcher
 /// computes the slice for the first two depth levels:
@@ -116,6 +117,51 @@ bool BackwardSliceMatcher<Matcher>::matches(
                            : backwardSlice.size() >= 1;
 }
 
+/// Computes the backward-slice of all transitive defs reachable from `rootOp`,
+/// if `innerMatcher` matches. Traversal stops where `filterMatcher` matches.
+template <typename BaseMatcher, typename Filter>
+class PredicateBackwardSliceMatcher {
+public:
+  PredicateBackwardSliceMatcher(BaseMatcher innerMatcher, Filter filterMatcher,
+                                bool inclusive, bool omitBlockArguments,
+                                bool omitUsesFromAbove)
+      : innerMatcher(std::move(innerMatcher)),
+        filterMatcher(std::move(filterMatcher)), inclusive(inclusive),
+        omitBlockArguments(omitBlockArguments),
+        omitUsesFromAbove(omitUsesFromAbove) {}
+
+  bool match(Operation *rootOp, SetVector<Operation *> &backwardSlice) {
+    backwardSlice.clear();
+    BackwardSliceOptions options;
+    options.inclusive = inclusive;
+    options.omitUsesFromAbove = omitUsesFromAbove;
+    options.omitBlockArguments = omitBlockArguments;
+    if (innerMatcher.match(rootOp)) {
+      options.filter = [&](Operation *subOp) {
+        return !filterMatcher.match(subOp);
+      };
+      getBackwardSlice(rootOp, &backwardSlice, options);
+      return options.inclusive ? backwardSlice.size() > 1
+                               : backwardSlice.size() >= 1;
+    }
+    return false;
+  }
+
+private:
+  BaseMatcher innerMatcher;
+  Filter filterMatcher;
+  bool inclusive;
+  bool omitBlockArguments;
+  bool omitUsesFromAbove;
+};
+
+const matcher::VariadicOperatorMatcherFunc<1,
+                                           std::numeric_limits<unsigned>::max()>
+    anyOf = {matcher::DynMatcher::AnyOf};
+const matcher::VariadicOperatorMatcherFunc<1,
+                                           std::numeric_limits<unsigned>::max()>
+    allOf = {matcher::DynMatcher::AllOf};
+
 /// Matches transitive defs of a top-level operation up to N levels.
 template <typename Matcher>
 inline BackwardSliceMatcher<Matcher>
@@ -127,7 +173,7 @@ m_GetDefinitions(Matcher innerMatcher, int64_t maxDepth, bool inclusive,
                                        omitUsesFromAbove);
 }
 
-/// Matches all transitive defs of a top-level operation up to N levels
+/// Matches all transitive defs of a top-level operation up to N levels.
 template <typename Matcher>
 inline BackwardSliceMatcher<Matcher> m_GetAllDefinitions(Matcher innerMatcher,
                                                          int64_t maxDepth) {
@@ -136,6 +182,18 @@ inline BackwardSliceMatcher<Matcher> m_GetAllDefinitions(Matcher innerMatcher,
                                        false, false);
 }
 
+/// Matches all transitive defs of a top-level operation and stops where
+/// `filterMatcher` rejects.
+template <typename BaseMatcher, typename Filter>
+inline PredicateBackwardSliceMatcher<BaseMatcher, Filter>
+m_GetDefinitionsByPredicate(BaseMatcher innerMatcher, Filter filterMatcher,
+                            bool inclusive, bool omitBlockArguments,
+                            bool omitUsesFromAbove) {
+  return PredicateBackwardSliceMatcher<BaseMatcher, Filter>(
+      std::move(innerMatcher), std::move(filterMatcher), inclusive,
+      omitBlockArguments, omitUsesFromAbove);
+}
+
 } // namespace mlir::query::matcher
 
 #endif // MLIR_TOOLS_MLIRQUERY_MATCHERS_SLICEMATCHERS_H
diff --git a/mlir/include/mlir/Query/Matcher/VariantValue.h b/mlir/include/mlir/Query/Matcher/VariantValue.h
index 98c0a18e25101..1a47576de1841 100644
--- a/mlir/include/mlir/Query/Matcher/VariantValue.h
+++ b/mlir/include/mlir/Query/Matcher/VariantValue.h
@@ -26,7 +26,12 @@ enum class ArgKind { Boolean, Matcher, Signed, String };
 // A variant matcher object to abstract simple and complex matchers into a
 // single object type.
 class VariantMatcher {
-  class MatcherOps;
+  class MatcherOps {
+  public:
+    std::optional<DynMatcher>
+    constructVariadicOperator(DynMatcher::VariadicOperator varOp,
+                              ArrayRef<VariantMatcher> innerMatchers) const;
+  };
 
   // Payload interface to be specialized by each matcher type. It follows a
   // similar interface as VariantMatcher itself.
@@ -43,6 +48,9 @@ class VariantMatcher {
 
   // Clones the provided matcher.
   static VariantMatcher SingleMatcher(DynMatcher matcher);
+  static VariantMatcher
+  VariadicOperatorMatcher(DynMatcher::VariadicOperator varOp,
+                          ArrayRef<VariantMatcher> args);
 
   // Makes the matcher the "null" matcher.
   void reset();
@@ -61,6 +69,7 @@ class VariantMatcher {
       : value(std::move(value)) {}
 
   class SinglePayload;
+  class VariadicOpPayload;
 
   std::shared_ptr<const Payload> value;
 };
diff --git a/mlir/lib/Query/Matcher/CMakeLists.txt b/mlir/lib/Query/Matcher/CMakeLists.txt
index 629479bf7adc1..ba202762fdfbb 100644
--- a/mlir/lib/Query/Matcher/CMakeLists.txt
+++ b/mlir/lib/Query/Matcher/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_library(MLIRQueryMatcher
   MatchFinder.cpp
+  MatchersInternal.cpp
   Parser.cpp
   RegistryManager.cpp
   VariantValue.cpp
diff --git a/mlir/lib/Query/Matcher/MatchersInternal.cpp b/mlir/lib/Query/Matcher/MatchersInternal.cpp
new file mode 100644
index 0000000000000..e3593aa001f31
--- /dev/null
+++ b/mlir/lib/Query/Matcher/MatchersInternal.cpp
@@ -0,0 +1,30 @@
+//===--- MatchersInternal.cpp----------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Implements the base layer of the matcher framework.
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Query/Matcher/MatchersInternal.h"
+
+namespace mlir::query::matcher {
+
+namespace internal {
+
+bool allOfVariadicOperator(Operation *op, ArrayRef<DynMatcher> innerMatchers) {
+  return llvm::all_of(innerMatchers, [op](const DynMatcher &matcher) {
+    return matcher.match(op);
+  });
+}
+
+bool anyOfVariadicOperator(Operation *op, ArrayRef<DynMatcher> innerMatchers) {
+  return llvm::any_of(innerMatchers, [op](const DynMatcher &matcher) {
+    return matcher.match(op);
+  });
+}
+} // namespace internal
+} // namespace mlir::query::matcher
diff --git a/mlir/lib/Query/Matcher/RegistryManager.cpp b/mlir/lib/Query/Matcher/RegistryManager.cpp
index 4b511c5f009e7..08b610453b11a 100644
--- a/mlir/lib/Query/Matcher/RegistryManager.cpp
+++ b/mlir/lib/Query/Matcher/RegistryManager.cpp
@@ -64,7 +64,7 @@ std::vector<ArgKind> RegistryManager::getAcceptedCompletionTypes(
     unsigned argNumber = ctxEntry.second;
     std::vector<ArgKind> nextTypeSet;
 
-    if (argNumber < ctor->getNumArgs())
+    if (ctor->isVariadic() || argNumber < ctor->getNumArgs())
       ctor->getArgKinds(argNumber, nextTypeSet);
 
     typeSet.insert(nextTypeSet.begin(), nextTypeSet.end());
@@ -83,7 +83,7 @@ RegistryManager::getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes,
     const internal::MatcherDescriptor &matcher = *m.getValue();
     llvm::StringRef name = m.getKey();
 
-    unsigned numArgs = matcher.getNumArgs();
+    unsigned numArgs = matcher.isVariadic() ? 1 : matcher.getNumArgs();
     std::vector<std::vector<ArgKind>> argKinds(numArgs);
 
     for (const ArgKind &kind : acceptedTypes) {
@@ -115,6 +115,9 @@ RegistryManager::getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes,
       }
     }
 
+    if (matcher.isVariadic())
+      os << ",...";
+
     os << ")";
     typedText += "(";
 
diff --git a/mlir/lib/Query/Matcher/VariantValue.cpp b/mlir/lib/Query/Matcher/VariantValue.cpp
index 1cb2d48f9d56f..61316cfd0d489 100644
--- a/mlir/lib/Query/Matcher/VariantValue.cpp
+++ b/mlir/lib/Query/Matcher/VariantValue.cpp
@@ -27,12 +27,66 @@ class VariantMatcher::SinglePayload : public VariantMatcher::Payload {
   DynMatcher matcher;
 };
 
+class VariantMatcher::VariadicOpPayload : public VariantMatcher::Payload {
+public:
+  VariadicOpPayload(DynMatcher::VariadicOperator varOp,
+                    std::vector<VariantMatcher> args)
+      : varOp(varOp), args(std::move(args)) {}
+
+  std::optional<DynMatcher> getDynMatcher() const override {
+    std::vector<DynMatcher> dynMatchers;
+    for (auto variantMatcher : args) {
+      std::optional<DynMatcher> dynMatcher = variantMatcher.getDynMatcher();
+      if (dynMatcher)
+        dynMatchers.push_back(dynMatcher.value());
+    }
+    auto result = DynMatcher::constructVariadic(varOp, dynMatchers);
+    return *result;
+  }
+
+  std::string getTypeAsString() const override {
+    std::string inner;
+    for (size_t i = 0, e = args.size(); i != e; ++i) {
+      if (i != 0)
+        inner += "&";
+      inner += args[i].getTypeAsString();
+    }
+    return inner;
+  }
+
+private:
+  const DynMatcher::VariadicOperator varOp;
+  const std::vector<VariantMatcher> args;
+};
+
 VariantMatcher::VariantMatcher() = default;
 
 VariantMatcher VariantMatcher::SingleMatcher(DynMatcher matcher) {
   return VariantMatcher(std::make_shared<SinglePayload>(std::move(matcher)));
 }
 
+VariantMatcher
+VariantMatcher::VariadicOperatorMatcher(DynMatcher::VariadicOperator varOp,
+                                        ArrayRef<VariantMatcher> args) {
+  return VariantMatcher(
+      std::make_shared<VariadicOpPayload>(varOp, std::move(args)));
+}
+
+std::optional<DynMatcher> VariantMatcher::MatcherOps::constructVariadicOperator(
+    DynMatcher::VariadicOperator varOp,
+    ArrayRef<VariantMatcher> innerMatchers) const {
+  std::vector<DynMatcher> dynMatchers;
+  for (const auto &innerMatcher : innerMatchers) {
+    if (!innerMatcher.value)
+      return std::nullopt;
+    std::optional<DynMatcher> inner = innerMatcher.value->getDynMatcher();
+    if (!inner)
+      return std::nullopt;
+    dynMatchers.push_back(*inner);
+  }
+  return *DynMatcher::constructVariadic(varOp, dynMatchers);
+}
+
 std::optional<DynMatcher> VariantMatcher::getDynMatcher() const {
   return value...
[truncated]

@chios202 chios202 changed the title Improve MLIR-Query by adding matcher combinators Improve mlir-query by adding matcher combinators May 25, 2025
@chios202 chios202 force-pushed the Improve-Mlir-Query-with-matcher-combinators branch from 4680123 to 02f731d Compare May 25, 2025 18:22
bool omitUsesFromAbove;
};

const matcher::VariadicOperatorMatcherFunc<1,
Copy link
Contributor Author

@chios202 chios202 May 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking this file could be renamed since it no longer contains only slicing matchers?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only one that is not slicing related correct? Lets just move it as the others are still slicing related all and this one has no dependencies on analysis (putting adjacent to DynMatcher in MatchersInternal.h)

Copy link
Contributor Author

@chios202 chios202 May 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved them, right below VariadicOperatorMatcherFunc because they need to see its definition. I could relocate their declarations adjacent to DynMatcher and then keep their definitions at the end of the file

@chios202 chios202 force-pushed the Improve-Mlir-Query-with-matcher-combinators branch 4 times, most recently from 089aecf to 11792a6 Compare May 26, 2025 14:16
Copy link
Member

@jpienaar jpienaar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a test/example?

bool omitUsesFromAbove;
};

const matcher::VariadicOperatorMatcherFunc<1,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only one that is not slicing related correct? Lets just move it as the others are still slicing related all and this one has no dependencies on analysis (putting adjacent to DynMatcher in MatchersInternal.h)

//
//===----------------------------------------------------------------------===//
//
// Implements the base layer of the matcher framework.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy pasta?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, I removed the description


std::string getTypeAsString() const override {
std::string inner;
for (size_t i = 0, e = args.size(); i != e; ++i) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

llvm::interleave could help here. Is spaces needed around the & ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion. Some space around & would be nicer too

@chios202 chios202 changed the title Improve mlir-query by adding matcher combinators [mlir] Improve mlir-query by adding matcher combinators May 27, 2025
@chios202 chios202 requested a review from jpienaar May 27, 2025 19:17
@chios202 chios202 force-pushed the Improve-Mlir-Query-with-matcher-combinators branch 2 times, most recently from ace9ebd to fec7c92 Compare June 1, 2025 15:20
@chios202
Copy link
Contributor Author

chios202 commented Jun 2, 2025

@jpienaar could you provide feedback on the latest changes?

@chios202 chios202 force-pushed the Improve-Mlir-Query-with-matcher-combinators branch 5 times, most recently from b5267f3 to 326c2a9 Compare June 4, 2025 19:22
chios202 added 4 commits June 5, 2025 07:18
	Limit backward-slice with nested matching
	Add variadic operators
	Add test cases for variadic matchers
	Relocate variadic matchers
@chios202 chios202 force-pushed the Improve-Mlir-Query-with-matcher-combinators branch from 326c2a9 to 1522646 Compare June 5, 2025 07:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants