Skip to content

[MLIR][DRR] Fix inconsistent operand and arg index usage #139816

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

xl4624
Copy link

@xl4624 xl4624 commented May 14, 2025

Background issue: #139813

In emitEitherOperandMatch() we check if op.getArg(argIndex) is a NamedTypeConstraint:

} else if (isa<NamedTypeConstraint *>(op.getArg(argIndex))) {
      emitOperandMatch(tree, opName, /*operandName=*/formatv("v{0}", i).str(),
                       operandIndex,
                       /*operandMatcher=*/eitherArgTree.getArgAsLeaf(i),
                       /*argName=*/eitherArgTree.getArgName(i), argIndex,
                       /*variadicSubIndex=*/std::nullopt);
      ++operandIndex;
}

but in emitOperandMatch() we cast on op.getArg(operandIndex), which is incorrect if the operation has attributes or other non-operand arguments before its operands.

Copy link

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels May 14, 2025
@llvmbot
Copy link
Member

llvmbot commented May 14, 2025

@llvm/pr-subscribers-mlir

Author: Xiaomin Liu (xl4624)

Changes

Background issue: #139813

In emitEitherOperandMatch() we check if op.getArg(argIndex) is a NamedTypeConstraint:

} else if (isa&lt;NamedTypeConstraint *&gt;(op.getArg(argIndex))) {
      emitOperandMatch(tree, opName, /*operandName=*/formatv("v{0}", i).str(),
                       operandIndex,
                       /*operandMatcher=*/eitherArgTree.getArgAsLeaf(i),
                       /*argName=*/eitherArgTree.getArgName(i), argIndex,
                       /*variadicSubIndex=*/std::nullopt);
      ++operandIndex;
}

but in emitOperandMatch() we cast on op.getArg(operandIndex), which is incorrect if the operation has attributes or other non-operand arguments before its operands.


Full diff: https://github.com/llvm/llvm-project/pull/139816.diff

2 Files Affected:

  • (modified) mlir/test/lib/Dialect/Test/TestOps.td (+8)
  • (modified) mlir/tools/mlir-tblgen/RewriterGen.cpp (+12-12)
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 3e461999e2730..5bcd66b85c364 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1878,6 +1878,11 @@ def TestEitherOpB : TEST_Op<"either_op_b"> {
   let results = (outs I32:$output);
 }
 
+def TestEitherOpC : TEST_Op<"either_op_b"> {
+  let arguments = (ins AnyI32Attr:$attr, AnyInteger:$arg0, AnyInteger:$arg1);
+  let results = (outs I32:$output);
+}
+
 def : Pat<(TestEitherOpA (either I32:$arg1, I16:$arg2), $x),
           (TestEitherOpB $arg2, $x)>;
 
@@ -1889,6 +1894,9 @@ def : Pat<(TestEitherOpA (either (TestEitherOpB I32:$arg1, $_),
                           $x),
           (TestEitherOpB $arg2, $x)>;
 
+def : Pat<(TestEitherOpC ConstantAttr<I32Attr, "0">, (either $arg1, I32:$arg2)),
+          (TestEitherOpB $arg1, $arg2)>;
+
 def TestEitherHelperOpA : TEST_Op<"either_helper_op_a"> {
   let arguments = (ins I32:$arg0);
   let results = (outs I32:$output);
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 58abcc2bee895..471baea5f5268 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -568,11 +568,11 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
     return formatv("castedOp{0}", depth);
   };
 
-  // The order of generating static matcher follows the topological order so
-  // that for every dependent DagNode already have their static matcher
-  // generated if needed. The reason we check if `getMatcherName(tree).empty()`
-  // is when we are generating the static matcher for a DagNode itself. In this
-  // case, we need to emit the function body rather than a function call.
+  // Static matchersare generated in topological order so that all dependent
+  // DagNodes have their static matcher generated beforehand if needed. We check
+  // if `getMatcherName(tree).empty()` for when we are generating the static
+  // matcher for a DagNode itself. In this case, we need to emit the function
+  // body rather than a function call.
   if (staticMatcherHelper.useStaticMatcher(tree) &&
       !staticMatcherHelper.getMatcherName(tree).empty()) {
     emitStaticMatchCall(tree, opName);
@@ -658,7 +658,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
     if (isa<NamedTypeConstraint *>(opArg)) {
       auto operandName =
           formatv("{0}.getODSOperands({1})", castedName, nextOperand);
-      emitOperandMatch(tree, castedName, operandName.str(), opArgIdx,
+      emitOperandMatch(tree, castedName, operandName.str(), nextOperand,
                        /*operandMatcher=*/tree.getArgAsLeaf(i),
                        /*argName=*/tree.getArgName(i), opArgIdx,
                        /*variadicSubIndex=*/std::nullopt);
@@ -680,7 +680,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
                                       int argIndex,
                                       std::optional<int> variadicSubIndex) {
   Operator &op = tree.getDialectOp(opMap);
-  auto *operand = cast<NamedTypeConstraint *>(op.getArg(operandIndex));
+  NamedTypeConstraint operand = op.getOperand(operandIndex);
 
   // If a constraint is specified, we need to generate C++ statements to
   // check the constraint.
@@ -693,8 +693,8 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
     // Only need to verify if the matcher's type is different from the one
     // of op definition.
     Constraint constraint = operandMatcher.getAsConstraint();
-    if (operand->constraint != constraint) {
-      if (operand->isVariableLength()) {
+    if (operand.constraint != constraint) {
+      if (operand.isVariableLength()) {
         auto error = formatv(
             "further constrain op {0}'s variadic operand #{1} unsupported now",
             op.getOperationName(), argIndex);
@@ -706,7 +706,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
           verifier, opName, self.str(),
           formatv(
               "\"operand {0} of op '{1}' failed to satisfy constraint: '{2}'\"",
-              operand - op.operand_begin(), op.getOperationName(),
+              operandIndex, op.getOperationName(),
               escapeString(constraint.getSummary()))
               .str());
     }
@@ -715,7 +715,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
   // Capture the value
   // `$_` is a special symbol to ignore op argument matching.
   if (!argName.empty() && argName != "_") {
-    auto res = symbolInfoMap.findBoundSymbol(argName, tree, op, operandIndex,
+    auto res = symbolInfoMap.findBoundSymbol(argName, tree, op, argIndex,
                                              variadicSubIndex);
     if (res == symbolInfoMap.end())
       PrintFatalError(loc, formatv("symbol not found: {0}", argName));
@@ -821,7 +821,7 @@ void PatternEmitter::emitVariadicOperandMatch(DagNode tree,
   StringRef variadicTreeName = variadicArgTree.getSymbol();
   if (!variadicTreeName.empty()) {
     auto res =
-        symbolInfoMap.findBoundSymbol(variadicTreeName, tree, op, operandIndex,
+        symbolInfoMap.findBoundSymbol(variadicTreeName, tree, op, argIndex,
                                       /*variadicSubIndex=*/std::nullopt);
     if (res == symbolInfoMap.end())
       PrintFatalError(loc, formatv("symbol not found: {0}", variadicTreeName));

@llvmbot
Copy link
Member

llvmbot commented May 14, 2025

@llvm/pr-subscribers-mlir-core

Author: Xiaomin Liu (xl4624)

Changes

Background issue: #139813

In emitEitherOperandMatch() we check if op.getArg(argIndex) is a NamedTypeConstraint:

} else if (isa&lt;NamedTypeConstraint *&gt;(op.getArg(argIndex))) {
      emitOperandMatch(tree, opName, /*operandName=*/formatv("v{0}", i).str(),
                       operandIndex,
                       /*operandMatcher=*/eitherArgTree.getArgAsLeaf(i),
                       /*argName=*/eitherArgTree.getArgName(i), argIndex,
                       /*variadicSubIndex=*/std::nullopt);
      ++operandIndex;
}

but in emitOperandMatch() we cast on op.getArg(operandIndex), which is incorrect if the operation has attributes or other non-operand arguments before its operands.


Full diff: https://github.com/llvm/llvm-project/pull/139816.diff

2 Files Affected:

  • (modified) mlir/test/lib/Dialect/Test/TestOps.td (+8)
  • (modified) mlir/tools/mlir-tblgen/RewriterGen.cpp (+12-12)
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 3e461999e2730..5bcd66b85c364 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1878,6 +1878,11 @@ def TestEitherOpB : TEST_Op<"either_op_b"> {
   let results = (outs I32:$output);
 }
 
+def TestEitherOpC : TEST_Op<"either_op_b"> {
+  let arguments = (ins AnyI32Attr:$attr, AnyInteger:$arg0, AnyInteger:$arg1);
+  let results = (outs I32:$output);
+}
+
 def : Pat<(TestEitherOpA (either I32:$arg1, I16:$arg2), $x),
           (TestEitherOpB $arg2, $x)>;
 
@@ -1889,6 +1894,9 @@ def : Pat<(TestEitherOpA (either (TestEitherOpB I32:$arg1, $_),
                           $x),
           (TestEitherOpB $arg2, $x)>;
 
+def : Pat<(TestEitherOpC ConstantAttr<I32Attr, "0">, (either $arg1, I32:$arg2)),
+          (TestEitherOpB $arg1, $arg2)>;
+
 def TestEitherHelperOpA : TEST_Op<"either_helper_op_a"> {
   let arguments = (ins I32:$arg0);
   let results = (outs I32:$output);
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 58abcc2bee895..471baea5f5268 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -568,11 +568,11 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
     return formatv("castedOp{0}", depth);
   };
 
-  // The order of generating static matcher follows the topological order so
-  // that for every dependent DagNode already have their static matcher
-  // generated if needed. The reason we check if `getMatcherName(tree).empty()`
-  // is when we are generating the static matcher for a DagNode itself. In this
-  // case, we need to emit the function body rather than a function call.
+  // Static matchersare generated in topological order so that all dependent
+  // DagNodes have their static matcher generated beforehand if needed. We check
+  // if `getMatcherName(tree).empty()` for when we are generating the static
+  // matcher for a DagNode itself. In this case, we need to emit the function
+  // body rather than a function call.
   if (staticMatcherHelper.useStaticMatcher(tree) &&
       !staticMatcherHelper.getMatcherName(tree).empty()) {
     emitStaticMatchCall(tree, opName);
@@ -658,7 +658,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
     if (isa<NamedTypeConstraint *>(opArg)) {
       auto operandName =
           formatv("{0}.getODSOperands({1})", castedName, nextOperand);
-      emitOperandMatch(tree, castedName, operandName.str(), opArgIdx,
+      emitOperandMatch(tree, castedName, operandName.str(), nextOperand,
                        /*operandMatcher=*/tree.getArgAsLeaf(i),
                        /*argName=*/tree.getArgName(i), opArgIdx,
                        /*variadicSubIndex=*/std::nullopt);
@@ -680,7 +680,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
                                       int argIndex,
                                       std::optional<int> variadicSubIndex) {
   Operator &op = tree.getDialectOp(opMap);
-  auto *operand = cast<NamedTypeConstraint *>(op.getArg(operandIndex));
+  NamedTypeConstraint operand = op.getOperand(operandIndex);
 
   // If a constraint is specified, we need to generate C++ statements to
   // check the constraint.
@@ -693,8 +693,8 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
     // Only need to verify if the matcher's type is different from the one
     // of op definition.
     Constraint constraint = operandMatcher.getAsConstraint();
-    if (operand->constraint != constraint) {
-      if (operand->isVariableLength()) {
+    if (operand.constraint != constraint) {
+      if (operand.isVariableLength()) {
         auto error = formatv(
             "further constrain op {0}'s variadic operand #{1} unsupported now",
             op.getOperationName(), argIndex);
@@ -706,7 +706,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
           verifier, opName, self.str(),
           formatv(
               "\"operand {0} of op '{1}' failed to satisfy constraint: '{2}'\"",
-              operand - op.operand_begin(), op.getOperationName(),
+              operandIndex, op.getOperationName(),
               escapeString(constraint.getSummary()))
               .str());
     }
@@ -715,7 +715,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
   // Capture the value
   // `$_` is a special symbol to ignore op argument matching.
   if (!argName.empty() && argName != "_") {
-    auto res = symbolInfoMap.findBoundSymbol(argName, tree, op, operandIndex,
+    auto res = symbolInfoMap.findBoundSymbol(argName, tree, op, argIndex,
                                              variadicSubIndex);
     if (res == symbolInfoMap.end())
       PrintFatalError(loc, formatv("symbol not found: {0}", argName));
@@ -821,7 +821,7 @@ void PatternEmitter::emitVariadicOperandMatch(DagNode tree,
   StringRef variadicTreeName = variadicArgTree.getSymbol();
   if (!variadicTreeName.empty()) {
     auto res =
-        symbolInfoMap.findBoundSymbol(variadicTreeName, tree, op, operandIndex,
+        symbolInfoMap.findBoundSymbol(variadicTreeName, tree, op, argIndex,
                                       /*variadicSubIndex=*/std::nullopt);
     if (res == symbolInfoMap.end())
       PrintFatalError(loc, formatv("symbol not found: {0}", variadicTreeName));

@xl4624 xl4624 force-pushed the mlir-tblgen branch 2 times, most recently from ae8a069 to 4b4efde Compare May 14, 2025 01:00
Copy link
Contributor

@loganchien loganchien left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks.

@xl4624 xl4624 changed the title [MLIR][ODS] Fix inconsistent operand and arg index usage [MLIR][DRR] Fix inconsistent operand and arg index usage May 14, 2025
@xl4624 xl4624 force-pushed the mlir-tblgen branch 2 times, most recently from 126a033 to 6edef55 Compare May 17, 2025 03:38
@@ -1883,6 +1888,9 @@ def : Pat<(TestEitherOpA (either (TestEitherOpB I32:$arg1, $_),
$x),
(TestEitherOpB $arg2, $x)>;

def : Pat<(TestEitherOpC ConstantAttr<I32Attr, "0">, (either $arg1, I32:$arg2)),
(TestEitherOpB $arg1, $arg2)>;

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens with this without the changes to the RewriterGen.cpp file?

Copy link
Author

@xl4624 xl4624 May 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without the changes this test fails on a cast assertion:

mlir-tblgen: /home/xiaomin/dev/llvm-project/llvm/include/llvm/Support/Casting.h:566: decltype(auto) llvm::cast(const From &) [To = mlir::tblgen::NamedTypeConstraint *, From = llvm::PointerUnion<mlir::tblgen::NamedAttribute *, mlir::tblgen::NamedProperty *, mlir::tblgen::NamedTypeConstraint *>]: Assertion `isa<To>(Val) && "cast<Ty>() argument of incompatible type!"' failed.

Coming from:

} else if (isa<NamedTypeConstraint *>(op.getArg(argIndex))) {
emitOperandMatch(tree, opName, /*operandName=*/formatv("v{0}", i).str(),
operandIndex,
/*operandMatcher=*/eitherArgTree.getArgAsLeaf(i),
/*argName=*/eitherArgTree.getArgName(i), argIndex,
/*variadicSubIndex=*/std::nullopt);

void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
StringRef operandName, int operandIndex,
DagLeaf operandMatcher, StringRef argName,
int argIndex,
std::optional<int> variadicSubIndex) {
Operator &op = tree.getDialectOp(opMap);
auto *operand = cast<NamedTypeConstraint *>(op.getArg(operandIndex));

In emitEitherOperand(), we check if op.getArg(argIndex) is a NamedTypeConstraint, but in emitOperandMatch we cast op.getArg(operandIndex). In the test case above, operandIndex and argIndex get out of sync due to the Attribute being in the front which leads to a cast with the wrong index (specifically argIndex=1 referring to $arg1 and operandIndex=0 referring to the ConstantAttr)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a test for this rewrite though to ensure it works as expected?

Copy link
Author

@xl4624 xl4624 May 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a test for this in pattern.mlir, also cleaned up surrounding tests to match the style of the rest of the file.

@xl4624 xl4624 force-pushed the mlir-tblgen branch 6 times, most recently from 7dc9751 to 98fae34 Compare May 19, 2025 06:59
[MLIR][DRR] Fix inconsistent operand and arg index usage
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants