-
Notifications
You must be signed in to change notification settings - Fork 13.3k
[mlir] Integrate OpAsmTypeInterface with AsmPrinter #124700
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir] Integrate OpAsmTypeInterface with AsmPrinter #124700
Conversation
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir Author: Hongren Zheng (ZenithalHourlyRate) ChangesSee https://discourse.llvm.org/t/rfc-introduce-opasm-type-attr-interface-for-pretty-print-in-asmprinter/83792 for detailed introduction. This is a follow up PR of #121187, by integrating OpAsmTypeInterface with AsmPrinter. There are a few conditions when OpAsmTypeInterface comes into play
Cc @River707 @jpienaar @ftynse for review. Full diff: https://github.com/llvm/llvm-project/pull/124700.diff 3 Files Affected:
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index fa4a1b4b72b024..a689e2b673a242 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1536,10 +1536,13 @@ StringRef maybeGetValueNameFromLoc(Value value, StringRef name) {
} // namespace
void SSANameState::numberValuesInRegion(Region ®ion) {
+ // indicate whether OpAsmOpInterface set a name
+ bool opAsmOpInterfaceUsed = false;
auto setBlockArgNameFn = [&](Value arg, StringRef name) {
assert(!valueIDs.count(arg) && "arg numbered multiple times");
assert(llvm::cast<BlockArgument>(arg).getOwner()->getParent() == ®ion &&
"arg not defined in current region");
+ opAsmOpInterfaceUsed = true;
if (LLVM_UNLIKELY(printerFlags.shouldUseNameLocAsPrefix()))
name = maybeGetValueNameFromLoc(arg, name);
setValueName(arg, name);
@@ -1549,6 +1552,23 @@ void SSANameState::numberValuesInRegion(Region ®ion) {
if (Operation *op = region.getParentOp()) {
if (auto asmInterface = dyn_cast<OpAsmOpInterface>(op))
asmInterface.getAsmBlockArgumentNames(region, setBlockArgNameFn);
+ if (!opAsmOpInterfaceUsed) {
+ // If the OpAsmOpInterface didn't set a name, and when
+ // all arguments have OpAsmTypeInterface, get names from the type
+ bool allHaveOpAsmTypeInterface =
+ llvm::all_of(region.getArguments(), [&](Value arg) {
+ return mlir::isa<OpAsmTypeInterface>(arg.getType());
+ });
+ if (allHaveOpAsmTypeInterface) {
+ for (auto arg : region.getArguments()) {
+ auto typeInterface = mlir::cast<OpAsmTypeInterface>(arg.getType());
+ auto setNameFn = [&](StringRef name) {
+ setBlockArgNameFn(arg, name);
+ };
+ typeInterface.getAsmName(setNameFn);
+ }
+ }
+ }
}
}
@@ -1598,9 +1618,12 @@ void SSANameState::numberValuesInBlock(Block &block) {
void SSANameState::numberValuesInOp(Operation &op) {
// Function used to set the special result names for the operation.
SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0);
+ // indicating whether OpAsmOpInterface set a name
+ bool opAsmOpInterfaceUsed = false;
auto setResultNameFn = [&](Value result, StringRef name) {
assert(!valueIDs.count(result) && "result numbered multiple times");
assert(result.getDefiningOp() == &op && "result not defined by 'op'");
+ opAsmOpInterfaceUsed = true;
if (LLVM_UNLIKELY(printerFlags.shouldUseNameLocAsPrefix()))
name = maybeGetValueNameFromLoc(result, name);
setValueName(result, name);
@@ -1630,6 +1653,23 @@ void SSANameState::numberValuesInOp(Operation &op) {
asmInterface.getAsmBlockNames(setBlockNameFn);
asmInterface.getAsmResultNames(setResultNameFn);
}
+ if (!opAsmOpInterfaceUsed) {
+ // If the OpAsmOpInterface didn't set a name, and when
+ // all results have OpAsmTypeInterface, get names from the type
+ bool allHaveOpAsmTypeInterface =
+ llvm::all_of(op.getResults(), [&](Value result) {
+ return mlir::isa<OpAsmTypeInterface>(result.getType());
+ });
+ if (allHaveOpAsmTypeInterface) {
+ for (auto result : op.getResults()) {
+ auto typeInterface = mlir::cast<OpAsmTypeInterface>(result.getType());
+ auto setNameFn = [&](StringRef name) {
+ setResultNameFn(result, name);
+ };
+ typeInterface.getAsmName(setNameFn);
+ }
+ }
+ }
}
unsigned numResults = op.getNumResults();
diff --git a/mlir/test/IR/op-asm-interface.mlir b/mlir/test/IR/op-asm-interface.mlir
index a9c199e3dc9736..fe73750ba0edf5 100644
--- a/mlir/test/IR/op-asm-interface.mlir
+++ b/mlir/test/IR/op-asm-interface.mlir
@@ -22,3 +22,39 @@ func.func @block_argument_name_from_op_asm_type_interface() {
}
return
}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Test OpAsmTypeInterface
+//===----------------------------------------------------------------------===//
+
+func.func @result_name_from_op_asm_type_interface_asmprinter() {
+ // CHECK-LABEL: @result_name_from_op_asm_type_interface_asmprinter
+ // CHECK: %op_asm_type_interface
+ %0 = "test.result_name_from_type_interface"() : () -> !test.op_asm_type_interface
+ return
+}
+
+// -----
+
+// i1 does not have OpAsmTypeInterface, should not get named.
+func.func @result_name_from_op_asm_type_interface_not_all() {
+ // CHECK-LABEL: @result_name_from_op_asm_type_interface_not_all
+ // CHECK-NOT: %op_asm_type_interface
+ // CHECK: %0:2
+ %0:2 = "test.result_name_from_type_interface"() : () -> (!test.op_asm_type_interface, i1)
+ return
+}
+
+// -----
+
+func.func @block_argument_name_from_op_asm_type_interface_asmprinter() {
+ // CHECK-LABEL: @block_argument_name_from_op_asm_type_interface_asmprinter
+ // CHECK: ^bb0(%op_asm_type_interface
+ test.block_argument_name_from_type_interface {
+ ^bb0(%arg0: !test.op_asm_type_interface):
+ "test.terminator"() : ()->()
+ }
+ return
+}
\ No newline at end of file
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index f37573c1351cec..c22363b14b1867 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -939,6 +939,25 @@ def BlockArgumentNameFromTypeOp
let assemblyFormat = "regions attr-dict-with-keyword";
}
+// This is used to test OpAsmTypeInterface::getAsmName's integration with AsmPrinter
+// for op result name when OpAsmOpInterface::getAsmResultNames is the default implementation
+// i.e. does nothing
+def ResultNameFromTypeInterfaceOp
+ : TEST_Op<"result_name_from_type_interface",
+ [OpAsmOpInterface]> {
+ let results = (outs Variadic<AnyType>:$r);
+}
+
+// This is used to test OpAsmTypeInterface::getAsmName's integration with AsmPrinter
+// for block argument name when OpAsmOpInterface::getAsmBlockArgumentNames is the default implementation
+// i.e. does nothing
+def BlockArgumentNameFromTypeInterfaceOp
+ : TEST_Op<"block_argument_name_from_type_interface",
+ [OpAsmOpInterface]> {
+ let regions = (region AnyRegion:$body);
+ let assemblyFormat = "regions attr-dict-with-keyword";
+}
+
// This is used to test the OpAsmOpInterface::getDefaultDialect() feature:
// operations nested in a region under this op will drop the "test." dialect
// prefix.
|
4206f58
to
0b82313
Compare
Ping for review |
mlir/lib/IR/AsmPrinter.cpp
Outdated
bool allHaveOpAsmTypeInterface = | ||
llvm::all_of(op.getResultTypes(), [&](Type type) { | ||
return mlir::isa<OpAsmTypeInterface>(type); | ||
}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why check if they all implement the interface?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For an op with multiple results, we could not suggest a meaningful result "group" name.
For example, in the following assembly the second result should not bear the name of type_name
. It is the responsibility of OpAsmOpInterface
to handle the grouping behavior.
%type_name:2 = test.some_op %arg : i32 -> (!test.type, i32)
We should rather suggest the following way to clearly separate them.
%type_name, %0 = test.some_op %arg : i32 -> (!test.type, i32)
But the default grouping behavior is not this way, so I only apply such separation when all of the results have OpAsmTypeInterface
.
0b82313
to
d6c9bbf
Compare
Comments addressed. |
Ping for review |
Ping for review |
See https://discourse.llvm.org/t/rfc-introduce-opasm-type-attr-interface-for-pretty-print-in-asmprinter/83792 for detailed introduction. This is a follow up PR of llvm#121187, by integrating OpAsmTypeInterface with AsmPrinter. There are a few conditions when OpAsmTypeInterface comes into play * There is no OpAsmOpInterface * Or OpAsmOpInterface::getAsmResultName/getBlockArgumentName does not invoke `setName` (i.e. the default impl) * All results have OpAsmTypeInterface (otherwise we can not handle result grouping behavior) Cc @River707 @jpienaar @ftynse for review.
See https://discourse.llvm.org/t/rfc-introduce-opasm-type-attr-interface-for-pretty-print-in-asmprinter/83792 for detailed introduction.
This is a follow up PR of #121187, by integrating OpAsmTypeInterface with AsmPrinter. There are a few conditions when OpAsmTypeInterface comes into play
setName
(i.e. the default impl)Cc @River707 @jpienaar @ftynse for review.