Skip to content

[MLIR][DRR] Fix inconsistent operand and arg index usage #139816

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions mlir/test/lib/Dialect/Test/TestOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1872,6 +1872,11 @@ def TestEitherOpB : TEST_Op<"either_op_b"> {
let results = (outs I32:$output);
}

def TestEitherOpC : TEST_Op<"either_op_c"> {
let arguments = (ins AnyI32Attr:$attr, AnyInteger:$arg0, AnyInteger:$arg1);
let results = (outs I32:$output);
}

def : Pat<(TestEitherOpA (either I32:$arg1, I16:$arg2), $x),
(TestEitherOpB $arg2, $x)>;

Expand All @@ -1883,6 +1888,9 @@ def : Pat<(TestEitherOpA (either (TestEitherOpB I32:$arg1, $_),
$x),
(TestEitherOpB $arg2, $x)>;

def : Pat<(TestEitherOpC ConstantAttr<I32Attr, "0">, (either $arg1, I32:$arg2)),
(TestEitherOpB $arg1, $arg2)>;

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens with this without the changes to the RewriterGen.cpp file?

Copy link
Author

@xl4624 xl4624 May 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without the changes this test fails on a cast assertion:

mlir-tblgen: /home/xiaomin/dev/llvm-project/llvm/include/llvm/Support/Casting.h:566: decltype(auto) llvm::cast(const From &) [To = mlir::tblgen::NamedTypeConstraint *, From = llvm::PointerUnion<mlir::tblgen::NamedAttribute *, mlir::tblgen::NamedProperty *, mlir::tblgen::NamedTypeConstraint *>]: Assertion `isa<To>(Val) && "cast<Ty>() argument of incompatible type!"' failed.

Coming from:

} else if (isa<NamedTypeConstraint *>(op.getArg(argIndex))) {
emitOperandMatch(tree, opName, /*operandName=*/formatv("v{0}", i).str(),
operandIndex,
/*operandMatcher=*/eitherArgTree.getArgAsLeaf(i),
/*argName=*/eitherArgTree.getArgName(i), argIndex,
/*variadicSubIndex=*/std::nullopt);

void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
StringRef operandName, int operandIndex,
DagLeaf operandMatcher, StringRef argName,
int argIndex,
std::optional<int> variadicSubIndex) {
Operator &op = tree.getDialectOp(opMap);
auto *operand = cast<NamedTypeConstraint *>(op.getArg(operandIndex));

In emitEitherOperand(), we check if op.getArg(argIndex) is a NamedTypeConstraint, but in emitOperandMatch we cast op.getArg(operandIndex). In the test case above, operandIndex and argIndex get out of sync due to the Attribute being in the front which leads to a cast with the wrong index (specifically argIndex=1 referring to $arg1 and operandIndex=0 referring to the ConstantAttr)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a test for this rewrite though to ensure it works as expected?

Copy link
Author

@xl4624 xl4624 May 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a test for this in pattern.mlir, also cleaned up surrounding tests to match the style of the rest of the file.

def TestEitherHelperOpA : TEST_Op<"either_helper_op_a"> {
let arguments = (ins I32:$arg0);
let results = (outs I32:$output);
Expand Down
31 changes: 25 additions & 6 deletions mlir/test/mlir-tblgen/pattern.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -609,17 +609,17 @@ func.func @redundantTest(%arg0: i32) -> i32 {
// Test either directive
//===----------------------------------------------------------------------===//

// CHECK: @either_dag_leaf_only
func.func @either_dag_leaf_only_1(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () {
// CHECK-LABEL: @eitherDagLeafOnly
func.func @eitherDagLeafOnly(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () {
// CHECK: "test.either_op_b"(%arg1, %arg2) : (i16, i8) -> i32
%0 = "test.either_op_a"(%arg0, %arg1, %arg2) : (i32, i16, i8) -> i32
// CHECK: "test.either_op_b"(%arg1, %arg2) : (i16, i8) -> i32
%1 = "test.either_op_a"(%arg1, %arg0, %arg2) : (i16, i32, i8) -> i32
return
}

// CHECK: @either_dag_leaf_dag_node
func.func @either_dag_leaf_dag_node(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () {
// CHECK-LABEL: @eitherDagLeafDagNode
func.func @eitherDagLeafDagNode(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () {
%0 = "test.either_op_b"(%arg0, %arg0) : (i32, i32) -> i32
// CHECK: "test.either_op_b"(%arg1, %arg2) : (i16, i8) -> i32
%1 = "test.either_op_a"(%0, %arg1, %arg2) : (i32, i16, i8) -> i32
Expand All @@ -628,8 +628,8 @@ func.func @either_dag_leaf_dag_node(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> ()
return
}

// CHECK: @either_dag_node_dag_node
func.func @either_dag_node_dag_node(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () {
// CHECK-LABEL: @eitherDagNodeDagNode
func.func @eitherDagNodeDagNode(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () {
%0 = "test.either_op_b"(%arg0, %arg0) : (i32, i32) -> i32
%1 = "test.either_op_b"(%arg1, %arg1) : (i16, i16) -> i32
// CHECK: "test.either_op_b"(%arg1, %arg2) : (i16, i8) -> i32
Expand All @@ -639,24 +639,38 @@ func.func @either_dag_node_dag_node(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> ()
return
}

// CHECK-LABEL: @testEitherOpWithAttr
func.func @testEitherOpWithAttr(%arg0 : i32, %arg1 : i16) -> () {
// CHECK: "test.either_op_b"(%arg1, %arg0) : (i16, i32) -> i32
%0 = "test.either_op_c"(%arg0, %arg1) {attr = 0 : i32} : (i32, i16) -> i32
// CHECK: "test.either_op_b"(%arg1, %arg0) : (i16, i32) -> i32
%1 = "test.either_op_c"(%arg1, %arg0) {attr = 0 : i32} : (i16, i32) -> i32
// CHECK: "test.either_op_c"(%arg0, %arg1) <{attr = 1 : i32}> : (i32, i16) -> i32
%2 = "test.either_op_c"(%arg0, %arg1) {attr = 1 : i32} : (i32, i16) -> i32
return
}

//===----------------------------------------------------------------------===//
// Test that ops without type deduction can be created with type builders.
//===----------------------------------------------------------------------===//

// CHECK-LABEL: @explicitReturnTypeTest
func.func @explicitReturnTypeTest(%arg0 : i64) -> i8 {
%0 = "test.source_op"(%arg0) {tag = 11 : i32} : (i64) -> i8
// CHECK: "test.op_x"(%arg0) : (i64) -> i32
// CHECK: "test.op_x"(%0) : (i32) -> i8
return %0 : i8
}

// CHECK-LABEL: @returnTypeBuilderTest
func.func @returnTypeBuilderTest(%arg0 : i1) -> i8 {
%0 = "test.source_op"(%arg0) {tag = 22 : i32} : (i1) -> i8
// CHECK: "test.op_x"(%arg0) : (i1) -> i1
// CHECK: "test.op_x"(%0) : (i1) -> i8
return %0 : i8
}

// CHECK-LABEL: @multipleReturnTypeBuildTest
func.func @multipleReturnTypeBuildTest(%arg0 : i1) -> i1 {
%0 = "test.source_op"(%arg0) {tag = 33 : i32} : (i1) -> i1
// CHECK: "test.one_to_two"(%arg0) : (i1) -> (i64, i32)
Expand All @@ -666,13 +680,15 @@ func.func @multipleReturnTypeBuildTest(%arg0 : i1) -> i1 {
return %0 : i1
}

// CHECK-LABEL: @copyValueType
func.func @copyValueType(%arg0 : i8) -> i32 {
%0 = "test.source_op"(%arg0) {tag = 44 : i32} : (i8) -> i32
// CHECK: "test.op_x"(%arg0) : (i8) -> i8
// CHECK: "test.op_x"(%0) : (i8) -> i32
return %0 : i32
}

// CHECK-LABEL: @multipleReturnTypeDifferent
func.func @multipleReturnTypeDifferent(%arg0 : i1) -> i64 {
%0 = "test.source_op"(%arg0) {tag = 55 : i32} : (i1) -> i64
// CHECK: "test.one_to_two"(%arg0) : (i1) -> (i1, i64)
Expand All @@ -684,6 +700,7 @@ func.func @multipleReturnTypeDifferent(%arg0 : i1) -> i64 {
// Test that multiple trailing directives can be mixed in patterns.
//===----------------------------------------------------------------------===//

// CHECK-LABEL: @returnTypeAndLocation
func.func @returnTypeAndLocation(%arg0 : i32) -> i1 {
%0 = "test.source_op"(%arg0) {tag = 66 : i32} : (i32) -> i1
// CHECK: "test.op_x"(%arg0) : (i32) -> i32 loc("loc1")
Expand All @@ -696,6 +713,7 @@ func.func @returnTypeAndLocation(%arg0 : i32) -> i1 {
// Test that patterns can create ConstantStrAttr
//===----------------------------------------------------------------------===//

// CHECK-LABEL: @testConstantStrAttr
func.func @testConstantStrAttr() -> () {
// CHECK: test.has_str_value {value = "foo"}
test.no_str_value {value = "bar"}
Expand All @@ -706,6 +724,7 @@ func.func @testConstantStrAttr() -> () {
// Test that patterns with variadics propagate sizes
//===----------------------------------------------------------------------===//

// CHECK-LABEL: @testVariadic
func.func @testVariadic(%arg_0: i32, %arg_1: i32, %brg: i64,
%crg_0: f32, %crg_1: f32, %crg_2: f32, %crg_3: f32) -> () {
// CHECK: "test.variadic_rewrite_dst_op"(%arg2, %arg3, %arg4, %arg5, %arg6, %arg0, %arg1) <{operandSegmentSizes = array<i32: 1, 4, 2>}> : (i64, f32, f32, f32, f32, i32, i32) -> ()
Expand Down
14 changes: 7 additions & 7 deletions mlir/tools/mlir-tblgen/RewriterGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
if (isa<NamedTypeConstraint *>(opArg)) {
auto operandName =
formatv("{0}.getODSOperands({1})", castedName, nextOperand);
emitOperandMatch(tree, castedName, operandName.str(), opArgIdx,
emitOperandMatch(tree, castedName, operandName.str(), nextOperand,
/*operandMatcher=*/tree.getArgAsLeaf(i),
/*argName=*/tree.getArgName(i), opArgIdx,
/*variadicSubIndex=*/std::nullopt);
Expand All @@ -680,7 +680,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
int argIndex,
std::optional<int> variadicSubIndex) {
Operator &op = tree.getDialectOp(opMap);
auto *operand = cast<NamedTypeConstraint *>(op.getArg(operandIndex));
NamedTypeConstraint operand = op.getOperand(operandIndex);

// If a constraint is specified, we need to generate C++ statements to
// check the constraint.
Expand All @@ -693,8 +693,8 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
// Only need to verify if the matcher's type is different from the one
// of op definition.
Constraint constraint = operandMatcher.getAsConstraint();
if (operand->constraint != constraint) {
if (operand->isVariableLength()) {
if (operand.constraint != constraint) {
if (operand.isVariableLength()) {
auto error = formatv(
"further constrain op {0}'s variadic operand #{1} unsupported now",
op.getOperationName(), argIndex);
Expand All @@ -706,7 +706,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
verifier, opName, self.str(),
formatv(
"\"operand {0} of op '{1}' failed to satisfy constraint: '{2}'\"",
operand - op.operand_begin(), op.getOperationName(),
operandIndex, op.getOperationName(),
escapeString(constraint.getSummary()))
.str());
}
Expand All @@ -715,7 +715,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
// Capture the value
// `$_` is a special symbol to ignore op argument matching.
if (!argName.empty() && argName != "_") {
auto res = symbolInfoMap.findBoundSymbol(argName, tree, op, operandIndex,
auto res = symbolInfoMap.findBoundSymbol(argName, tree, op, argIndex,
variadicSubIndex);
if (res == symbolInfoMap.end())
PrintFatalError(loc, formatv("symbol not found: {0}", argName));
Expand Down Expand Up @@ -821,7 +821,7 @@ void PatternEmitter::emitVariadicOperandMatch(DagNode tree,
StringRef variadicTreeName = variadicArgTree.getSymbol();
if (!variadicTreeName.empty()) {
auto res =
symbolInfoMap.findBoundSymbol(variadicTreeName, tree, op, operandIndex,
symbolInfoMap.findBoundSymbol(variadicTreeName, tree, op, argIndex,
/*variadicSubIndex=*/std::nullopt);
if (res == symbolInfoMap.end())
PrintFatalError(loc, formatv("symbol not found: {0}", variadicTreeName));
Expand Down