-
Notifications
You must be signed in to change notification settings - Fork 13.3k
[mlir] Add the concept of ASM dialect aliases #86033
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
✅ With the latest revision this PR passed the C/C++ code formatter. |
a00f2cb
to
e32dad6
Compare
@llvm/pr-subscribers-mlir-ods @llvm/pr-subscribers-mlir-core Author: Fabian Mora (fabianmcg) ChangesASM dialect aliases provide a mechanism to provide arbitrary aliases to types and attributes in pretty form. To use these aliases, users must complete several steps. For printing an alias, they need to:
For parsing an alias, the steps are:
Users also must attach the interface An example of this mechanism was added to the tests, specifically: This change is needed to alias "!llvm.ptr" with "ptr". Full diff: https://github.com/llvm/llvm-project/pull/86033.diff 5 Files Affected:
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 5333d7446df5ca..5c3d93f8e0cff6 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -1715,7 +1715,10 @@ class OpAsmDialectInterface
OverridableAlias,
/// An alias was provided and it should be used
/// (no other hooks will be checked).
- FinalAlias
+ FinalAlias,
+ /// A dialect alias was provided and it will be used
+ /// (no other hooks will be checked).
+ DialectAlias
};
/// Hooks for getting an alias identifier alias for a given symbol, that is
@@ -1729,6 +1732,29 @@ class OpAsmDialectInterface
return AliasResult::NoAlias;
}
+ /// Hooks for parsing a dialect alias. The method returns success if the
+ /// dialect has an alias for the symbol, otherwise it must return failure.
+ /// If there was an error during parsing, this method should return success
+ /// and set the attribute to null.
+ virtual LogicalResult parseDialectAlias(DialectAsmParser &parser,
+ Attribute &attr, Type type) const {
+ return failure();
+ }
+ virtual LogicalResult parseDialectAlias(DialectAsmParser &parser,
+ Type &type) const {
+ return failure();
+ }
+ /// Hooks for printing a dialect alias.
+ virtual void printDialectAlias(DialectAsmPrinter &printer,
+ Attribute attr) const {
+ llvm_unreachable("Dialect must implement `printDialectAlias` when defining "
+ "dialect aliases");
+ }
+ virtual void printDialectAlias(DialectAsmPrinter &printer, Type type) const {
+ llvm_unreachable("Dialect must implement `printDialectAlias` when defining "
+ "dialect aliases");
+ }
+
//===--------------------------------------------------------------------===//
// Resources
//===--------------------------------------------------------------------===//
diff --git a/mlir/lib/AsmParser/DialectSymbolParser.cpp b/mlir/lib/AsmParser/DialectSymbolParser.cpp
index 80cce7e6ae43f5..9261ef2fb3eb95 100644
--- a/mlir/lib/AsmParser/DialectSymbolParser.cpp
+++ b/mlir/lib/AsmParser/DialectSymbolParser.cpp
@@ -269,6 +269,12 @@ Attribute Parser::parseExtendedAttr(Type type) {
// Parse the attribute.
CustomDialectAsmParser customParser(symbolData, *this);
+ if (auto iface = dyn_cast<OpAsmDialectInterface>(dialect)) {
+ Attribute attr{};
+ if (succeeded(iface->parseDialectAlias(customParser, attr, type)))
+ return attr;
+ resetToken(symbolData.data());
+ }
Attribute attr = dialect->parseAttribute(customParser, attrType);
resetToken(curLexerPos);
return attr;
@@ -310,6 +316,12 @@ Type Parser::parseExtendedType() {
// Parse the type.
CustomDialectAsmParser customParser(symbolData, *this);
+ if (auto iface = dyn_cast<OpAsmDialectInterface>(dialect)) {
+ Type type{};
+ if (succeeded(iface->parseDialectAlias(customParser, type)))
+ return type;
+ resetToken(symbolData.data());
+ }
Type type = dialect->parseType(customParser);
resetToken(curLexerPos);
return type;
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 456cf6a2c27783..7aabc360517ac0 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -542,7 +542,9 @@ class AliasInitializer {
aliasOS(aliasBuffer) {}
void initialize(Operation *op, const OpPrintingFlags &printerFlags,
- llvm::MapVector<const void *, SymbolAlias> &attrTypeToAlias);
+ llvm::MapVector<const void *, SymbolAlias> &attrTypeToAlias,
+ llvm::DenseMap<const void *, const OpAsmDialectInterface *>
+ &attrTypeToDialectAlias);
/// Visit the given attribute to see if it has an alias. `canBeDeferred` is
/// set to true if the originator of this attribute can resolve the alias
@@ -570,6 +572,10 @@ class AliasInitializer {
InProgressAliasInfo(StringRef alias, bool isType, bool canBeDeferred)
: alias(alias), aliasDepth(1), isType(isType),
canBeDeferred(canBeDeferred) {}
+ InProgressAliasInfo(const OpAsmDialectInterface *aliasDialect, bool isType,
+ bool canBeDeferred)
+ : alias(std::nullopt), aliasDepth(1), isType(isType),
+ canBeDeferred(canBeDeferred), aliasDialect(aliasDialect) {}
bool operator<(const InProgressAliasInfo &rhs) const {
// Order first by depth, then by attr/type kind, and then by name.
@@ -577,6 +583,8 @@ class AliasInitializer {
return aliasDepth < rhs.aliasDepth;
if (isType != rhs.isType)
return isType;
+ if (aliasDialect != rhs.aliasDialect)
+ return aliasDialect < rhs.aliasDialect;
return alias < rhs.alias;
}
@@ -592,6 +600,8 @@ class AliasInitializer {
bool canBeDeferred : 1;
/// Indices for child aliases.
SmallVector<size_t> childIndices;
+ /// Dialect interface used to print the alias.
+ const OpAsmDialectInterface *aliasDialect{};
};
/// Visit the given attribute or type to see if it has an alias.
@@ -617,7 +627,9 @@ class AliasInitializer {
/// symbol to a given alias.
static void initializeAliases(
llvm::MapVector<const void *, InProgressAliasInfo> &visitedSymbols,
- llvm::MapVector<const void *, SymbolAlias> &symbolToAlias);
+ llvm::MapVector<const void *, SymbolAlias> &symbolToAlias,
+ llvm::DenseMap<const void *, const OpAsmDialectInterface *>
+ &attrTypeToDialectAlias);
/// The set of asm interfaces within the context.
DialectInterfaceCollection<OpAsmDialectInterface> &interfaces;
@@ -1027,7 +1039,9 @@ static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer,
/// symbol to a given alias.
void AliasInitializer::initializeAliases(
llvm::MapVector<const void *, InProgressAliasInfo> &visitedSymbols,
- llvm::MapVector<const void *, SymbolAlias> &symbolToAlias) {
+ llvm::MapVector<const void *, SymbolAlias> &symbolToAlias,
+ llvm::DenseMap<const void *, const OpAsmDialectInterface *>
+ &attrTypeToDialectAlias) {
SmallVector<std::pair<const void *, InProgressAliasInfo>, 0>
unprocessedAliases = visitedSymbols.takeVector();
llvm::stable_sort(unprocessedAliases, [](const auto &lhs, const auto &rhs) {
@@ -1036,8 +1050,12 @@ void AliasInitializer::initializeAliases(
llvm::StringMap<unsigned> nameCounts;
for (auto &[symbol, aliasInfo] : unprocessedAliases) {
- if (!aliasInfo.alias)
+ if (!aliasInfo.alias && !aliasInfo.aliasDialect)
continue;
+ if (aliasInfo.aliasDialect) {
+ attrTypeToDialectAlias.insert({symbol, aliasInfo.aliasDialect});
+ continue;
+ }
StringRef alias = *aliasInfo.alias;
unsigned nameIndex = nameCounts[alias]++;
symbolToAlias.insert(
@@ -1048,7 +1066,9 @@ void AliasInitializer::initializeAliases(
void AliasInitializer::initialize(
Operation *op, const OpPrintingFlags &printerFlags,
- llvm::MapVector<const void *, SymbolAlias> &attrTypeToAlias) {
+ llvm::MapVector<const void *, SymbolAlias> &attrTypeToAlias,
+ llvm::DenseMap<const void *, const OpAsmDialectInterface *>
+ &attrTypeToDialectAlias) {
// Use a dummy printer when walking the IR so that we can collect the
// attributes/types that will actually be used during printing when
// considering aliases.
@@ -1056,7 +1076,7 @@ void AliasInitializer::initialize(
aliasPrinter.printCustomOrGenericOp(op);
// Initialize the aliases.
- initializeAliases(aliases, attrTypeToAlias);
+ initializeAliases(aliases, attrTypeToAlias, attrTypeToDialectAlias);
}
template <typename T, typename... PrintArgs>
@@ -1113,9 +1133,14 @@ template <typename T>
void AliasInitializer::generateAlias(T symbol, InProgressAliasInfo &alias,
bool canBeDeferred) {
SmallString<32> nameBuffer;
+ const OpAsmDialectInterface *dialectAlias = nullptr;
for (const auto &interface : interfaces) {
OpAsmDialectInterface::AliasResult result =
interface.getAlias(symbol, aliasOS);
+ if (result == OpAsmDialectInterface::AliasResult::DialectAlias) {
+ dialectAlias = &interface;
+ break;
+ }
if (result == OpAsmDialectInterface::AliasResult::NoAlias)
continue;
nameBuffer = std::move(aliasBuffer);
@@ -1123,6 +1148,11 @@ void AliasInitializer::generateAlias(T symbol, InProgressAliasInfo &alias,
if (result == OpAsmDialectInterface::AliasResult::FinalAlias)
break;
}
+ if (dialectAlias) {
+ alias = InProgressAliasInfo(
+ dialectAlias, /*isType=*/std::is_base_of_v<Type, T>, canBeDeferred);
+ return;
+ }
if (nameBuffer.empty())
return;
@@ -1157,6 +1187,13 @@ class AliasState {
/// Returns success if an alias was printed, failure otherwise.
LogicalResult getAlias(Type ty, raw_ostream &os) const;
+ /// Get a dialect alias for the given attribute if it has one or return
+ /// nullptr.
+ const OpAsmDialectInterface *getDialectAlias(Attribute attr) const;
+
+ /// Get a dialect alias for the given type if it has one or return nullptr.
+ const OpAsmDialectInterface *getDialectAlias(Type ty) const;
+
/// Print all of the referenced aliases that can not be resolved in a deferred
/// manner.
void printNonDeferredAliases(AsmPrinter::Impl &p, NewLineCounter &newLine) {
@@ -1177,6 +1214,10 @@ class AliasState {
/// Mapping between attribute/type and alias.
llvm::MapVector<const void *, SymbolAlias> attrTypeToAlias;
+ /// Mapping between attribute/type and alias dialect interfaces.
+ llvm::DenseMap<const void *, const OpAsmDialectInterface *>
+ attrTypeToDialectAlias;
+
/// An allocator used for alias names.
llvm::BumpPtrAllocator aliasAllocator;
};
@@ -1186,7 +1227,8 @@ void AliasState::initialize(
Operation *op, const OpPrintingFlags &printerFlags,
DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
AliasInitializer initializer(interfaces, aliasAllocator);
- initializer.initialize(op, printerFlags, attrTypeToAlias);
+ initializer.initialize(op, printerFlags, attrTypeToAlias,
+ attrTypeToDialectAlias);
}
LogicalResult AliasState::getAlias(Attribute attr, raw_ostream &os) const {
@@ -1206,6 +1248,20 @@ LogicalResult AliasState::getAlias(Type ty, raw_ostream &os) const {
return success();
}
+const OpAsmDialectInterface *AliasState::getDialectAlias(Attribute attr) const {
+ auto it = attrTypeToDialectAlias.find(attr.getAsOpaquePointer());
+ if (it == attrTypeToDialectAlias.end())
+ return nullptr;
+ return it->second;
+}
+
+const OpAsmDialectInterface *AliasState::getDialectAlias(Type ty) const {
+ auto it = attrTypeToDialectAlias.find(ty.getAsOpaquePointer());
+ if (it == attrTypeToDialectAlias.end())
+ return nullptr;
+ return it->second;
+}
+
void AliasState::printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine,
bool isDeferred) {
auto filterFn = [=](const auto &aliasIt) {
@@ -2189,11 +2245,41 @@ static void printElidedElementsAttr(raw_ostream &os) {
}
LogicalResult AsmPrinter::Impl::printAlias(Attribute attr) {
- return state.getAliasState().getAlias(attr, os);
+ if (succeeded(state.getAliasState().getAlias(attr, os)))
+ return success();
+ const OpAsmDialectInterface *iface =
+ state.getAliasState().getDialectAlias(attr);
+ if (!iface)
+ return failure();
+ // Ask the dialect to serialize the attribute to a string.
+ std::string attrName;
+ {
+ llvm::raw_string_ostream attrNameStr(attrName);
+ Impl subPrinter(attrNameStr, state);
+ DialectAsmPrinter printer(subPrinter);
+ iface->printDialectAlias(printer, attr);
+ }
+ printDialectSymbol(os, "#", iface->getDialect()->getNamespace(), attrName);
+ return success();
}
LogicalResult AsmPrinter::Impl::printAlias(Type type) {
- return state.getAliasState().getAlias(type, os);
+ if (succeeded(state.getAliasState().getAlias(type, os)))
+ return success();
+ const OpAsmDialectInterface *iface =
+ state.getAliasState().getDialectAlias(type);
+ if (!iface)
+ return failure();
+ // Ask the dialect to serialize the type to a string.
+ std::string typeName;
+ {
+ llvm::raw_string_ostream typeNameStr(typeName);
+ Impl subPrinter(typeNameStr, state);
+ DialectAsmPrinter printer(subPrinter);
+ iface->printDialectAlias(printer, type);
+ }
+ printDialectSymbol(os, "!", iface->getDialect()->getNamespace(), typeName);
+ return success();
}
void AsmPrinter::Impl::printAttribute(Attribute attr,
diff --git a/mlir/test/IR/print-attr-type-aliases.mlir b/mlir/test/IR/print-attr-type-aliases.mlir
index 162eacd0022832..6065573072e638 100644
--- a/mlir/test/IR/print-attr-type-aliases.mlir
+++ b/mlir/test/IR/print-attr-type-aliases.mlir
@@ -21,6 +21,12 @@
// CHECK-DAG: !test_tuple = tuple<!test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla>
"test.op"() {alias_test = "alias_test:large_tuple"} : () -> (tuple<!test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla>)
+// CHECK-DAG: #test.test_string<"hello">, #test.test_string<" world">
+"test.op"() {alias_test = ["test_dialect_alias:hello", "test_dialect_alias: world"]} : () -> ()
+
+// CHECK-DAG: #test.test_string<"hello">, #test.test_string<" world">
+"test.op"() {alias_test = [#test.test_string<"hello">, #test.test_string<" world">]} : () -> ()
+
// CHECK-DAG: #test_encoding = "alias_test:tensor_encoding"
// CHECK-DAG: tensor<32xf32, #test_encoding>
"test.op"() : () -> tensor<32xf32, "alias_test:tensor_encoding">
@@ -29,6 +35,15 @@
// CHECK-DAG: tensor<32x!test_ui8_>
"test.op"() : () -> tensor<32x!test.int<unsigned, 8>>
+// CHECK-DAG: !test.tensor_int3<!test_ui8_>
+"test.op"() : () -> tensor<3x!test.int<unsigned, 8>>
+
+// CHECK-DAG: !test.tensor_int3<!test.int<signed, 8>>
+"test.op"() : () -> !test.tensor_int3<!test.int<signed, 8>>
+
+// CHECK-DAG: tensor<3xi3>
+"test.op"() : () -> !test.tensor_int3<i3>
+
// CHECK-DAG: #loc = loc("nested")
// CHECK-DAG: #loc1 = loc("test.mlir":10:8)
// CHECK-DAG: #loc2 = loc(fused<#loc>[#loc1])
diff --git a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
index 66578b246afab1..8ef3cb82fe159e 100644
--- a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
@@ -193,6 +193,11 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
StringRef("test_alias_conflict0_"))
.Case("alias_test:tensor_encoding", StringRef("test_encoding"))
.Default(std::nullopt);
+
+ // Create a dialect alias for strings starting with "test_dialect_alias:"
+ if (strAttr.getValue().starts_with("test_dialect_alias:"))
+ return AliasResult::DialectAlias;
+
if (!aliasName)
return AliasResult::NoAlias;
@@ -200,6 +205,33 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
return AliasResult::FinalAlias;
}
+ void printDialectAlias(DialectAsmPrinter &printer,
+ Attribute attr) const override {
+ if (StringAttr strAttr = dyn_cast<StringAttr>(attr)) {
+ // Drop "test_dialect_alias:" from the front of the string
+ StringRef value = strAttr.getValue();
+ value.consume_front("test_dialect_alias:");
+ printer << "test_string<\"" << value << "\">";
+ }
+ }
+
+ LogicalResult parseDialectAlias(DialectAsmParser &parser, Attribute &attr,
+ Type type) const override {
+ return AsmParser::KeywordSwitch<LogicalResult>(parser)
+ // Alias !test.test_string<"..."> to StringAttr
+ .Case("test_string",
+ [&](llvm::StringRef, llvm::SMLoc) {
+ std::string str;
+ if (parser.parseLess() || parser.parseString(&str) ||
+ parser.parseGreater())
+ return success();
+ attr = parser.getBuilder().getStringAttr("test_dialect_alias:" +
+ str);
+ return success();
+ })
+ .Default([&](StringRef keyword, SMLoc) { return failure(); });
+ }
+
AliasResult getAlias(Type type, raw_ostream &os) const final {
if (auto tupleType = dyn_cast<TupleType>(type)) {
if (tupleType.size() > 0 &&
@@ -229,9 +261,44 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
os << recAliasType.getName();
return AliasResult::FinalAlias;
}
+ // Create a dialect alias for tensor<3x!test.int<...>>
+ if (auto tensorTy = dyn_cast<TensorType>(type);
+ tensorTy && isa<TestIntegerType>(tensorTy.getElementType()) &&
+ tensorTy.hasRank()) {
+ ArrayRef<int64_t> shape = tensorTy.getShape();
+ if (shape.size() == 1 && shape[0] == 3)
+ return AliasResult::DialectAlias;
+ }
return AliasResult::NoAlias;
}
+ void printDialectAlias(DialectAsmPrinter &printer, Type type) const override {
+ if (auto tensorTy = dyn_cast<TensorType>(type);
+ tensorTy && isa<TestIntegerType>(tensorTy.getElementType()) &&
+ tensorTy.hasRank()) {
+ // Alias tensor<3x!test.int<...>> to !test.tensor_int3<!test.int<...>>
+ ArrayRef<int64_t> shape = tensorTy.getShape();
+ if (shape.size() == 1 && shape[0] == 3)
+ printer << "tensor_int3" << "<" << tensorTy.getElementType() << ">";
+ }
+ }
+
+ LogicalResult parseDialectAlias(DialectAsmParser &parser,
+ Type &type) const override {
+ return AsmParser::KeywordSwitch<LogicalResult>(parser)
+ // Alias !test.tensor_int3<IntType> to tensor<3xIntType>
+ .Case("tensor_int3",
+ [&](llvm::StringRef, llvm::SMLoc) {
+ if (parser.parseLess() || parser.parseType(type) ||
+ parser.parseGreater())
+ type = nullptr;
+ if (isa<TestIntegerType>(type) || isa<IntegerType>(type))
+ type = RankedTensorType::get({3}, type);
+ return success();
+ })
+ .Default([&](StringRef keyword, SMLoc) { return failure(); });
+ }
+
//===------------------------------------------------------------------===//
// Resources
//===------------------------------------------------------------------===//
|
Thanks! I'll try to get to this over the weekend. |
@@ -21,6 +21,12 @@ | |||
// CHECK-DAG: !test_tuple = tuple<!test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla> | |||
"test.op"() {alias_test = "alias_test:large_tuple"} : () -> (tuple<!test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla>) | |||
|
|||
// CHECK-DAG: #test.test_string<"hello">, #test.test_string<" world"> | |||
"test.op"() {alias_test = ["test_dialect_alias:hello", "test_dialect_alias: world"]} : () -> () |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this allowing an arbitrary string prefix that any dialect could set? E.g., dialect foo can choose "ba_dialect_alias:" and ba can choose "baz_dialect@"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Technically, yes. However, what it truly allows is hijacking any attribute and type's representation with a custom printer and parsers by a dialect.
For example, test could decide
#test.ones<2, f32>
alias to [1.: f32, 1.: f32]
.
Edit: The internal representation of "test_dialect_alias:hello"
is still "test_dialect_alias:hello"
, #test.test_string<"hello">
is just pretty printing the string.
I am curious about the motivation. Have you considered other options such as the equivalent of |
This is a spin-off from the pointer dialect proposal, see my post here: https://discourse.llvm.org/t/rfc-ptr-dialect-modularizing-ptr-ops-in-the-llvm-dialect/75142/10 |
There's a toned down version of this idea, where only attributes and types implementing the However, that version requires the attribute or type to internally hold which dialect to use during printing (ie. not all objects can be aliased, for example |
I see thank you! Is the following a correct summary? The point of introducing this is to make it possible that I'll have to think further about the implementation, but the motivation at first glance did not match up with the code, causing a bit of confusion. |
Yes. However, for parsing one doesn't need to return Here's a prototype of the toned down version: |
After making the change |
Can you share the benchmark? (push to a branch somewhere) |
The benchmark is just timing I'm observing on my side inconsistent timings with the change with no good explanation as to why it's happening: Sometimes there's no performance difference, sometimes the time of the change is considerably slower. |
|
TestI used python to create a synthetic test: for i in range(0, 10000):
print("func.func @func{0}(%A: memref<?xi32, {0}>, %i: index) {{".format(i))
for j in range(0, 20):
print(" %{0} = memref.load %A[%i] : memref<?xi32, {1}>".format(j, i))
print(" memref.store %{0}, %A[%i] : memref<?xi32, {1}>".format(j, i))
print(" return")
print("}") Which then I transformed to LLVM with: mlir-opt test.mlir --finalize-memref-to-llvm --convert-func-to-llvm --canonicalize -o llvm.mlir To get something that looks like: %0 = llvm.getelementptr %arg1[%arg5] : (!llvm.ptr, i64) -> !llvm.ptr, i32
%1 = ptr.load %0 : !llvm.ptr -> i32
%2 = llvm.getelementptr %arg1[%arg5] : (!llvm.ptr, i64) -> !llvm.ptr, i32
ptr.store %1, %2 : i32, !llvm.ptr and %0 = llvm.getelementptr %arg1[%arg5] : (!llvm.ptr, i64) -> !llvm.ptr, i32
%1 = llvm.load %0 : !llvm.ptr -> i32
%2 = llvm.getelementptr %arg1[%arg5] : (!llvm.ptr, i64) -> !llvm.ptr, i32
llvm.store %1, %2 : i32, !llvm.ptr for a total of mlir-opt llvm.mlir --mlir-timing -o llvm-other.mlir Before the change:
After the change:
ResultsParsing takes less time after change, without a good reason as to why. But printing does take 8% more after the change. Side noteWhile I was moving flang tests to use Ptr, I realized that this method has another flaw. Aliases are not always computed, for example, when emitting errors it prints the full version of the type. |
An alternative design would maybe be to make this more of a first class citizen and have dialects regsiter dialect aliases the same way dialects register types, attributes and operations (e.g. the LLVM dialect registers an alias for the pointer type in the pointer dialect). Having the dialect register would allow a future patch to add support for even generating these with ODS to make sure that e.g. documentation for |
This is what we want. Because Having said that, I like the idea of making them a first-class concept, however, that involves changing Also worth noting, the toned down version doesn't suffer from any of this issues, as the type or attr directly specifies which dialect to use during printing. |
The way I read @zero9178 post, there wouldn't be any change to any of the Core Storage or anything: it is still an alias. You just keep the ODS as a driver to declare the alias so you keep the documentation of the alias, and also auto-generate the c++ code for the aliasing dialect. |
From what I understood in @zero9178 comment the idea is to register them as aliases somewhere. But, I might be missing something. Also, |
ASM dialect aliases provide a mechanism to provide arbitrary aliases to types and attributes in pretty form. To use these aliases, users must complete several steps. For printing an alias, they need to: - Implement the method `OpAsmDialectInterface::getAlias` and return `AliasResult::DialectAlias` for the aliased types instances. - Implement the method `OpAsmDialectInterface::printDialectAlias`, printing the alias however the user sees fit. For parsing an alias, the steps are: - Implement `OpAsmDialectInterface::parseDialectAlias` for the aliased types. Users also must attach the interface `OpAsmDialectInterface` to the dialect creating the alias. An example of this mechanism was added to the tests, specifically: `"test_dialect_alias:..."` alias to `#test.test_string<...>` `tensor<3x!test.int<...>>` alias to `!test.tensor_int3<test.int<...>>` This change is needed to alias "!llvm.ptr" with "ptr".
e32dad6
to
2a52522
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're both correct. My concern is with the usability and the understandability of the IR and I think having it be supported in ODS to look as much as normal types and attributes with all the supporting infrastructure would be a big help.
That said, I don't want to demand ODS support in this PR but rather having it as a future goal.
a way to assign identifiers to types and attributes for syntax sugar
This is what we want. Because !llvm.ptr<3> is just syntax sugar for !ptr.ptr<#llvm.address_space<3>>. When interacting > with !llvm.ptr<3> one is still using !ptr.ptr methods.
The current aliasing infrastructure supported by getAlias
does not change the syntax of attributes and types, but just their uses so to say, and is fully transparent to any dialect. As you noticed when using dump
, if these concepts are conflated then it becomes more difficult to change the behaviour of either "dialect aliases" or "attribute/type aliases" without affecting the other. I don't really see a lot of overlap between the two besides both being syntax sugar and having "alias" in their name.
What do you think about having an API like Dialect::addDialectAlias<Ptr::PointerType, LLVM::LLVMPointerType>()
function, where LLVM::LLVMPointerType
optionally subclasses Ptr::PointerType
and implements classof
, print
and parse
?
The classof
is used so that isa
calls work and after registration can also be used by the printer to check whether a Ptr::PointerType
is actually a LLVM::LLVMPointerType
that should be printed as such. Whether and how much this changes core, the context etc I see more as an implementation detail at this point of the design stage and not a constraint (after all we all love changing core 🙂) and something that automatically requires larger discussions compared to the current stage. As mehdi said, the LLVM::LLVMPointerType
could then be ODS generated in the future. By requiring and using print
and parse
methods it'd also work with existing MLIR features such as unqualified printing, something that I don't think works in the current implementation?
If you think about it, the concept of a class acting as the dialect alias is not a new concept in MLIR core either as the builtin dialect does the same with BoolAttr
, FlatSymbolRefAttr
etc. Its only really the custom printing support that is missing. This could allow having addDialectAlias<IntegerAttr, BoolAttr>
with BoolAttr
being ODS generated in the future 🙂
If that doesn't sound good to you I'd already be happy with just not using getAlias
to implement this. I admittedly don't care too much about the performance of printing and parsing (as long as its not horrible), and mostly about the usability and understandability of the APIs and IR. Making concepts first class helps in this regard IMO.
With the current proposal it should be possible to support ODS, however, the core issue of aliases not always being printed would remain. After thinking about it a bit more, I think registering them is not going to help (I might be wrong).
If we want all three conditions, then, as types: an instance of Which is what I'm currently doing manually, see: LLVMTypes.cpp#L256-L294 Hence, what would registration do?
Thus, I think this proposal might be as good as it gets (I might be wrong). But again, it should be possible to generate ODS for the current proposal, for implementing a subclass with |
]; | ||
} | ||
|
||
//===----------------------------------------------------------------------===// |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also worth nothing, that for LLVMPointerType
this interface (TypeAsmAliasTypeInterface
) solves the issue of printing in all circumstances.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With the current proposal it should be possible to support ODS, however, the core issue of aliases not always being printed would remain.
After thinking about it a bit more, I think registering them is not going to help (I might be wrong). Let's agree in the following conditions:
If
type=!llvm.ptr
, then:
isa<LLVMPointerType>(type) == isa<PtrType>(type)
andTypeID::get<PtrType>() == type.getTypeID()
If
type!=!llvm.ptr
, then:
isa<LLVMPointerType>(type) != isa<PtrType>(type)
andTypeID::get<PtrType>() == type.getTypeID()
We also need that in all circumstances regardless of their origin:
!ptr.ptr<#llvm.address_space> = !llvm.ptr
This condition refers toLLVMPointerType::get(0) == PtrType::get(LLVMAddressSpace::get(0))
.If we want all three conditions, then, as types: an instance of
LLVMPointerType
must be indistinguishable from aPtrType
. The only thing that could tell that something is aLLVMPointerType
isLLVMPointerType
itself.Which is what I'm currently doing manually, see: LLVMTypes.cpp#L256-L294
That's great thank you!
Hence, what would registration do?
- It wouldn't affect printing because the type alone has no way to know it's a
LLVMPointerType
and testing to check if its, would be too time consuming. And using an extraTypeID
or extra information is not possible without violating the 3rd condition.
The addDialectAlias
would simply be an alternative way to later give the printer the information that it has to check whether a Ptr::PointerType
is an instance of LLVMPointerType
or not prior to deciding which dialect to use when printing. That said, I am actually pretty happy with your interface and like the fact that it is opt-in, even if that makes it less powerful. I can also see how its more performant.
Left some comments about the implementation about the current approach then. It looks to me like it could be simplified if only allowing a different dialect to be used for printing and parsing through the interface.
/// Hooks for parsing a dialect alias. The method returns success if the | ||
/// dialect has an alias for the symbol, otherwise it must return failure. | ||
/// If there was an error during parsing, this method should return success | ||
/// and set the attribute to null. | ||
virtual LogicalResult parseDialectAlias(DialectAsmParser &parser, | ||
Attribute &attr, Type type) const { | ||
return failure(); | ||
} | ||
virtual LogicalResult parseDialectAlias(DialectAsmParser &parser, | ||
Type &type) const { | ||
return failure(); | ||
} | ||
/// Hooks for printing a dialect alias. The method returns success if the | ||
/// dialect has an alias for the symbol, otherwise it must return failure. | ||
virtual LogicalResult printDialectAlias(DialectAsmPrinter &printer, | ||
Attribute attr) const { | ||
return failure(); | ||
} | ||
virtual LogicalResult printDialectAlias(DialectAsmPrinter &printer, | ||
Type type) const { | ||
return failure(); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are these APIs required? When parsing and the parser sees a !llvm
, it will call dialect->parseType
right?
Similarily for printing, after the dialect that should be used for printing has been determined through the interface call, couldn't one just call dialect->printType
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. I introduced this methods just to make it easier to have aliases without having to implement the parseType
function in the dialect. However, if the plan is to have this backed this by ODS at some point, then there's no need for them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd personally prefer not having these unless required. We can always add these once there is a use-case. I don't think they add much clarity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the plan is eventually doing ODS, then I agree these should be removed.
/// A dialect alias was provided and it will be used | ||
/// (no other hooks will be checked). | ||
DialectAlias |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this still required? Can't the printer just check whether the type/attribute implements the interface, if it does call getAliasDialect
, and then call its print*
method?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DialectAlias
is different from the TypeAsmAliasTypeInterface
.
- With
DialectAlias
you can alias any type or attribute, however, the type might not always be printed using the alias. TypeAsmAliasTypeInterface
, you can only alias objects that allow aliases. So it has more constraints in what it can do.
For example, with the first mechanism you can alias IntegerType::get(2^exp)
to p2_int<exp>
, and keep i3
as i3
, however the second mechanism allows aliasing the type to only one dialect (unless the type implements a specific alias rule).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there actually a use case for this and do you need that powerful of a mechanism for the pointer dialect? In the interest of landing this PR with the interface mechanism I'd just drop that mechanism and implement it in another PR of really needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not really. I was just trying to solve the problem in general, and I'm open drop it. However, I think there's no rush with landing this, as Euro LLVM will happen soon, hence all the other Ptr patches are not likely to land soon.
I'll implement the DenseMap approach to talk with real numbers. If you're interested, we can also talk more about it in the meeting.
The problem here is that we would need something like:
In |
The latter version is how I'd roughly had thought that's how it'd be done. I doubt it'd have killed performance or even been a lot different in that regard to your current version, given that the number of registered aliases would be very low in a program (so not a lot of elements in the SmallVector. E.g. just one for |
You're correct the lookup itself wouldn't kill performance, as it is equivalent to Maybe, we agree that that's the price to use aliases, and that in general users shouldn't use many. However, that's a design decision we need to make. |
It certainly would be the price if someone were to register thousands of aliases. But we're getting side tracked here, I think the interface variant is fine and if you prefer it then let's go with that. |
ASM dialect aliases provide a mechanism to provide arbitrary aliases to types and attributes in pretty form.
To use these aliases, users must complete several steps. For printing an alias, they need to:
OpAsmDialectInterface::getAlias
and returnAliasResult::DialectAlias
for the aliased types instances.OpAsmDialectInterface::printDialectAlias
, printing the alias however the user sees fit.For parsing an alias, the steps are:
OpAsmDialectInterface::parseDialectAlias
for the aliased types.Users also must attach the interface
OpAsmDialectInterface
to the dialect creating the alias.An example of this mechanism was added to the tests, specifically:
"test_dialect_alias:..."
alias to#test.test_string<...>
tensor<3x!test.int<...>>
alias to!test.tensor_int3<test.int<...>>
This change is needed to alias "!llvm.ptr" with "ptr".