-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[MLIR][DRR] Fix inconsistent operand and arg index usage #139816
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
@llvm/pr-subscribers-mlir Author: Xiaomin Liu (xl4624) ChangesBackground issue: #139813 In emitEitherOperandMatch() we check if } else if (isa<NamedTypeConstraint *>(op.getArg(argIndex))) {
emitOperandMatch(tree, opName, /*operandName=*/formatv("v{0}", i).str(),
operandIndex,
/*operandMatcher=*/eitherArgTree.getArgAsLeaf(i),
/*argName=*/eitherArgTree.getArgName(i), argIndex,
/*variadicSubIndex=*/std::nullopt);
++operandIndex;
} but in Full diff: https://github.com/llvm/llvm-project/pull/139816.diff 2 Files Affected:
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 3e461999e2730..5bcd66b85c364 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1878,6 +1878,11 @@ def TestEitherOpB : TEST_Op<"either_op_b"> {
let results = (outs I32:$output);
}
+def TestEitherOpC : TEST_Op<"either_op_b"> {
+ let arguments = (ins AnyI32Attr:$attr, AnyInteger:$arg0, AnyInteger:$arg1);
+ let results = (outs I32:$output);
+}
+
def : Pat<(TestEitherOpA (either I32:$arg1, I16:$arg2), $x),
(TestEitherOpB $arg2, $x)>;
@@ -1889,6 +1894,9 @@ def : Pat<(TestEitherOpA (either (TestEitherOpB I32:$arg1, $_),
$x),
(TestEitherOpB $arg2, $x)>;
+def : Pat<(TestEitherOpC ConstantAttr<I32Attr, "0">, (either $arg1, I32:$arg2)),
+ (TestEitherOpB $arg1, $arg2)>;
+
def TestEitherHelperOpA : TEST_Op<"either_helper_op_a"> {
let arguments = (ins I32:$arg0);
let results = (outs I32:$output);
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 58abcc2bee895..471baea5f5268 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -568,11 +568,11 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
return formatv("castedOp{0}", depth);
};
- // The order of generating static matcher follows the topological order so
- // that for every dependent DagNode already have their static matcher
- // generated if needed. The reason we check if `getMatcherName(tree).empty()`
- // is when we are generating the static matcher for a DagNode itself. In this
- // case, we need to emit the function body rather than a function call.
+ // Static matchersare generated in topological order so that all dependent
+ // DagNodes have their static matcher generated beforehand if needed. We check
+ // if `getMatcherName(tree).empty()` for when we are generating the static
+ // matcher for a DagNode itself. In this case, we need to emit the function
+ // body rather than a function call.
if (staticMatcherHelper.useStaticMatcher(tree) &&
!staticMatcherHelper.getMatcherName(tree).empty()) {
emitStaticMatchCall(tree, opName);
@@ -658,7 +658,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
if (isa<NamedTypeConstraint *>(opArg)) {
auto operandName =
formatv("{0}.getODSOperands({1})", castedName, nextOperand);
- emitOperandMatch(tree, castedName, operandName.str(), opArgIdx,
+ emitOperandMatch(tree, castedName, operandName.str(), nextOperand,
/*operandMatcher=*/tree.getArgAsLeaf(i),
/*argName=*/tree.getArgName(i), opArgIdx,
/*variadicSubIndex=*/std::nullopt);
@@ -680,7 +680,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
int argIndex,
std::optional<int> variadicSubIndex) {
Operator &op = tree.getDialectOp(opMap);
- auto *operand = cast<NamedTypeConstraint *>(op.getArg(operandIndex));
+ NamedTypeConstraint operand = op.getOperand(operandIndex);
// If a constraint is specified, we need to generate C++ statements to
// check the constraint.
@@ -693,8 +693,8 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
// Only need to verify if the matcher's type is different from the one
// of op definition.
Constraint constraint = operandMatcher.getAsConstraint();
- if (operand->constraint != constraint) {
- if (operand->isVariableLength()) {
+ if (operand.constraint != constraint) {
+ if (operand.isVariableLength()) {
auto error = formatv(
"further constrain op {0}'s variadic operand #{1} unsupported now",
op.getOperationName(), argIndex);
@@ -706,7 +706,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
verifier, opName, self.str(),
formatv(
"\"operand {0} of op '{1}' failed to satisfy constraint: '{2}'\"",
- operand - op.operand_begin(), op.getOperationName(),
+ operandIndex, op.getOperationName(),
escapeString(constraint.getSummary()))
.str());
}
@@ -715,7 +715,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
// Capture the value
// `$_` is a special symbol to ignore op argument matching.
if (!argName.empty() && argName != "_") {
- auto res = symbolInfoMap.findBoundSymbol(argName, tree, op, operandIndex,
+ auto res = symbolInfoMap.findBoundSymbol(argName, tree, op, argIndex,
variadicSubIndex);
if (res == symbolInfoMap.end())
PrintFatalError(loc, formatv("symbol not found: {0}", argName));
@@ -821,7 +821,7 @@ void PatternEmitter::emitVariadicOperandMatch(DagNode tree,
StringRef variadicTreeName = variadicArgTree.getSymbol();
if (!variadicTreeName.empty()) {
auto res =
- symbolInfoMap.findBoundSymbol(variadicTreeName, tree, op, operandIndex,
+ symbolInfoMap.findBoundSymbol(variadicTreeName, tree, op, argIndex,
/*variadicSubIndex=*/std::nullopt);
if (res == symbolInfoMap.end())
PrintFatalError(loc, formatv("symbol not found: {0}", variadicTreeName));
|
@llvm/pr-subscribers-mlir-core Author: Xiaomin Liu (xl4624) ChangesBackground issue: #139813 In emitEitherOperandMatch() we check if } else if (isa<NamedTypeConstraint *>(op.getArg(argIndex))) {
emitOperandMatch(tree, opName, /*operandName=*/formatv("v{0}", i).str(),
operandIndex,
/*operandMatcher=*/eitherArgTree.getArgAsLeaf(i),
/*argName=*/eitherArgTree.getArgName(i), argIndex,
/*variadicSubIndex=*/std::nullopt);
++operandIndex;
} but in Full diff: https://github.com/llvm/llvm-project/pull/139816.diff 2 Files Affected:
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 3e461999e2730..5bcd66b85c364 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1878,6 +1878,11 @@ def TestEitherOpB : TEST_Op<"either_op_b"> {
let results = (outs I32:$output);
}
+def TestEitherOpC : TEST_Op<"either_op_b"> {
+ let arguments = (ins AnyI32Attr:$attr, AnyInteger:$arg0, AnyInteger:$arg1);
+ let results = (outs I32:$output);
+}
+
def : Pat<(TestEitherOpA (either I32:$arg1, I16:$arg2), $x),
(TestEitherOpB $arg2, $x)>;
@@ -1889,6 +1894,9 @@ def : Pat<(TestEitherOpA (either (TestEitherOpB I32:$arg1, $_),
$x),
(TestEitherOpB $arg2, $x)>;
+def : Pat<(TestEitherOpC ConstantAttr<I32Attr, "0">, (either $arg1, I32:$arg2)),
+ (TestEitherOpB $arg1, $arg2)>;
+
def TestEitherHelperOpA : TEST_Op<"either_helper_op_a"> {
let arguments = (ins I32:$arg0);
let results = (outs I32:$output);
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 58abcc2bee895..471baea5f5268 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -568,11 +568,11 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
return formatv("castedOp{0}", depth);
};
- // The order of generating static matcher follows the topological order so
- // that for every dependent DagNode already have their static matcher
- // generated if needed. The reason we check if `getMatcherName(tree).empty()`
- // is when we are generating the static matcher for a DagNode itself. In this
- // case, we need to emit the function body rather than a function call.
+ // Static matchersare generated in topological order so that all dependent
+ // DagNodes have their static matcher generated beforehand if needed. We check
+ // if `getMatcherName(tree).empty()` for when we are generating the static
+ // matcher for a DagNode itself. In this case, we need to emit the function
+ // body rather than a function call.
if (staticMatcherHelper.useStaticMatcher(tree) &&
!staticMatcherHelper.getMatcherName(tree).empty()) {
emitStaticMatchCall(tree, opName);
@@ -658,7 +658,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
if (isa<NamedTypeConstraint *>(opArg)) {
auto operandName =
formatv("{0}.getODSOperands({1})", castedName, nextOperand);
- emitOperandMatch(tree, castedName, operandName.str(), opArgIdx,
+ emitOperandMatch(tree, castedName, operandName.str(), nextOperand,
/*operandMatcher=*/tree.getArgAsLeaf(i),
/*argName=*/tree.getArgName(i), opArgIdx,
/*variadicSubIndex=*/std::nullopt);
@@ -680,7 +680,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
int argIndex,
std::optional<int> variadicSubIndex) {
Operator &op = tree.getDialectOp(opMap);
- auto *operand = cast<NamedTypeConstraint *>(op.getArg(operandIndex));
+ NamedTypeConstraint operand = op.getOperand(operandIndex);
// If a constraint is specified, we need to generate C++ statements to
// check the constraint.
@@ -693,8 +693,8 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
// Only need to verify if the matcher's type is different from the one
// of op definition.
Constraint constraint = operandMatcher.getAsConstraint();
- if (operand->constraint != constraint) {
- if (operand->isVariableLength()) {
+ if (operand.constraint != constraint) {
+ if (operand.isVariableLength()) {
auto error = formatv(
"further constrain op {0}'s variadic operand #{1} unsupported now",
op.getOperationName(), argIndex);
@@ -706,7 +706,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
verifier, opName, self.str(),
formatv(
"\"operand {0} of op '{1}' failed to satisfy constraint: '{2}'\"",
- operand - op.operand_begin(), op.getOperationName(),
+ operandIndex, op.getOperationName(),
escapeString(constraint.getSummary()))
.str());
}
@@ -715,7 +715,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
// Capture the value
// `$_` is a special symbol to ignore op argument matching.
if (!argName.empty() && argName != "_") {
- auto res = symbolInfoMap.findBoundSymbol(argName, tree, op, operandIndex,
+ auto res = symbolInfoMap.findBoundSymbol(argName, tree, op, argIndex,
variadicSubIndex);
if (res == symbolInfoMap.end())
PrintFatalError(loc, formatv("symbol not found: {0}", argName));
@@ -821,7 +821,7 @@ void PatternEmitter::emitVariadicOperandMatch(DagNode tree,
StringRef variadicTreeName = variadicArgTree.getSymbol();
if (!variadicTreeName.empty()) {
auto res =
- symbolInfoMap.findBoundSymbol(variadicTreeName, tree, op, operandIndex,
+ symbolInfoMap.findBoundSymbol(variadicTreeName, tree, op, argIndex,
/*variadicSubIndex=*/std::nullopt);
if (res == symbolInfoMap.end())
PrintFatalError(loc, formatv("symbol not found: {0}", variadicTreeName));
|
ae8a069
to
4b4efde
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks.
126a033
to
6edef55
Compare
@@ -1883,6 +1888,9 @@ def : Pat<(TestEitherOpA (either (TestEitherOpB I32:$arg1, $_), | |||
$x), | |||
(TestEitherOpB $arg2, $x)>; | |||
|
|||
def : Pat<(TestEitherOpC ConstantAttr<I32Attr, "0">, (either $arg1, I32:$arg2)), | |||
(TestEitherOpB $arg1, $arg2)>; | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens with this without the changes to the RewriterGen.cpp file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Without the changes this test fails on a cast assertion:
mlir-tblgen: /home/xiaomin/dev/llvm-project/llvm/include/llvm/Support/Casting.h:566: decltype(auto) llvm::cast(const From &) [To = mlir::tblgen::NamedTypeConstraint *, From = llvm::PointerUnion<mlir::tblgen::NamedAttribute *, mlir::tblgen::NamedProperty *, mlir::tblgen::NamedTypeConstraint *>]: Assertion `isa<To>(Val) && "cast<Ty>() argument of incompatible type!"' failed.
Coming from:
llvm-project/mlir/tools/mlir-tblgen/RewriterGen.cpp
Lines 773 to 778 in 751e6c0
} else if (isa<NamedTypeConstraint *>(op.getArg(argIndex))) { | |
emitOperandMatch(tree, opName, /*operandName=*/formatv("v{0}", i).str(), | |
operandIndex, | |
/*operandMatcher=*/eitherArgTree.getArgAsLeaf(i), | |
/*argName=*/eitherArgTree.getArgName(i), argIndex, | |
/*variadicSubIndex=*/std::nullopt); |
llvm-project/mlir/tools/mlir-tblgen/RewriterGen.cpp
Lines 677 to 683 in bf1d4a0
void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName, | |
StringRef operandName, int operandIndex, | |
DagLeaf operandMatcher, StringRef argName, | |
int argIndex, | |
std::optional<int> variadicSubIndex) { | |
Operator &op = tree.getDialectOp(opMap); | |
auto *operand = cast<NamedTypeConstraint *>(op.getArg(operandIndex)); |
In emitEitherOperand()
, we check if op.getArg(argIndex)
is a NamedTypeConstraint
, but in emitOperandMatch
we cast op.getArg(operandIndex)
. In the test case above, operandIndex
and argIndex
get out of sync due to the Attribute being in the front which leads to a cast with the wrong index (specifically argIndex=1
referring to $arg1
and operandIndex=0
referring to the ConstantAttr
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a test for this rewrite though to ensure it works as expected?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a test for this in pattern.mlir
, also cleaned up surrounding tests to match the style of the rest of the file.
7dc9751
to
98fae34
Compare
[MLIR][DRR] Fix inconsistent operand and arg index usage
Background issue: #139813
In emitEitherOperandMatch() we check if
op.getArg(argIndex)
is aNamedTypeConstraint
:but in
emitOperandMatch()
we cast onop.getArg(operandIndex)
, which is incorrect if the operation has attributes or other non-operand arguments before its operands.