Skip to content

[mlir] Improve mlir-query tool by implementing getBackwardSlice matcher #115670

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

chios202
Copy link
Contributor

@chios202 chios202 commented Nov 10, 2024

Improve mlir-query tool by implementing getBackwardSlice matcher

Note: backwardSlice and forwardSlice algoritms are the same as the ones in mlir/lib/Analysis/SliceAnalysis.cpp
Example of current matcher. The query was made to the file: mlir/test/mlir-query/complex-test.mlir

./mlir-query /home/dbudii/personal/llvm-project/mlir/test/mlir-query/complex-test.mlir -c "match getDefinitions(hasOpName(\"arith.add
f\"),2)"

Match #1:

/home/dbudii/personal/llvm-project/mlir/test/mlir-query/complex-test.mlir:5:8:
  %0 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) {
       ^
/home/dbudii/personal/llvm-project/mlir/test/mlir-query/complex-test.mlir:7:10: note: "root" binds here
    %2 = arith.addf %in, %in : f32
         ^
Match #2:

/home/dbudii/personal/llvm-project/mlir/test/mlir-query/complex-test.mlir:10:16:
  %collapsed = tensor.collapse_shape %0 [[0, 1]] : tensor<5x5xf32> into tensor<25xf32>
               ^
/home/dbudii/personal/llvm-project/mlir/test/mlir-query/complex-test.mlir:13:11:
    %c2 = arith.constant 2 : index
          ^
/home/dbudii/personal/llvm-project/mlir/test/mlir-query/complex-test.mlir:14:18:
    %extracted = tensor.extract %collapsed[%c2] : tensor<25xf32>
                 ^
/home/dbudii/personal/llvm-project/mlir/test/mlir-query/complex-test.mlir:15:10: note: "root" binds here
    %2 = arith.addf %extracted, %extracted : f32
         ^
2 matches.

Copy link

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@chios202 chios202 changed the title MLIR-QUERY DefinitionsMatcher implementation & DAG [mlir] MLIR-QUERY DefinitionsMatcher implementation & DAG Nov 11, 2024
Copy link
Member

@jpienaar jpienaar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick skim comments.

@dbudii dbudii force-pushed the DefiningOpsMatcher_DAG_initial_implementation branch 2 times, most recently from bafbd37 to 88a01fd Compare January 20, 2025 21:49
@chios202 chios202 marked this pull request as ready for review January 20, 2025 21:50
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Jan 20, 2025
@llvmbot
Copy link
Member

llvmbot commented Jan 20, 2025

@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Denzel-Brian Budii (chios202)

Changes

This Pull Request aims to improve MLIR-QUERY tool by implementing getBackwardSlice and getForwardSlice matchers. As an addition SetQuery also needed to be added to enable custom configuration for each query. e.g: inclusive, omitUsesFromAbove, omitBlockArguments.

Example of current matcher. The query was made to the file: mlir/test/mlir-query/complex-test.mlir

mlir-query&gt; match getDefinitions(hasOpName("arith.addf"),2)
Match #<!-- -->1:

/home/dbudii/personal/llvm-project-fork/mlir/test/mlir-query/complex-test.mlir:5:8: note: "root" binds here
  %0 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor&lt;5x5xf32&gt;) outs(%arg1 : tensor&lt;5x5xf32&gt;) {
       ^
/home/dbudii/personal/llvm-project-fork/mlir/test/mlir-query/complex-test.mlir:7:10: note: "root" binds here
    %2 = arith.addf %in, %in : f32
         ^
Match #<!-- -->2:

/home/dbudii/personal/llvm-project-fork/mlir/test/mlir-query/complex-test.mlir:10:16: note: "root" binds here
  %collapsed = tensor.collapse_shape %0 [[0, 1]] : tensor&lt;5x5xf32&gt; into tensor&lt;25xf32&gt;
               ^
/home/dbudii/personal/llvm-project-fork/mlir/test/mlir-query/complex-test.mlir:13:11: note: "root" binds here
    %c2 = arith.constant 2 : index
          ^
/home/dbudii/personal/llvm-project-fork/mlir/test/mlir-query/complex-test.mlir:14:18: note: "root" binds here
    %extracted = tensor.extract %collapsed[%c2] : tensor&lt;25xf32&gt;
                 ^
/home/dbudii/personal/llvm-project-fork/mlir/test/mlir-query/complex-test.mlir:15:10: note: "root" binds here
    %2 = arith.addf %extracted, %extracted : f32
         ^
mlir-query&gt; 

Patch is 34.52 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/115670.diff

17 Files Affected:

  • (modified) mlir/include/mlir/IR/Matchers.h (+2-2)
  • (added) mlir/include/mlir/Query/Matcher/ExtraMatchers.h (+180)
  • (modified) mlir/include/mlir/Query/Matcher/Marshallers.h (+15)
  • (modified) mlir/include/mlir/Query/Matcher/MatchFinder.h (+39-11)
  • (modified) mlir/include/mlir/Query/Matcher/MatchersInternal.h (+49-11)
  • (modified) mlir/include/mlir/Query/Matcher/VariantValue.h (+11-1)
  • (modified) mlir/include/mlir/Query/Query.h (+48-1)
  • (modified) mlir/include/mlir/Query/QuerySession.h (+10-1)
  • (modified) mlir/lib/Query/Matcher/Parser.cpp (+48-6)
  • (modified) mlir/lib/Query/Matcher/RegistryManager.cpp (+2)
  • (modified) mlir/lib/Query/Matcher/VariantValue.cpp (+24)
  • (modified) mlir/lib/Query/Query.cpp (+26-20)
  • (modified) mlir/lib/Query/QueryParser.cpp (+76-1)
  • (modified) mlir/lib/Query/QueryParser.h (+1-1)
  • (added) mlir/test/mlir-query/complex-test.mlir (+22)
  • (modified) mlir/test/mlir-query/function-extraction.mlir (+2-2)
  • (modified) mlir/tools/mlir-query/mlir-query.cpp (+10-1)
diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index 1dce055db1b4a7..2204a68be26b10 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -59,7 +59,7 @@ struct NameOpMatcher {
   NameOpMatcher(StringRef name) : name(name) {}
   bool match(Operation *op) { return op->getName().getStringRef() == name; }
 
-  StringRef name;
+  std::string name;
 };
 
 /// The matcher that matches operations that have the specified attribute name.
@@ -67,7 +67,7 @@ struct AttrOpMatcher {
   AttrOpMatcher(StringRef attrName) : attrName(attrName) {}
   bool match(Operation *op) { return op->hasAttr(attrName); }
 
-  StringRef attrName;
+  std::string attrName;
 };
 
 /// The matcher that matches operations that have the `ConstantLike` trait, and
diff --git a/mlir/include/mlir/Query/Matcher/ExtraMatchers.h b/mlir/include/mlir/Query/Matcher/ExtraMatchers.h
new file mode 100644
index 00000000000000..57adc3241b0bef
--- /dev/null
+++ b/mlir/include/mlir/Query/Matcher/ExtraMatchers.h
@@ -0,0 +1,180 @@
+//===- ExtraMatchers.h - Various common matchers ---------------------*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file provides extra matchers that are very useful for mlir-query
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_EXTRAMATCHERS_H
+#define MLIR_IR_EXTRAMATCHERS_H
+
+#include "MatchFinder.h"
+#include "MatchersInternal.h"
+#include "mlir/IR/Region.h"
+#include "mlir/Query/Query.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace mlir {
+
+namespace query {
+
+namespace extramatcher {
+
+namespace detail {
+
+class BackwardSliceMatcher {
+public:
+  BackwardSliceMatcher(matcher::DynMatcher &&innerMatcher, unsigned hops)
+      : innerMatcher(std::move(innerMatcher)), hops(hops) {}
+
+private:
+  bool matches(Operation *op, SetVector<Operation *> &backwardSlice,
+               QueryOptions &options, unsigned tempHops) {
+
+    if (op->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
+      return false;
+    }
+
+    auto processValue = [&](Value value) {
+      if (tempHops == 0) {
+        return;
+      }
+      if (auto *definingOp = value.getDefiningOp()) {
+        if (backwardSlice.count(definingOp) == 0)
+          matches(definingOp, backwardSlice, options, tempHops - 1);
+      } else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
+        if (options.omitBlockArguments)
+          return;
+        Block *block = blockArg.getOwner();
+
+        Operation *parentOp = block->getParentOp();
+
+        if (parentOp && backwardSlice.count(parentOp) == 0) {
+          assert(parentOp->getNumRegions() == 1 &&
+                 parentOp->getRegion(0).getBlocks().size() == 1);
+          matches(parentOp, backwardSlice, options, tempHops-1);
+        }
+      } else {
+        llvm_unreachable("No definingOp and not a block argument.");
+      }
+    };
+
+    if (!options.omitUsesFromAbove) {
+      llvm::for_each(op->getRegions(), [&](Region &region) {
+        SmallPtrSet<Region *, 4> descendents;
+        region.walk(
+            [&](Region *childRegion) { descendents.insert(childRegion); });
+        region.walk([&](Operation *op) {
+          for (OpOperand &operand : op->getOpOperands()) {
+            if (!descendents.contains(operand.get().getParentRegion()))
+              processValue(operand.get());
+          }
+        });
+      });
+    }
+
+    llvm::for_each(op->getOperands(), processValue);
+    backwardSlice.insert(op);
+    return true;
+  }
+
+public:
+  bool match(Operation *op, SetVector<Operation *> &backwardSlice,
+             QueryOptions &options) {
+    if (innerMatcher.match(op) && matches(op, backwardSlice, options, hops)) {
+      if (!options.inclusive) {
+        backwardSlice.remove(op);
+      }
+      return true;
+    }
+    return false;
+  }
+
+private:
+  matcher::DynMatcher innerMatcher;
+  unsigned hops;
+};
+
+class ForwardSliceMatcher {
+public:
+  ForwardSliceMatcher(matcher::DynMatcher &&innerMatcher, unsigned hops)
+      : innerMatcher(std::move(innerMatcher)), hops(hops) {}
+
+private:
+  bool matches(Operation *op, SetVector<Operation *> &forwardSlice,
+               QueryOptions &options, unsigned tempHops) {
+
+    if (tempHops == 0) {
+      forwardSlice.insert(op);
+      return true;
+    }
+
+    for (Region &region : op->getRegions())
+      for (Block &block : region)
+        for (Operation &blockOp : block)
+          if (forwardSlice.count(&blockOp) == 0)
+            matches(&blockOp, forwardSlice, options, tempHops - 1);
+    for (Value result : op->getResults()) {
+      for (Operation *userOp : result.getUsers())
+        if (forwardSlice.count(userOp) == 0)
+          matches(userOp, forwardSlice, options, tempHops - 1);
+    }
+
+    forwardSlice.insert(op);
+    return true;
+  }
+
+public:
+  bool match(Operation *op, SetVector<Operation *> &forwardSlice,
+             QueryOptions &options) {
+    if (innerMatcher.match(op) && matches(op, forwardSlice, options, hops)) {
+      if (!options.inclusive) {
+        forwardSlice.remove(op);
+      }
+      SmallVector<Operation *, 0> v(forwardSlice.takeVector());
+      forwardSlice.insert(v.rbegin(), v.rend());
+      return true;
+    }
+    return false;
+  }
+
+private:
+  matcher::DynMatcher innerMatcher;
+  unsigned hops;
+};
+
+} // namespace detail
+
+inline detail::BackwardSliceMatcher
+definedBy(mlir::query::matcher::DynMatcher innerMatcher) {
+  return detail::BackwardSliceMatcher(std::move(innerMatcher), 1);
+}
+
+inline detail::BackwardSliceMatcher
+getDefinitions(mlir::query::matcher::DynMatcher innerMatcher, unsigned hops) {
+  return detail::BackwardSliceMatcher(std::move(innerMatcher), hops);
+}
+
+inline detail::ForwardSliceMatcher
+usedBy(mlir::query::matcher::DynMatcher innerMatcher) {
+  return detail::ForwardSliceMatcher(std::move(innerMatcher), 1);
+}
+
+inline detail::ForwardSliceMatcher
+getUses(mlir::query::matcher::DynMatcher innerMatcher, unsigned hops) {
+  return detail::ForwardSliceMatcher(std::move(innerMatcher), hops);
+}
+
+} // namespace extramatcher
+
+} // namespace query
+
+} // namespace mlir
+
+#endif // MLIR_IR_EXTRAMATCHERS_H
diff --git a/mlir/include/mlir/Query/Matcher/Marshallers.h b/mlir/include/mlir/Query/Matcher/Marshallers.h
index 6ed35ac0ddccc7..c775dbc5c86da0 100644
--- a/mlir/include/mlir/Query/Matcher/Marshallers.h
+++ b/mlir/include/mlir/Query/Matcher/Marshallers.h
@@ -50,6 +50,21 @@ struct ArgTypeTraits<llvm::StringRef> {
   }
 };
 
+template <>
+struct ArgTypeTraits<unsigned> {
+  static bool hasCorrectType(const VariantValue &value) {
+    return value.isUnsigned();
+  }
+
+  static unsigned get(const VariantValue &value) { return value.getUnsigned(); }
+
+  static ArgKind getKind() { return ArgKind::Unsigned; }
+
+  static std::optional<std::string> getBestGuess(const VariantValue &) {
+    return std::nullopt;
+  }
+};
+
 template <>
 struct ArgTypeTraits<DynMatcher> {
 
diff --git a/mlir/include/mlir/Query/Matcher/MatchFinder.h b/mlir/include/mlir/Query/Matcher/MatchFinder.h
index b008a21f53ae2a..2175db86a91bdf 100644
--- a/mlir/include/mlir/Query/Matcher/MatchFinder.h
+++ b/mlir/include/mlir/Query/Matcher/MatchFinder.h
@@ -7,7 +7,7 @@
 //===----------------------------------------------------------------------===//
 //
 // This file contains the MatchFinder class, which is used to find operations
-// that match a given matcher.
+// that match a given matcher and print them.
 //
 //===----------------------------------------------------------------------===//
 
@@ -15,24 +15,52 @@
 #define MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H
 
 #include "MatchersInternal.h"
+#include "mlir/Query/QuerySession.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/raw_ostream.h"
 
 namespace mlir::query::matcher {
 
-// MatchFinder is used to find all operations that match a given matcher.
 class MatchFinder {
+private:
+  static void printMatch(llvm::raw_ostream &os, QuerySession &qs,
+                         mlir::Operation *op, const std::string &binding) {
+    auto fileLoc = op->getLoc()->findInstanceOf<FileLineColLoc>();
+    auto smloc = qs.getSourceManager().FindLocForLineAndColumn(
+        qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn());
+    qs.getSourceManager().PrintMessage(os, smloc, llvm::SourceMgr::DK_Note,
+                                       "\"" + binding + "\" binds here");
+  };
+
 public:
-  // Returns all operations that match the given matcher.
-  static std::vector<Operation *> getMatches(Operation *root,
-                                             DynMatcher matcher) {
-    std::vector<Operation *> matches;
+  static std::vector<Operation *>
+  getMatches(Operation *root, QueryOptions &options, DynMatcher matcher,
+             llvm::raw_ostream &os, QuerySession &qs) {
+    unsigned matchCount = 0;
+    std::vector<Operation *> matchedOps;
+    SetVector<Operation *> tempStorage;
 
-    // Simple match finding with walk.
     root->walk([&](Operation *subOp) {
-      if (matcher.match(subOp))
-        matches.push_back(subOp);
-    });
+      if (matcher.match(subOp)) {
+        matchedOps.push_back(subOp);
+        os << "Match #" << ++matchCount << ":\n\n";
+        printMatch(os, qs, subOp, "root");
+      } else {
+        SmallVector<Operation *> printingOps;
 
-    return matches;
+        if (matcher.match(subOp, tempStorage, options)) {
+          os << "Match #" << ++matchCount << ":\n\n";
+          SmallVector<Operation *> printingOps(tempStorage.takeVector());
+          for (auto op : printingOps) {
+            printMatch(os, qs, op, "root");
+            matchedOps.push_back(op);
+          }
+          printingOps.clear();
+        }
+      }
+    });
+    return matchedOps;
   }
 };
 
diff --git a/mlir/include/mlir/Query/Matcher/MatchersInternal.h b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
index 117f7d4edef9e3..b532b47be7d051 100644
--- a/mlir/include/mlir/Query/Matcher/MatchersInternal.h
+++ b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
@@ -1,4 +1,3 @@
-//===- MatchersInternal.h - Structural query framework ----------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -10,28 +9,53 @@
 //
 // Matchers are methods that return a Matcher which provides a method
 // match(Operation *op)
+// match(Operation *op, SetVector<Operation *> &matchedOps, QueryOptions
+// &options)
 //
 // The matcher functions are defined in include/mlir/IR/Matchers.h.
 // This file contains the wrapper classes needed to construct matchers for
 // mlir-query.
 //
 //===----------------------------------------------------------------------===//
-
 #ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
 #define MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
 
 #include "mlir/IR/Matchers.h"
 #include "llvm/ADT/IntrusiveRefCntPtr.h"
 
+namespace mlir {
+namespace query {
+struct QueryOptions;
+}
+} // namespace mlir
+
 namespace mlir::query::matcher {
+template <typename T, typename = void>
+struct has_simple_match : std::false_type {};
+
+template <typename T>
+struct has_simple_match<T, std::void_t<decltype(std::declval<T>().match(
+                               std::declval<Operation *>()))>>
+    : std::true_type {};
+
+template <typename T, typename = void>
+struct has_bound_match : std::false_type {};
+
+template <typename T>
+struct has_bound_match<T, std::void_t<decltype(std::declval<T>().match(
+                              std::declval<Operation *>(),
+                              std::declval<SetVector<Operation *> &>(),
+                              std::declval<QueryOptions &>()))>>
+    : std::true_type {};
 
 // Generic interface for matchers on an MLIR operation.
 class MatcherInterface
     : public llvm::ThreadSafeRefCountedBase<MatcherInterface> {
 public:
   virtual ~MatcherInterface() = default;
-
   virtual bool match(Operation *op) = 0;
+  virtual bool match(Operation *op, SetVector<Operation *> &matchedOps,
+                     QueryOptions &options) = 0;
 };
 
 // MatcherFnImpl takes a matcher function object and implements
@@ -40,14 +64,26 @@ template <typename MatcherFn>
 class MatcherFnImpl : public MatcherInterface {
 public:
   MatcherFnImpl(MatcherFn &matcherFn) : matcherFn(matcherFn) {}
-  bool match(Operation *op) override { return matcherFn.match(op); }
+
+  bool match(Operation *op) override {
+    if constexpr (has_simple_match<MatcherFn>::value)
+      return matcherFn.match(op);
+    return false;
+  }
+
+  bool match(Operation *op, SetVector<Operation *> &matchedOps,
+             QueryOptions &options) override {
+    if constexpr (has_bound_match<MatcherFn>::value)
+      return matcherFn.match(op, matchedOps, options);
+    return false;
+  }
 
 private:
   MatcherFn matcherFn;
 };
 
-// Matcher wraps a MatcherInterface implementation and provides a match()
-// method that redirects calls to the underlying implementation.
+// Matcher wraps a MatcherInterface implementation and provides match()
+// methods that redirect calls to the underlying implementation.
 class DynMatcher {
 public:
   // Takes ownership of the provided implementation pointer.
@@ -62,12 +98,14 @@ class DynMatcher {
   }
 
   bool match(Operation *op) const { return implementation->match(op); }
+  bool match(Operation *op, SetVector<Operation *> &matchedOps,
+             QueryOptions &options) const {
+    return implementation->match(op, matchedOps, options);
+  }
 
-  void setFunctionName(StringRef name) { functionName = name.str(); };
-
-  bool hasFunctionName() const { return !functionName.empty(); };
-
-  StringRef getFunctionName() const { return functionName; };
+  void setFunctionName(StringRef name) { functionName = name.str(); }
+  bool hasFunctionName() const { return !functionName.empty(); }
+  StringRef getFunctionName() const { return functionName; }
 
 private:
   llvm::IntrusiveRefCntPtr<MatcherInterface> implementation;
diff --git a/mlir/include/mlir/Query/Matcher/VariantValue.h b/mlir/include/mlir/Query/Matcher/VariantValue.h
index 449f8b3a01e021..6b57119df7a9bf 100644
--- a/mlir/include/mlir/Query/Matcher/VariantValue.h
+++ b/mlir/include/mlir/Query/Matcher/VariantValue.h
@@ -21,7 +21,7 @@
 namespace mlir::query::matcher {
 
 // All types that VariantValue can contain.
-enum class ArgKind { Matcher, String };
+enum class ArgKind { Matcher, String, Unsigned };
 
 // A variant matcher object to abstract simple and complex matchers into a
 // single object type.
@@ -81,6 +81,7 @@ class VariantValue {
   // Specific constructors for each supported type.
   VariantValue(const llvm::StringRef string);
   VariantValue(const VariantMatcher &matcher);
+  VariantValue(unsigned Unsigned);
 
   // String value functions.
   bool isString() const;
@@ -92,8 +93,15 @@ class VariantValue {
   const VariantMatcher &getMatcher() const;
   void setMatcher(const VariantMatcher &matcher);
 
+  // Unsigned value functions.
+  bool isUnsigned() const;
+  unsigned getUnsigned() const;
+  void setUnsigned(unsigned Unsigned);
+
   // String representation of the type of the value.
   std::string getTypeAsString() const;
+  explicit operator bool() const { return hasValue(); }
+  bool hasValue() const { return type != ValueType::Nothing; }
 
 private:
   void reset();
@@ -103,12 +111,14 @@ class VariantValue {
     Nothing,
     String,
     Matcher,
+    Unsigned,
   };
 
   // All supported value types.
   union AllValues {
     llvm::StringRef *String;
     VariantMatcher *Matcher;
+    unsigned Unsigned;
   };
 
   ValueType type;
diff --git a/mlir/include/mlir/Query/Query.h b/mlir/include/mlir/Query/Query.h
index 18f2172c9510a3..89d48773d2c3e6 100644
--- a/mlir/include/mlir/Query/Query.h
+++ b/mlir/include/mlir/Query/Query.h
@@ -17,7 +17,13 @@
 
 namespace mlir::query {
 
-enum class QueryKind { Invalid, NoOp, Help, Match, Quit };
+struct QueryOptions {
+  bool omitBlockArguments = false;
+  bool omitUsesFromAbove = true;
+  bool inclusive = true;
+};
+
+enum class QueryKind { Invalid, NoOp, Help, Match, Quit, Let, SetBool };
 
 class QuerySession;
 
@@ -103,6 +109,47 @@ struct MatchQuery : Query {
   }
 };
 
+struct LetQuery : Query {
+  LetQuery(llvm::StringRef name, const matcher::VariantValue &value)
+      : Query(QueryKind::Let), name(name), value(value) {}
+
+  llvm::LogicalResult run(llvm::raw_ostream &os,
+                          QuerySession &qs) const override;
+
+  std::string name;
+  matcher::VariantValue value;
+
+  static bool classof(const Query *query) {
+    return query->kind == QueryKind::Let;
+  }
+};
+
+template <typename T>
+struct SetQueryKind {};
+
+template <>
+struct SetQueryKind<bool> {
+  static const QueryKind value = QueryKind::SetBool;
+};
+template <typename T>
+struct SetQuery : Query {
+  SetQuery(T QuerySession::*var, T value)
+      : Query(SetQueryKind<T>::value), var(var), value(value) {}
+
+  llvm::LogicalResult run(llvm::raw_ostream &os,
+                          QuerySession &qs) const override {
+    qs.*var = value;
+    return mlir::success();
+  }
+
+  static bool classof(const Query *query) {
+    return query->kind == SetQueryKind<T>::value;
+  }
+
+  T QuerySession::*var;
+  T value;
+};
+
 } // namespace mlir::query
 
 #endif
diff --git a/mlir/include/mlir/Query/QuerySession.h b/mlir/include/mlir/Query/QuerySession.h
index fe552d750fc771..495358e8f36f94 100644
--- a/mlir/include/mlir/Query/QuerySession.h
+++ b/mlir/include/mlir/Query/QuerySession.h
@@ -9,14 +9,18 @@
 #ifndef MLIR_TOOLS_MLIRQUERY_QUERYSESSION_H
 #define MLIR_TOOLS_MLIRQUERY_QUERYSESSION_H
 
+#include "Matcher/VariantValue.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/Query/Matcher/Registry.h"
 #include "llvm/ADT/StringMap.h"
 #include "llvm/Support/SourceMgr.h"
 
+namespace mlir::query::matcher {
+class Registry;
+}
+
 namespace mlir::query {
 
-class Registry;
 // Represents the state for a particular mlir-query session.
 class QuerySession {
 public:
@@ -33,6 +37,11 @@ class QuerySession {
   llvm::StringMap<matcher::VariantValue> namedValues;
   bool terminate = false;
 
+public:
+  bool omitBlockArguments = false;
+  bool omitUsesFromAbove = true;
+  bool inclusive = true;
+
 private:
   Operation *rootOp;
   llvm::SourceMgr &sourceMgr;
diff --git a/mlir/lib/Query/Matcher/Parser.cpp b/mlir/lib/Query/Matcher/Parser.cpp
index 3609e24f9939f7..726f1188d7e4c8 100644
--- a/mlir/lib/Query/Matcher/Parser.cpp
+++ b/mlir/lib/Query/Matcher/Parser.cpp
@@ -135,6 +135,18 @@ class Parser::CodeTokenizer {
     case '\'':
       consumeStringLiteral(&result);
       break;
+    case '0':
+    case '1':
+    case '2':
+    case '3':
+    case '4':
+    case '5':
+    case '6':
+    case '7':
+    case '8':
+    case '9':
+      consumeNumberLiteral(&result);
+      break;
     default:
       parseIdentifierOrInvalid(&result);
       break;
@@ -144,6 +156,30 @@ class Parser::CodeTokenizer {
     return result;
   }
 
+  void consumeNumberLiteral(TokenInfo *result) {
+    unsigned length = 1;
+    if (code.size() > 1) {
+      // Consume the 'x' or 'b' radix modifier, if present.
+      switch (tolower(code[1])) {
+      case 'x':
+      case 'b':
+        length = 2;
+      }
+    }
+    while (length < code.size() && isdigit(code[length]))
+      ++length;
+
+    result->text = code.take_front(length);
+    code = code.drop_front(length);
+
+    unsigned value;
+    if (!result->text.getAsInteger(0, value)) {
+      result->kind = TokenKind::Literal;
+      result->value = static_cast<unsigned>(value);
+      return;
+    }
+  }
+
   // Consume a string literal, handle escape sequences and missing closing
   // quote.
   void consumeStringLiteral(TokenInfo *result) {
@@ -257,13 +293,19 @@ bool Parser::parseIdentifierPrefixImpl(VariantValue *value) {
 
   if (tokenizer->nextTokenKind() != TokenKind::OpenParen) {
     // Parse as a named value.
-    auto namedValue =
-        namedValues ? namedValues->lookup(nameToken.text) : VariantValue();
+    if (auto namedValue = namedValues ? namedValues->lookup(nameToken.te...
[truncated]

@dbudii dbudii force-pushed the DefiningOpsMatcher_DAG_initial_implementation branch from 88a01fd to 0a12247 Compare January 25, 2025 13:45
@chios202 chios202 changed the title [mlir] MLIR-QUERY DefinitionsMatcher implementation & DAG [mlir] MLIR-QUERY slice-matchers implementation Jan 25, 2025
@dbudii dbudii force-pushed the DefiningOpsMatcher_DAG_initial_implementation branch from 0a12247 to 30aff7a Compare January 25, 2025 13:53
@jpienaar
Copy link
Member

jpienaar commented Feb 3, 2025

I see the presubmit is failing on the newly added tests (mlir-query/complex-test.mlir failing)

return false;
}

auto processValue = [&](Value value) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could mlir/include/mlir/Analysis/SliceAnalysis.h's methods be used here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

temHops parameter which specifies how far into the graph we are searching (how many levels) limits us to reuse code from mlir/include/mlir/Analysis/SliceAnalysis.h

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think maxDepth would be more fitting for this parameter.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at it, it may be possible (but less efficient) by keeping track of the depth in a map and then having the filter return whether the minimum depth of any of the operation's argument's depth is greater or equal to maxDepth.

@dbudii dbudii force-pushed the DefiningOpsMatcher_DAG_initial_implementation branch from 62593f0 to 7a9fbd6 Compare February 23, 2025 15:08
SmallVector<Operation *> printingOps(tempStorage.takeVector());
for (auto op : printingOps) {
if (printMatchingOps) {
printMatch(os, qs, op, "root");
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should use the overloaded print method with no argument for the binding. I accidentally reverted this back to using the old method, but will update it accordingly with the next changes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SG

//
// getMatches walks the IR and prints operations as soon as it matches them
// if a matcher is to be further extracted into the function, then it does not
// print operations
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is printing needed as a match is no longer just an Operation? How about returning a vector of "Match" which consists of the root operation match + environment? If we could decouple the matching from the printing of the matching, this would enable reuse in more places.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would we want to return a vector of vectors (each one containing rootOp + matching environment) for printing and for reusability we could return a vector containing sequences of rootOp + matching environment?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It now follows a similar design to clang-query's MatchFinder class

return false;
}

auto processValue = [&](Value value) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at it, it may be possible (but less efficient) by keeping track of the depth in a map and then having the filter return whether the minimum depth of any of the operation's argument's depth is greater or equal to maxDepth.

@chios202
Copy link
Contributor Author

chios202 commented Mar 4, 2025

Edit: suitable solution:

backwardSlice.clear();
  llvm::DenseMap<Operation *, int64_t> opDepths;
  opDepths[rootOp] = 0;

  query::QueryOptions depthOptions = options;
  depthOptions.filter = [&](Operation *op) {
    if (opDepths[op] > maxDepth)
      return false;
    for (auto operand : op->getOperands()) {
      if (auto definingOp = operand.getDefiningOp()) {
        auto it = opDepths.find(definingOp);
        if (it == opDepths.end()) {
          opDepths[definingOp] = opDepths[op] + 1;
          if (opDepths[op] > maxDepth)
            return false;
        }
      } else {
       auto blockArgument = cast<BlockArgument>(operand));
        Operation *parentOp = blockArgument.getOwner()->getParentOp();
        auto it = opDepths.find(parentOp);
        if (it == opDepths.end()) {
          opDepths[parentOp] = opDepths[op] + 1;
          if (opDepths[op] > maxDepth)
            return false;
        }
      }
    }
    return true;
  };
  getBackwardSlice(rootOp, &backwardSlice, depthOptions);

@dbudii dbudii force-pushed the DefiningOpsMatcher_DAG_initial_implementation branch from 7a9fbd6 to e6bc9b3 Compare March 29, 2025 15:19
@chios202
Copy link
Contributor Author

@jpienaar could you take another look? mlir-query MatchFinder now follows a similar design to clang-query's MatchFinder class and we now reuse the existing algorithm logic with a custom filter

qs.getSourceManager().PrintMessage(os, smloc, llvm::SourceMgr::DK_Note,
"\"" + binding + "\" binds here");
}

// TODO: Extract into a helper function that can be reused outside query
// context.
static Operation *extractFunction(std::vector<Operation *> &ops,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this method be a candidate for the MatchFinder class?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, or even higher, such as a utility function in Analysis or Transform or so, as this is rather common.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could move it to MatchFinder. I could also create an issue for this and assign it to myself. I’d need to investigate where in Analysis or Transform would be the most suitable place.

@dbudii dbudii force-pushed the DefiningOpsMatcher_DAG_initial_implementation branch 2 times, most recently from c51a7b9 to 95bb64e Compare April 1, 2025 16:46
@chios202 chios202 requested a review from jpienaar April 22, 2025 12:50
@dbudii dbudii force-pushed the DefiningOpsMatcher_DAG_initial_implementation branch 3 times, most recently from 524318d to cad295c Compare April 22, 2025 14:31
@chios202 chios202 force-pushed the DefiningOpsMatcher_DAG_initial_implementation branch from cad295c to e07e1fe Compare April 22, 2025 16:33
@chios202
Copy link
Contributor Author

I'm not sure why complex-test.mlir is failing in CI for Linux. Locally I use wsl and it seems to pass. Also I did not modify the printing logic.

@chios202 chios202 force-pushed the DefiningOpsMatcher_DAG_initial_implementation branch 5 times, most recently from 69472f0 to 5f940da Compare April 26, 2025 10:15
@chios202 chios202 force-pushed the DefiningOpsMatcher_DAG_initial_implementation branch 4 times, most recently from 562b1a1 to 9a16aed Compare May 4, 2025 12:20
@chios202
Copy link
Contributor Author

chios202 commented May 4, 2025

@jpienaar could you provide feedback on the latest changes?

@chios202 chios202 force-pushed the DefiningOpsMatcher_DAG_initial_implementation branch 2 times, most recently from 654ac76 to d637929 Compare May 10, 2025 13:56
// RUN: mlir-query %s -c "m getAllDefinitions(hasOpName(\"arith.addf\"),2)" | FileCheck %s

#map = affine_map<(d0, d1) -> (d0, d1)>
func.func @slice_use_from_above(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've copied this test from mlir/test/IR/slice.mlir I coud at least rename it to "slicing-query" and/or change the contents to avoid duplication?

@chios202 chios202 force-pushed the DefiningOpsMatcher_DAG_initial_implementation branch from d637929 to cfabf87 Compare May 12, 2025 18:26
Copy link
Member

@jpienaar jpienaar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this look overall good. Thanks. Lets do the one rename you were also thinking about, and then can approve. Do you need help landing it?

@@ -0,0 +1,85 @@
//===- ExtraMatchers.h - Various common matchers --------------------------===//
//
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about just SliceMatchers? (says what it is and then we don't need to think about what is and what isn't extra)

chios202 added 3 commits May 13, 2025 08:32
Relocate backwardSlice matcher to Query specific headers
Remove unncecessary code
Make BackwardSlice matcher more generic
Capture values in tests
@chios202 chios202 force-pushed the DefiningOpsMatcher_DAG_initial_implementation branch 2 times, most recently from af3bb15 to 87e2e44 Compare May 13, 2025 09:02
@jpienaar jpienaar merged commit 9b63bdd into llvm:main May 13, 2025
7 checks passed
Copy link

@chios202 Congratulations on having your first Pull Request (PR) merged into the LLVM Project!

Your changes will be combined with recent changes from other authors, then tested by our build bots. If there is a problem with a build, you may receive a report in an email or a comment on this PR.

Please check whether problems have been caused by your change specifically, as the builds can include changes from many authors. It is not uncommon for your change to be included in a build that fails due to someone else's changes, or infrastructure issues.

How to do this, and the rest of the post-merge process, is covered in detail here.

If your change does cause a problem, it may be reverted, or you can revert it yourself. This is a normal part of LLVM development. You can fix your changes and open a new PR to merge them again.

If you don't get any reports, no action is required from you. Your changes are working as expected, well done!

@chios202 chios202 changed the title [mlir] Improve mlir-query tool by implementing getBackwardSlice and getForwardSlice matchers [mlir] Improve mlir-query tool by implementing getBackwardSlice matcher May 13, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:linalg mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants