Skip to content

Commit ae8a069

Browse files
committed
[MLIR][ODS] Fix inconsistent operand and arg index usage
1 parent e618a79 commit ae8a069

File tree

2 files changed

+20
-12
lines changed

2 files changed

+20
-12
lines changed

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1878,6 +1878,11 @@ def TestEitherOpB : TEST_Op<"either_op_b"> {
18781878
let results = (outs I32:$output);
18791879
}
18801880

1881+
def TestEitherOpC : TEST_Op<"either_op_c"> {
1882+
let arguments = (ins AnyI32Attr:$attr, AnyInteger:$arg0, AnyInteger:$arg1);
1883+
let results = (outs I32:$output);
1884+
}
1885+
18811886
def : Pat<(TestEitherOpA (either I32:$arg1, I16:$arg2), $x),
18821887
(TestEitherOpB $arg2, $x)>;
18831888

@@ -1889,6 +1894,9 @@ def : Pat<(TestEitherOpA (either (TestEitherOpB I32:$arg1, $_),
18891894
$x),
18901895
(TestEitherOpB $arg2, $x)>;
18911896

1897+
def : Pat<(TestEitherOpC ConstantAttr<I32Attr, "0">, (either $arg1, I32:$arg2)),
1898+
(TestEitherOpB $arg1, $arg2)>;
1899+
18921900
def TestEitherHelperOpA : TEST_Op<"either_helper_op_a"> {
18931901
let arguments = (ins I32:$arg0);
18941902
let results = (outs I32:$output);

mlir/tools/mlir-tblgen/RewriterGen.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -568,11 +568,11 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
568568
return formatv("castedOp{0}", depth);
569569
};
570570

571-
// The order of generating static matcher follows the topological order so
572-
// that for every dependent DagNode already have their static matcher
573-
// generated if needed. The reason we check if `getMatcherName(tree).empty()`
574-
// is when we are generating the static matcher for a DagNode itself. In this
575-
// case, we need to emit the function body rather than a function call.
571+
// Static matchers are generated in topological order so that all dependent
572+
// DagNodes have their static matcher generated beforehand if needed. We check
573+
// if `getMatcherName(tree).empty()` for when we are generating the static
574+
// matcher for a DagNode itself. In this case, we need to emit the function
575+
// body rather than a function call.
576576
if (staticMatcherHelper.useStaticMatcher(tree) &&
577577
!staticMatcherHelper.getMatcherName(tree).empty()) {
578578
emitStaticMatchCall(tree, opName);
@@ -658,7 +658,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
658658
if (isa<NamedTypeConstraint *>(opArg)) {
659659
auto operandName =
660660
formatv("{0}.getODSOperands({1})", castedName, nextOperand);
661-
emitOperandMatch(tree, castedName, operandName.str(), opArgIdx,
661+
emitOperandMatch(tree, castedName, operandName.str(), nextOperand,
662662
/*operandMatcher=*/tree.getArgAsLeaf(i),
663663
/*argName=*/tree.getArgName(i), opArgIdx,
664664
/*variadicSubIndex=*/std::nullopt);
@@ -680,7 +680,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
680680
int argIndex,
681681
std::optional<int> variadicSubIndex) {
682682
Operator &op = tree.getDialectOp(opMap);
683-
auto *operand = cast<NamedTypeConstraint *>(op.getArg(operandIndex));
683+
NamedTypeConstraint operand = op.getOperand(operandIndex);
684684

685685
// If a constraint is specified, we need to generate C++ statements to
686686
// check the constraint.
@@ -693,8 +693,8 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
693693
// Only need to verify if the matcher's type is different from the one
694694
// of op definition.
695695
Constraint constraint = operandMatcher.getAsConstraint();
696-
if (operand->constraint != constraint) {
697-
if (operand->isVariableLength()) {
696+
if (operand.constraint != constraint) {
697+
if (operand.isVariableLength()) {
698698
auto error = formatv(
699699
"further constrain op {0}'s variadic operand #{1} unsupported now",
700700
op.getOperationName(), argIndex);
@@ -706,7 +706,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
706706
verifier, opName, self.str(),
707707
formatv(
708708
"\"operand {0} of op '{1}' failed to satisfy constraint: '{2}'\"",
709-
operand - op.operand_begin(), op.getOperationName(),
709+
operandIndex, op.getOperationName(),
710710
escapeString(constraint.getSummary()))
711711
.str());
712712
}
@@ -715,7 +715,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
715715
// Capture the value
716716
// `$_` is a special symbol to ignore op argument matching.
717717
if (!argName.empty() && argName != "_") {
718-
auto res = symbolInfoMap.findBoundSymbol(argName, tree, op, operandIndex,
718+
auto res = symbolInfoMap.findBoundSymbol(argName, tree, op, argIndex,
719719
variadicSubIndex);
720720
if (res == symbolInfoMap.end())
721721
PrintFatalError(loc, formatv("symbol not found: {0}", argName));
@@ -821,7 +821,7 @@ void PatternEmitter::emitVariadicOperandMatch(DagNode tree,
821821
StringRef variadicTreeName = variadicArgTree.getSymbol();
822822
if (!variadicTreeName.empty()) {
823823
auto res =
824-
symbolInfoMap.findBoundSymbol(variadicTreeName, tree, op, operandIndex,
824+
symbolInfoMap.findBoundSymbol(variadicTreeName, tree, op, argIndex,
825825
/*variadicSubIndex=*/std::nullopt);
826826
if (res == symbolInfoMap.end())
827827
PrintFatalError(loc, formatv("symbol not found: {0}", variadicTreeName));

0 commit comments

Comments
 (0)