Skip to content

Commit 055872a

Browse files
[mlir] Integrate OpAsmTypeInterface with AsmPrinter (#124700)
See https://discourse.llvm.org/t/rfc-introduce-opasm-type-attr-interface-for-pretty-print-in-asmprinter/83792 for detailed introduction. This is a follow up PR of #121187, by integrating OpAsmTypeInterface with AsmPrinter. There are a few conditions when OpAsmTypeInterface comes into play * There is no OpAsmOpInterface * Or OpAsmOpInterface::getAsmResultName/getBlockArgumentName does not invoke `setName` (i.e. the default impl) * All results have OpAsmTypeInterface (otherwise we can not handle result grouping behavior) Cc @River707 @jpienaar @ftynse for review.
1 parent 49453bf commit 055872a

File tree

3 files changed

+85
-0
lines changed

3 files changed

+85
-0
lines changed

mlir/lib/IR/AsmPrinter.cpp

+30
Original file line numberDiff line numberDiff line change
@@ -1536,10 +1536,13 @@ StringRef maybeGetValueNameFromLoc(Value value, StringRef name) {
15361536
} // namespace
15371537

15381538
void SSANameState::numberValuesInRegion(Region &region) {
1539+
// Indicates whether OpAsmOpInterface set a name.
1540+
bool opAsmOpInterfaceUsed = false;
15391541
auto setBlockArgNameFn = [&](Value arg, StringRef name) {
15401542
assert(!valueIDs.count(arg) && "arg numbered multiple times");
15411543
assert(llvm::cast<BlockArgument>(arg).getOwner()->getParent() == &region &&
15421544
"arg not defined in current region");
1545+
opAsmOpInterfaceUsed = true;
15431546
if (LLVM_UNLIKELY(printerFlags.shouldUseNameLocAsPrefix()))
15441547
name = maybeGetValueNameFromLoc(arg, name);
15451548
setValueName(arg, name);
@@ -1549,6 +1552,15 @@ void SSANameState::numberValuesInRegion(Region &region) {
15491552
if (Operation *op = region.getParentOp()) {
15501553
if (auto asmInterface = dyn_cast<OpAsmOpInterface>(op))
15511554
asmInterface.getAsmBlockArgumentNames(region, setBlockArgNameFn);
1555+
// If the OpAsmOpInterface didn't set a name, get name from the type.
1556+
if (!opAsmOpInterfaceUsed) {
1557+
for (BlockArgument arg : region.getArguments()) {
1558+
if (auto interface = dyn_cast<OpAsmTypeInterface>(arg.getType())) {
1559+
interface.getAsmName(
1560+
[&](StringRef name) { setBlockArgNameFn(arg, name); });
1561+
}
1562+
}
1563+
}
15521564
}
15531565
}
15541566

@@ -1598,9 +1610,12 @@ void SSANameState::numberValuesInBlock(Block &block) {
15981610
void SSANameState::numberValuesInOp(Operation &op) {
15991611
// Function used to set the special result names for the operation.
16001612
SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0);
1613+
// Indicates whether OpAsmOpInterface set a name.
1614+
bool opAsmOpInterfaceUsed = false;
16011615
auto setResultNameFn = [&](Value result, StringRef name) {
16021616
assert(!valueIDs.count(result) && "result numbered multiple times");
16031617
assert(result.getDefiningOp() == &op && "result not defined by 'op'");
1618+
opAsmOpInterfaceUsed = true;
16041619
if (LLVM_UNLIKELY(printerFlags.shouldUseNameLocAsPrefix()))
16051620
name = maybeGetValueNameFromLoc(result, name);
16061621
setValueName(result, name);
@@ -1630,6 +1645,21 @@ void SSANameState::numberValuesInOp(Operation &op) {
16301645
asmInterface.getAsmBlockNames(setBlockNameFn);
16311646
asmInterface.getAsmResultNames(setResultNameFn);
16321647
}
1648+
if (!opAsmOpInterfaceUsed) {
1649+
// If the OpAsmOpInterface didn't set a name, and all results have
1650+
// OpAsmTypeInterface, get names from types.
1651+
bool allHaveOpAsmTypeInterface =
1652+
llvm::all_of(op.getResultTypes(), [&](Type type) {
1653+
return isa<OpAsmTypeInterface>(type);
1654+
});
1655+
if (allHaveOpAsmTypeInterface) {
1656+
for (OpResult result : op.getResults()) {
1657+
auto interface = cast<OpAsmTypeInterface>(result.getType());
1658+
interface.getAsmName(
1659+
[&](StringRef name) { setResultNameFn(result, name); });
1660+
}
1661+
}
1662+
}
16331663
}
16341664

16351665
unsigned numResults = op.getNumResults();

mlir/test/IR/op-asm-interface.mlir

+36
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,39 @@ func.func @block_argument_name_from_op_asm_type_interface() {
2222
}
2323
return
2424
}
25+
26+
// -----
27+
28+
//===----------------------------------------------------------------------===//
29+
// Test OpAsmTypeInterface
30+
//===----------------------------------------------------------------------===//
31+
32+
func.func @result_name_from_op_asm_type_interface_asmprinter() {
33+
// CHECK-LABEL: @result_name_from_op_asm_type_interface_asmprinter
34+
// CHECK: %op_asm_type_interface
35+
%0 = "test.result_name_from_type_interface"() : () -> !test.op_asm_type_interface
36+
return
37+
}
38+
39+
// -----
40+
41+
// i1 does not have OpAsmTypeInterface, should not get named.
42+
func.func @result_name_from_op_asm_type_interface_not_all() {
43+
// CHECK-LABEL: @result_name_from_op_asm_type_interface_not_all
44+
// CHECK-NOT: %op_asm_type_interface
45+
// CHECK: %0:2
46+
%0:2 = "test.result_name_from_type_interface"() : () -> (!test.op_asm_type_interface, i1)
47+
return
48+
}
49+
50+
// -----
51+
52+
func.func @block_argument_name_from_op_asm_type_interface_asmprinter() {
53+
// CHECK-LABEL: @block_argument_name_from_op_asm_type_interface_asmprinter
54+
// CHECK: ^bb0(%op_asm_type_interface
55+
test.block_argument_name_from_type_interface {
56+
^bb0(%arg0: !test.op_asm_type_interface):
57+
"test.terminator"() : ()->()
58+
}
59+
return
60+
}

mlir/test/lib/Dialect/Test/TestOps.td

+19
Original file line numberDiff line numberDiff line change
@@ -955,6 +955,25 @@ def BlockArgumentNameFromTypeOp
955955
let assemblyFormat = "regions attr-dict-with-keyword";
956956
}
957957

958+
// This is used to test OpAsmTypeInterface::getAsmName's integration with AsmPrinter
959+
// for op result name when OpAsmOpInterface::getAsmResultNames is the default implementation
960+
// i.e. does nothing.
961+
def ResultNameFromTypeInterfaceOp
962+
: TEST_Op<"result_name_from_type_interface",
963+
[OpAsmOpInterface]> {
964+
let results = (outs Variadic<AnyType>:$r);
965+
}
966+
967+
// This is used to test OpAsmTypeInterface::getAsmName's integration with AsmPrinter
968+
// for block argument name when OpAsmOpInterface::getAsmBlockArgumentNames is the default implementation
969+
// i.e. does nothing.
970+
def BlockArgumentNameFromTypeInterfaceOp
971+
: TEST_Op<"block_argument_name_from_type_interface",
972+
[OpAsmOpInterface]> {
973+
let regions = (region AnyRegion:$body);
974+
let assemblyFormat = "regions attr-dict-with-keyword";
975+
}
976+
958977
// This is used to test the OpAsmOpInterface::getDefaultDialect() feature:
959978
// operations nested in a region under this op will drop the "test." dialect
960979
// prefix.

0 commit comments

Comments
 (0)