Skip to content

[mlir] Add the concept of ASM dialect aliases #86033

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

fabianmcg
Copy link
Contributor

@fabianmcg fabianmcg commented Mar 20, 2024

ASM dialect aliases provide a mechanism to provide arbitrary aliases to types and attributes in pretty form.

To use these aliases, users must complete several steps. For printing an alias, they need to:

  • Implement the method OpAsmDialectInterface::getAlias and return AliasResult::DialectAlias for the aliased types instances.
  • Implement the method OpAsmDialectInterface::printDialectAlias, printing the alias however the user sees fit.

For parsing an alias, the steps are:

  • Implement OpAsmDialectInterface::parseDialectAlias for the aliased types.

Users also must attach the interface OpAsmDialectInterface to the dialect creating the alias.

An example of this mechanism was added to the tests, specifically:

  • "test_dialect_alias:..." alias to #test.test_string<...>
  • tensor<3x!test.int<...>> alias to !test.tensor_int3<test.int<...>>

This change is needed to alias "!llvm.ptr" with "ptr".

Copy link

github-actions bot commented Mar 20, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@fabianmcg fabianmcg marked this pull request as ready for review March 20, 2024 23:52
@fabianmcg fabianmcg requested review from ftynse and zero9178 March 20, 2024 23:53
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Mar 20, 2024
@fabianmcg fabianmcg requested a review from joker-eph March 20, 2024 23:53
@llvmbot
Copy link
Member

llvmbot commented Mar 20, 2024

@llvm/pr-subscribers-mlir-ods
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Fabian Mora (fabianmcg)

Changes

ASM dialect aliases provide a mechanism to provide arbitrary aliases to types and attributes in pretty form.

To use these aliases, users must complete several steps. For printing an alias, they need to:

  • Implement the method OpAsmDialectInterface::getAlias and return AliasResult::DialectAlias for the aliased types instances.
  • Implement the method OpAsmDialectInterface::printDialectAlias, printing the alias however the user sees fit.

For parsing an alias, the steps are:

  • Implement OpAsmDialectInterface::parseDialectAlias for the aliased types.

Users also must attach the interface OpAsmDialectInterface to the dialect creating the alias.

An example of this mechanism was added to the tests, specifically: "test_dialect_alias:..." alias to #test.test_string&lt;...&gt; tensor&lt;3x!test.int&lt;...&gt;&gt; alias to !test.tensor_int3&lt;test.int&lt;...&gt;&gt;

This change is needed to alias "!llvm.ptr" with "ptr".


Full diff: https://github.com/llvm/llvm-project/pull/86033.diff

5 Files Affected:

  • (modified) mlir/include/mlir/IR/OpImplementation.h (+27-1)
  • (modified) mlir/lib/AsmParser/DialectSymbolParser.cpp (+12)
  • (modified) mlir/lib/IR/AsmPrinter.cpp (+95-9)
  • (modified) mlir/test/IR/print-attr-type-aliases.mlir (+15)
  • (modified) mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp (+67)
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 5333d7446df5ca..5c3d93f8e0cff6 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -1715,7 +1715,10 @@ class OpAsmDialectInterface
     OverridableAlias,
     /// An alias was provided and it should be used
     /// (no other hooks will be checked).
-    FinalAlias
+    FinalAlias,
+    /// A dialect alias was provided and it will be used
+    /// (no other hooks will be checked).
+    DialectAlias
   };
 
   /// Hooks for getting an alias identifier alias for a given symbol, that is
@@ -1729,6 +1732,29 @@ class OpAsmDialectInterface
     return AliasResult::NoAlias;
   }
 
+  /// Hooks for parsing a dialect alias. The method returns success if the
+  /// dialect has an alias for the symbol, otherwise it must return failure.
+  /// If there was an error during parsing, this method should return success
+  /// and set the attribute to null.
+  virtual LogicalResult parseDialectAlias(DialectAsmParser &parser,
+                                          Attribute &attr, Type type) const {
+    return failure();
+  }
+  virtual LogicalResult parseDialectAlias(DialectAsmParser &parser,
+                                          Type &type) const {
+    return failure();
+  }
+  /// Hooks for printing a dialect alias.
+  virtual void printDialectAlias(DialectAsmPrinter &printer,
+                                 Attribute attr) const {
+    llvm_unreachable("Dialect must implement `printDialectAlias` when defining "
+                     "dialect aliases");
+  }
+  virtual void printDialectAlias(DialectAsmPrinter &printer, Type type) const {
+    llvm_unreachable("Dialect must implement `printDialectAlias` when defining "
+                     "dialect aliases");
+  }
+
   //===--------------------------------------------------------------------===//
   // Resources
   //===--------------------------------------------------------------------===//
diff --git a/mlir/lib/AsmParser/DialectSymbolParser.cpp b/mlir/lib/AsmParser/DialectSymbolParser.cpp
index 80cce7e6ae43f5..9261ef2fb3eb95 100644
--- a/mlir/lib/AsmParser/DialectSymbolParser.cpp
+++ b/mlir/lib/AsmParser/DialectSymbolParser.cpp
@@ -269,6 +269,12 @@ Attribute Parser::parseExtendedAttr(Type type) {
 
           // Parse the attribute.
           CustomDialectAsmParser customParser(symbolData, *this);
+          if (auto iface = dyn_cast<OpAsmDialectInterface>(dialect)) {
+            Attribute attr{};
+            if (succeeded(iface->parseDialectAlias(customParser, attr, type)))
+              return attr;
+            resetToken(symbolData.data());
+          }
           Attribute attr = dialect->parseAttribute(customParser, attrType);
           resetToken(curLexerPos);
           return attr;
@@ -310,6 +316,12 @@ Type Parser::parseExtendedType() {
 
           // Parse the type.
           CustomDialectAsmParser customParser(symbolData, *this);
+          if (auto iface = dyn_cast<OpAsmDialectInterface>(dialect)) {
+            Type type{};
+            if (succeeded(iface->parseDialectAlias(customParser, type)))
+              return type;
+            resetToken(symbolData.data());
+          }
           Type type = dialect->parseType(customParser);
           resetToken(curLexerPos);
           return type;
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 456cf6a2c27783..7aabc360517ac0 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -542,7 +542,9 @@ class AliasInitializer {
         aliasOS(aliasBuffer) {}
 
   void initialize(Operation *op, const OpPrintingFlags &printerFlags,
-                  llvm::MapVector<const void *, SymbolAlias> &attrTypeToAlias);
+                  llvm::MapVector<const void *, SymbolAlias> &attrTypeToAlias,
+                  llvm::DenseMap<const void *, const OpAsmDialectInterface *>
+                      &attrTypeToDialectAlias);
 
   /// Visit the given attribute to see if it has an alias. `canBeDeferred` is
   /// set to true if the originator of this attribute can resolve the alias
@@ -570,6 +572,10 @@ class AliasInitializer {
     InProgressAliasInfo(StringRef alias, bool isType, bool canBeDeferred)
         : alias(alias), aliasDepth(1), isType(isType),
           canBeDeferred(canBeDeferred) {}
+    InProgressAliasInfo(const OpAsmDialectInterface *aliasDialect, bool isType,
+                        bool canBeDeferred)
+        : alias(std::nullopt), aliasDepth(1), isType(isType),
+          canBeDeferred(canBeDeferred), aliasDialect(aliasDialect) {}
 
     bool operator<(const InProgressAliasInfo &rhs) const {
       // Order first by depth, then by attr/type kind, and then by name.
@@ -577,6 +583,8 @@ class AliasInitializer {
         return aliasDepth < rhs.aliasDepth;
       if (isType != rhs.isType)
         return isType;
+      if (aliasDialect != rhs.aliasDialect)
+        return aliasDialect < rhs.aliasDialect;
       return alias < rhs.alias;
     }
 
@@ -592,6 +600,8 @@ class AliasInitializer {
     bool canBeDeferred : 1;
     /// Indices for child aliases.
     SmallVector<size_t> childIndices;
+    /// Dialect interface used to print the alias.
+    const OpAsmDialectInterface *aliasDialect{};
   };
 
   /// Visit the given attribute or type to see if it has an alias.
@@ -617,7 +627,9 @@ class AliasInitializer {
   /// symbol to a given alias.
   static void initializeAliases(
       llvm::MapVector<const void *, InProgressAliasInfo> &visitedSymbols,
-      llvm::MapVector<const void *, SymbolAlias> &symbolToAlias);
+      llvm::MapVector<const void *, SymbolAlias> &symbolToAlias,
+      llvm::DenseMap<const void *, const OpAsmDialectInterface *>
+          &attrTypeToDialectAlias);
 
   /// The set of asm interfaces within the context.
   DialectInterfaceCollection<OpAsmDialectInterface> &interfaces;
@@ -1027,7 +1039,9 @@ static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer,
 /// symbol to a given alias.
 void AliasInitializer::initializeAliases(
     llvm::MapVector<const void *, InProgressAliasInfo> &visitedSymbols,
-    llvm::MapVector<const void *, SymbolAlias> &symbolToAlias) {
+    llvm::MapVector<const void *, SymbolAlias> &symbolToAlias,
+    llvm::DenseMap<const void *, const OpAsmDialectInterface *>
+        &attrTypeToDialectAlias) {
   SmallVector<std::pair<const void *, InProgressAliasInfo>, 0>
       unprocessedAliases = visitedSymbols.takeVector();
   llvm::stable_sort(unprocessedAliases, [](const auto &lhs, const auto &rhs) {
@@ -1036,8 +1050,12 @@ void AliasInitializer::initializeAliases(
 
   llvm::StringMap<unsigned> nameCounts;
   for (auto &[symbol, aliasInfo] : unprocessedAliases) {
-    if (!aliasInfo.alias)
+    if (!aliasInfo.alias && !aliasInfo.aliasDialect)
       continue;
+    if (aliasInfo.aliasDialect) {
+      attrTypeToDialectAlias.insert({symbol, aliasInfo.aliasDialect});
+      continue;
+    }
     StringRef alias = *aliasInfo.alias;
     unsigned nameIndex = nameCounts[alias]++;
     symbolToAlias.insert(
@@ -1048,7 +1066,9 @@ void AliasInitializer::initializeAliases(
 
 void AliasInitializer::initialize(
     Operation *op, const OpPrintingFlags &printerFlags,
-    llvm::MapVector<const void *, SymbolAlias> &attrTypeToAlias) {
+    llvm::MapVector<const void *, SymbolAlias> &attrTypeToAlias,
+    llvm::DenseMap<const void *, const OpAsmDialectInterface *>
+        &attrTypeToDialectAlias) {
   // Use a dummy printer when walking the IR so that we can collect the
   // attributes/types that will actually be used during printing when
   // considering aliases.
@@ -1056,7 +1076,7 @@ void AliasInitializer::initialize(
   aliasPrinter.printCustomOrGenericOp(op);
 
   // Initialize the aliases.
-  initializeAliases(aliases, attrTypeToAlias);
+  initializeAliases(aliases, attrTypeToAlias, attrTypeToDialectAlias);
 }
 
 template <typename T, typename... PrintArgs>
@@ -1113,9 +1133,14 @@ template <typename T>
 void AliasInitializer::generateAlias(T symbol, InProgressAliasInfo &alias,
                                      bool canBeDeferred) {
   SmallString<32> nameBuffer;
+  const OpAsmDialectInterface *dialectAlias = nullptr;
   for (const auto &interface : interfaces) {
     OpAsmDialectInterface::AliasResult result =
         interface.getAlias(symbol, aliasOS);
+    if (result == OpAsmDialectInterface::AliasResult::DialectAlias) {
+      dialectAlias = &interface;
+      break;
+    }
     if (result == OpAsmDialectInterface::AliasResult::NoAlias)
       continue;
     nameBuffer = std::move(aliasBuffer);
@@ -1123,6 +1148,11 @@ void AliasInitializer::generateAlias(T symbol, InProgressAliasInfo &alias,
     if (result == OpAsmDialectInterface::AliasResult::FinalAlias)
       break;
   }
+  if (dialectAlias) {
+    alias = InProgressAliasInfo(
+        dialectAlias, /*isType=*/std::is_base_of_v<Type, T>, canBeDeferred);
+    return;
+  }
 
   if (nameBuffer.empty())
     return;
@@ -1157,6 +1187,13 @@ class AliasState {
   /// Returns success if an alias was printed, failure otherwise.
   LogicalResult getAlias(Type ty, raw_ostream &os) const;
 
+  /// Get a dialect alias for the given attribute if it has one or return
+  /// nullptr.
+  const OpAsmDialectInterface *getDialectAlias(Attribute attr) const;
+
+  /// Get a dialect alias for the given type if it has one or return nullptr.
+  const OpAsmDialectInterface *getDialectAlias(Type ty) const;
+
   /// Print all of the referenced aliases that can not be resolved in a deferred
   /// manner.
   void printNonDeferredAliases(AsmPrinter::Impl &p, NewLineCounter &newLine) {
@@ -1177,6 +1214,10 @@ class AliasState {
   /// Mapping between attribute/type and alias.
   llvm::MapVector<const void *, SymbolAlias> attrTypeToAlias;
 
+  /// Mapping between attribute/type and alias dialect interfaces.
+  llvm::DenseMap<const void *, const OpAsmDialectInterface *>
+      attrTypeToDialectAlias;
+
   /// An allocator used for alias names.
   llvm::BumpPtrAllocator aliasAllocator;
 };
@@ -1186,7 +1227,8 @@ void AliasState::initialize(
     Operation *op, const OpPrintingFlags &printerFlags,
     DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
   AliasInitializer initializer(interfaces, aliasAllocator);
-  initializer.initialize(op, printerFlags, attrTypeToAlias);
+  initializer.initialize(op, printerFlags, attrTypeToAlias,
+                         attrTypeToDialectAlias);
 }
 
 LogicalResult AliasState::getAlias(Attribute attr, raw_ostream &os) const {
@@ -1206,6 +1248,20 @@ LogicalResult AliasState::getAlias(Type ty, raw_ostream &os) const {
   return success();
 }
 
+const OpAsmDialectInterface *AliasState::getDialectAlias(Attribute attr) const {
+  auto it = attrTypeToDialectAlias.find(attr.getAsOpaquePointer());
+  if (it == attrTypeToDialectAlias.end())
+    return nullptr;
+  return it->second;
+}
+
+const OpAsmDialectInterface *AliasState::getDialectAlias(Type ty) const {
+  auto it = attrTypeToDialectAlias.find(ty.getAsOpaquePointer());
+  if (it == attrTypeToDialectAlias.end())
+    return nullptr;
+  return it->second;
+}
+
 void AliasState::printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine,
                               bool isDeferred) {
   auto filterFn = [=](const auto &aliasIt) {
@@ -2189,11 +2245,41 @@ static void printElidedElementsAttr(raw_ostream &os) {
 }
 
 LogicalResult AsmPrinter::Impl::printAlias(Attribute attr) {
-  return state.getAliasState().getAlias(attr, os);
+  if (succeeded(state.getAliasState().getAlias(attr, os)))
+    return success();
+  const OpAsmDialectInterface *iface =
+      state.getAliasState().getDialectAlias(attr);
+  if (!iface)
+    return failure();
+  // Ask the dialect to serialize the attribute to a string.
+  std::string attrName;
+  {
+    llvm::raw_string_ostream attrNameStr(attrName);
+    Impl subPrinter(attrNameStr, state);
+    DialectAsmPrinter printer(subPrinter);
+    iface->printDialectAlias(printer, attr);
+  }
+  printDialectSymbol(os, "#", iface->getDialect()->getNamespace(), attrName);
+  return success();
 }
 
 LogicalResult AsmPrinter::Impl::printAlias(Type type) {
-  return state.getAliasState().getAlias(type, os);
+  if (succeeded(state.getAliasState().getAlias(type, os)))
+    return success();
+  const OpAsmDialectInterface *iface =
+      state.getAliasState().getDialectAlias(type);
+  if (!iface)
+    return failure();
+  // Ask the dialect to serialize the type to a string.
+  std::string typeName;
+  {
+    llvm::raw_string_ostream typeNameStr(typeName);
+    Impl subPrinter(typeNameStr, state);
+    DialectAsmPrinter printer(subPrinter);
+    iface->printDialectAlias(printer, type);
+  }
+  printDialectSymbol(os, "!", iface->getDialect()->getNamespace(), typeName);
+  return success();
 }
 
 void AsmPrinter::Impl::printAttribute(Attribute attr,
diff --git a/mlir/test/IR/print-attr-type-aliases.mlir b/mlir/test/IR/print-attr-type-aliases.mlir
index 162eacd0022832..6065573072e638 100644
--- a/mlir/test/IR/print-attr-type-aliases.mlir
+++ b/mlir/test/IR/print-attr-type-aliases.mlir
@@ -21,6 +21,12 @@
 // CHECK-DAG: !test_tuple = tuple<!test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla>
 "test.op"() {alias_test = "alias_test:large_tuple"} : () -> (tuple<!test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla>)
 
+// CHECK-DAG: #test.test_string<"hello">, #test.test_string<" world">
+"test.op"() {alias_test = ["test_dialect_alias:hello", "test_dialect_alias: world"]} : () -> ()
+
+// CHECK-DAG: #test.test_string<"hello">, #test.test_string<" world">
+"test.op"() {alias_test = [#test.test_string<"hello">, #test.test_string<" world">]} : () -> ()
+
 // CHECK-DAG: #test_encoding = "alias_test:tensor_encoding"
 // CHECK-DAG: tensor<32xf32, #test_encoding>
 "test.op"() : () -> tensor<32xf32, "alias_test:tensor_encoding">
@@ -29,6 +35,15 @@
 // CHECK-DAG: tensor<32x!test_ui8_>
 "test.op"() : () -> tensor<32x!test.int<unsigned, 8>>
 
+// CHECK-DAG: !test.tensor_int3<!test_ui8_>
+"test.op"() : () -> tensor<3x!test.int<unsigned, 8>>
+
+// CHECK-DAG: !test.tensor_int3<!test.int<signed, 8>>
+"test.op"() : () -> !test.tensor_int3<!test.int<signed, 8>>
+
+// CHECK-DAG: tensor<3xi3>
+"test.op"() : () -> !test.tensor_int3<i3>
+
 // CHECK-DAG: #loc = loc("nested")
 // CHECK-DAG: #loc1 = loc("test.mlir":10:8)
 // CHECK-DAG: #loc2 = loc(fused<#loc>[#loc1])
diff --git a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
index 66578b246afab1..8ef3cb82fe159e 100644
--- a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
@@ -193,6 +193,11 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
                   StringRef("test_alias_conflict0_"))
             .Case("alias_test:tensor_encoding", StringRef("test_encoding"))
             .Default(std::nullopt);
+
+    // Create a dialect alias for strings starting with "test_dialect_alias:"
+    if (strAttr.getValue().starts_with("test_dialect_alias:"))
+      return AliasResult::DialectAlias;
+
     if (!aliasName)
       return AliasResult::NoAlias;
 
@@ -200,6 +205,33 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
     return AliasResult::FinalAlias;
   }
 
+  void printDialectAlias(DialectAsmPrinter &printer,
+                         Attribute attr) const override {
+    if (StringAttr strAttr = dyn_cast<StringAttr>(attr)) {
+      // Drop "test_dialect_alias:" from the front of the string
+      StringRef value = strAttr.getValue();
+      value.consume_front("test_dialect_alias:");
+      printer << "test_string<\"" << value << "\">";
+    }
+  }
+
+  LogicalResult parseDialectAlias(DialectAsmParser &parser, Attribute &attr,
+                                  Type type) const override {
+    return AsmParser::KeywordSwitch<LogicalResult>(parser)
+        // Alias !test.test_string<"..."> to StringAttr
+        .Case("test_string",
+              [&](llvm::StringRef, llvm::SMLoc) {
+                std::string str;
+                if (parser.parseLess() || parser.parseString(&str) ||
+                    parser.parseGreater())
+                  return success();
+                attr = parser.getBuilder().getStringAttr("test_dialect_alias:" +
+                                                         str);
+                return success();
+              })
+        .Default([&](StringRef keyword, SMLoc) { return failure(); });
+  }
+
   AliasResult getAlias(Type type, raw_ostream &os) const final {
     if (auto tupleType = dyn_cast<TupleType>(type)) {
       if (tupleType.size() > 0 &&
@@ -229,9 +261,44 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
       os << recAliasType.getName();
       return AliasResult::FinalAlias;
     }
+    // Create a dialect alias for tensor<3x!test.int<...>>
+    if (auto tensorTy = dyn_cast<TensorType>(type);
+        tensorTy && isa<TestIntegerType>(tensorTy.getElementType()) &&
+        tensorTy.hasRank()) {
+      ArrayRef<int64_t> shape = tensorTy.getShape();
+      if (shape.size() == 1 && shape[0] == 3)
+        return AliasResult::DialectAlias;
+    }
     return AliasResult::NoAlias;
   }
 
+  void printDialectAlias(DialectAsmPrinter &printer, Type type) const override {
+    if (auto tensorTy = dyn_cast<TensorType>(type);
+        tensorTy && isa<TestIntegerType>(tensorTy.getElementType()) &&
+        tensorTy.hasRank()) {
+      // Alias tensor<3x!test.int<...>> to !test.tensor_int3<!test.int<...>>
+      ArrayRef<int64_t> shape = tensorTy.getShape();
+      if (shape.size() == 1 && shape[0] == 3)
+        printer << "tensor_int3" << "<" << tensorTy.getElementType() << ">";
+    }
+  }
+
+  LogicalResult parseDialectAlias(DialectAsmParser &parser,
+                                  Type &type) const override {
+    return AsmParser::KeywordSwitch<LogicalResult>(parser)
+        // Alias !test.tensor_int3<IntType> to tensor<3xIntType>
+        .Case("tensor_int3",
+              [&](llvm::StringRef, llvm::SMLoc) {
+                if (parser.parseLess() || parser.parseType(type) ||
+                    parser.parseGreater())
+                  type = nullptr;
+                if (isa<TestIntegerType>(type) || isa<IntegerType>(type))
+                  type = RankedTensorType::get({3}, type);
+                return success();
+              })
+        .Default([&](StringRef keyword, SMLoc) { return failure(); });
+  }
+
   //===------------------------------------------------------------------===//
   // Resources
   //===------------------------------------------------------------------===//

@joker-eph
Copy link
Collaborator

Thanks! I'll try to get to this over the weekend.

@@ -21,6 +21,12 @@
// CHECK-DAG: !test_tuple = tuple<!test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla>
"test.op"() {alias_test = "alias_test:large_tuple"} : () -> (tuple<!test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla>)

// CHECK-DAG: #test.test_string<"hello">, #test.test_string<" world">
"test.op"() {alias_test = ["test_dialect_alias:hello", "test_dialect_alias: world"]} : () -> ()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this allowing an arbitrary string prefix that any dialect could set? E.g., dialect foo can choose "ba_dialect_alias:" and ba can choose "baz_dialect@"?

Copy link
Contributor Author

@fabianmcg fabianmcg Mar 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically, yes. However, what it truly allows is hijacking any attribute and type's representation with a custom printer and parsers by a dialect.
For example, test could decide
#test.ones<2, f32> alias to [1.: f32, 1.: f32].

Edit: The internal representation of "test_dialect_alias:hello" is still "test_dialect_alias:hello", #test.test_string<"hello"> is just pretty printing the string.

@zero9178
Copy link
Member

I am curious about the motivation. Have you considered other options such as the equivalent of getDefaultDialect of OpAsmInterface?
On first glance this seems like a too powerful option that could cause a great deal of confusion (and potentially ambigiouty?) that proably requires an RFC. A less powerful implementation that satisfies your use-case would probably be nicer. Is my understanding correct that the printer currently allows arbitrary strings to be printed?
For some of the tests I can see a less powerful dialect (or ODS?) hook that'd allow something like !test.tensor_int3<IntType> to map to tensor<3xIntType>. In fact, this is already possible today if a user overwrites Dialect::parseAttribute instead of using the default ODS generated implementation.

@joker-eph
Copy link
Collaborator

I am curious about the motivation

This is a spin-off from the pointer dialect proposal, see my post here: https://discourse.llvm.org/t/rfc-ptr-dialect-modularizing-ptr-ops-in-the-llvm-dialect/75142/10

@fabianmcg
Copy link
Contributor Author

On first glance this seems like a too powerful option

There's a toned down version of this idea, where only attributes and types implementing the (Attr|Type)DialectAliasInterface are allowed to alias.

However, that version requires the attribute or type to internally hold which dialect to use during printing (ie. not all objects can be aliased, for example StringAttr wouldn't accept aliases) and parsing requires overriding the default parsers.
The main benefit of the toned down version is it produces less overhead during parsing and printing. Nonetheless, it's cumbersome to implement.

@zero9178
Copy link
Member

zero9178 commented Mar 21, 2024

I am curious about the motivation

This is a spin-off from the pointer dialect proposal, see my post here: https://discourse.llvm.org/t/rfc-ptr-dialect-modularizing-ptr-ops-in-the-llvm-dialect/75142/10

I see thank you! Is the following a correct summary? The point of introducing this is to make it possible that !ptr.ptr<#llvm.address_space<1>> can be printed and parsed as e.g. !llvm.ptr<1> and the current implementation aims to do so by allowing a dialect to take over the printing and parsing of any attribute/type by returning DialectAlias from getAlias. This is starting to make more sense to me then.

I'll have to think further about the implementation, but the motivation at first glance did not match up with the code, causing a bit of confusion.

@fabianmcg
Copy link
Contributor Author

fabianmcg commented Mar 21, 2024

Is the following a correct summary? The point of introducing this is to make it possible that !ptr.ptr<#llvm.address_space<1>> can be printed and parsed as e.g. !llvm.ptr<1> and the current implementation aims to do so by allowing a dialect to take over the printing and parsing of any attribute/type by returning DialectAlias from getAlias.

Yes. However, for parsing one doesn't need to return DialectAlias from getAlias, that's only needed for printing.

Here's a prototype of the toned down version:

@fabianmcg
Copy link
Contributor Author

fabianmcg commented Mar 24, 2024

After making the change llvm.ptr -> ptr.ptr<#llvm.address_space<0..>>, I ran the tests again and timed them, and this version of dialect aliases is introducing around 17% of overhead.
I'll switch to the toned-down version in which only types implementing the interface TypeWithDialectAlias can alias, unless someone has a better idea.

@joker-eph
Copy link
Collaborator

Can you share the benchmark? (push to a branch somewhere)

@fabianmcg
Copy link
Contributor Author

fabianmcg commented Mar 25, 2024

Can you share the benchmark? (push to a branch somewhere)

The benchmark is just timing ninja check-mlir on release mode. Here's the branch with the ptr dialect and switching the types:
https://github.com/fabianmcg/llvm-project/tree/ptr-dev

I'm observing on my side inconsistent timings with the change with no good explanation as to why it's happening: Sometimes there's no performance difference, sometimes the time of the change is considerably slower.

@joker-eph
Copy link
Collaborator

ninja check-mlir isn't a good benchmark: it's running a lot of things in parallel, does I/O, and thus will be very noisy.
You likely need some single targeted benchmark to do A/B comparison really.

@fabianmcg
Copy link
Contributor Author

Test

I used python to create a synthetic test:

for i in range(0, 10000):
    print("func.func @func{0}(%A: memref<?xi32, {0}>, %i: index) {{".format(i))
    for j in range(0, 20):
        print("  %{0} = memref.load %A[%i] : memref<?xi32, {1}>".format(j, i))
        print("  memref.store %{0}, %A[%i] : memref<?xi32, {1}>".format(j, i))
    print("  return")
    print("}")

Which then I transformed to LLVM with:

mlir-opt test.mlir  --finalize-memref-to-llvm --convert-func-to-llvm --canonicalize -o llvm.mlir

To get something that looks like:

    %0 = llvm.getelementptr %arg1[%arg5] : (!llvm.ptr, i64) -> !llvm.ptr, i32
    %1 = ptr.load %0 : !llvm.ptr -> i32
    %2 = llvm.getelementptr %arg1[%arg5] : (!llvm.ptr, i64) -> !llvm.ptr, i32
    ptr.store %1, %2 : i32, !llvm.ptr

and

    %0 = llvm.getelementptr %arg1[%arg5] : (!llvm.ptr, i64) -> !llvm.ptr, i32
    %1 = llvm.load %0 : !llvm.ptr -> i32
    %2 = llvm.getelementptr %arg1[%arg5] : (!llvm.ptr, i64) -> !llvm.ptr, i32
    llvm.store %1, %2 : i32, !llvm.ptr

for a total of 830003 lines.
Then I tested parse and printing times with:

mlir-opt llvm.mlir --mlir-timing -o llvm-other.mlir

Before the change:

  Total Execution Time: 7.3231 seconds

  ----Wall Time----  ----Name----
    6.1180 ( 83.5%)  Parser
    1.0906 ( 14.9%)  Output
    0.1144 (  1.6%)  Rest
    7.3231 (100.0%)  Total

After the change:

  Total Execution Time: 5.3205 seconds

  ----Wall Time----  ----Name----
    3.0505 ( 57.3%)  Parser
    2.1249 ( 39.9%)  Output
    0.1451 (  2.7%)  Rest
    5.3205 (100.0%)  Total

Results

Parsing takes less time after change, without a good reason as to why. But printing does take 8% more after the change.

Side note

While I was moving flang tests to use Ptr, I realized that this method has another flaw. Aliases are not always computed, for example, when emitting errors it prints the full version of the type.

@zero9178
Copy link
Member

Side note

While I was moving flang tests to use Ptr, I realized that this method has another flaw. Aliases are not always computed, for example, when emitting errors it prints the full version of the type.

An alternative design would maybe be to make this more of a first class citizen and have dialects regsiter dialect aliases the same way dialects register types, attributes and operations (e.g. the LLVM dialect registers an alias for the pointer type in the pointer dialect).
The current implementation conflates the concepts of "dialect alias" (a way to create an alternative syntax within one dialect for an attribute or type in another) with "attribute and type aliases" (a way to assign identifiers to types and attributes for syntax sugar), which are two different concepts to me with different needs.

Having the dialect register would allow a future patch to add support for even generating these with ODS to make sure that e.g. documentation for !llvm.ptr remains as is, to reuse the assemblyFormat infrastructure and maybe even generate a LLVM::LLVMPointerType class inheriting from Ptr::PointerType for strong typing. Main motivation for this suggestion is that otherwise we might have a regression in documentation where newcomers would find a !llvm.ptr in their IR but fail to find any documentation of it because 1) there is no vaguely related sounding type in the mlir::LLVM dialect and 2) no documentation of it at https://mlir.llvm.org/docs/Dialects/LLVM/.

@fabianmcg
Copy link
Contributor Author

a way to assign identifiers to types and attributes for syntax sugar

This is what we want. Because !llvm.ptr<3> is just syntax sugar for !ptr.ptr<#llvm.address_space<3>>. When interacting with !llvm.ptr<3> one is still using !ptr.ptr methods.

Having said that, I like the idea of making them a first-class concept, however, that involves changing AbstractType, MLIRContext and other core classes, and that goes outside the original proposal and requires further discussion.

Also worth noting, the toned down version doesn't suffer from any of this issues, as the type or attr directly specifies which dialect to use during printing.

@joker-eph
Copy link
Collaborator

Having said that, I like the idea of making them a first-class concept, however, that involves changing AbstractType, MLIRContext and other core classes, and that goes outside the original proposal and requires further discussion.

The way I read @zero9178 post, there wouldn't be any change to any of the Core Storage or anything: it is still an alias. You just keep the ODS as a driver to declare the alias so you keep the documentation of the alias, and also auto-generate the c++ code for the aliasing dialect.
For example you had to write class LLVMPointerType : public ptr::PtrType by hand while it could be generated by TableGen.

@fabianmcg
Copy link
Contributor Author

From what I understood in @zero9178 comment the idea is to register them as aliases somewhere. But, I might be missing something.

Also, AsmPrinter::Impl::printDialectType works by retrieving the dialect from the type and then uses it to print the type using the dialect's printing hook. It should be possible to register an alias, but we would need AbstractType to hold at the least an extra TypeID and for aliases to define a TypeID, otherwise the context will complain about the type being already registered.

ASM dialect aliases provide a mechanism to provide arbitrary aliases to types
and attributes in pretty form.

To use these aliases, users must complete several steps. For printing an alias,
they need to:
 - Implement the method `OpAsmDialectInterface::getAlias` and return
   `AliasResult::DialectAlias` for the aliased types instances.
 - Implement the method `OpAsmDialectInterface::printDialectAlias`, printing the
   alias however the user sees fit.

For parsing an alias, the steps are:
 - Implement `OpAsmDialectInterface::parseDialectAlias` for the aliased types.

Users also must attach the interface `OpAsmDialectInterface` to the dialect
creating the alias.

An example of this mechanism was added to the tests, specifically:
`"test_dialect_alias:..."` alias to `#test.test_string<...>`
`tensor<3x!test.int<...>>` alias to `!test.tensor_int3<test.int<...>>`

This change is needed to alias "!llvm.ptr" with "ptr".
Copy link
Member

@zero9178 zero9178 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're both correct. My concern is with the usability and the understandability of the IR and I think having it be supported in ODS to look as much as normal types and attributes with all the supporting infrastructure would be a big help.
That said, I don't want to demand ODS support in this PR but rather having it as a future goal.

a way to assign identifiers to types and attributes for syntax sugar
This is what we want. Because !llvm.ptr<3> is just syntax sugar for !ptr.ptr<#llvm.address_space<3>>. When interacting > with !llvm.ptr<3> one is still using !ptr.ptr methods.

The current aliasing infrastructure supported by getAlias does not change the syntax of attributes and types, but just their uses so to say, and is fully transparent to any dialect. As you noticed when using dump, if these concepts are conflated then it becomes more difficult to change the behaviour of either "dialect aliases" or "attribute/type aliases" without affecting the other. I don't really see a lot of overlap between the two besides both being syntax sugar and having "alias" in their name.

What do you think about having an API like Dialect::addDialectAlias<Ptr::PointerType, LLVM::LLVMPointerType>() function, where LLVM::LLVMPointerType optionally subclasses Ptr::PointerType and implements classof, print and parse?

The classof is used so that isa calls work and after registration can also be used by the printer to check whether a Ptr::PointerType is actually a LLVM::LLVMPointerType that should be printed as such. Whether and how much this changes core, the context etc I see more as an implementation detail at this point of the design stage and not a constraint (after all we all love changing core 🙂) and something that automatically requires larger discussions compared to the current stage. As mehdi said, the LLVM::LLVMPointerType could then be ODS generated in the future. By requiring and using print and parse methods it'd also work with existing MLIR features such as unqualified printing, something that I don't think works in the current implementation?

If you think about it, the concept of a class acting as the dialect alias is not a new concept in MLIR core either as the builtin dialect does the same with BoolAttr, FlatSymbolRefAttr etc. Its only really the custom printing support that is missing. This could allow having addDialectAlias<IntegerAttr, BoolAttr> with BoolAttr being ODS generated in the future 🙂

If that doesn't sound good to you I'd already be happy with just not using getAlias to implement this. I admittedly don't care too much about the performance of printing and parsing (as long as its not horrible), and mostly about the usability and understandability of the APIs and IR. Making concepts first class helps in this regard IMO.

@fabianmcg
Copy link
Contributor Author

With the current proposal it should be possible to support ODS, however, the core issue of aliases not always being printed would remain.

After thinking about it a bit more, I think registering them is not going to help (I might be wrong).
Let's agree in the following conditions:

  1. If type=!llvm.ptr, then:
    • isa<LLVMPointerType>(type) == isa<PtrType>(type) and TypeID::get<PtrType>() == type.getTypeID()
  2. If type!=!llvm.ptr, then:
    • isa<LLVMPointerType>(type) != isa<PtrType>(type) and TypeID::get<PtrType>() == type.getTypeID()
  3. We also need that in all circumstances regardless of their origin:
    • !ptr.ptr<#llvm.address_space> = !llvm.ptr
      This condition refers to LLVMPointerType::get(0) == PtrType::get(LLVMAddressSpace::get(0)).

If we want all three conditions, then, as types: an instance of LLVMPointerType must be indistinguishable from a PtrType. The only thing that could tell that something is a LLVMPointerType is LLVMPointerType itself.

Which is what I'm currently doing manually, see: LLVMTypes.cpp#L256-L294

Hence, what would registration do?

  • It wouldn't affect printing because the type alone has no way to know it's a LLVMPointerType and testing to check if its, would be too time consuming. And using an extra TypeID or extra information is not possible without violating the 3rd condition.

Thus, I think this proposal might be as good as it gets (I might be wrong).

But again, it should be possible to generate ODS for the current proposal, for implementing a subclass with classof, parse, print and custom getters.

];
}

//===----------------------------------------------------------------------===//
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also worth nothing, that for LLVMPointerType this interface (TypeAsmAliasTypeInterface) solves the issue of printing in all circumstances.

Copy link
Member

@zero9178 zero9178 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the current proposal it should be possible to support ODS, however, the core issue of aliases not always being printed would remain.

After thinking about it a bit more, I think registering them is not going to help (I might be wrong). Let's agree in the following conditions:

  1. If type=!llvm.ptr, then:

    • isa<LLVMPointerType>(type) == isa<PtrType>(type) and TypeID::get<PtrType>() == type.getTypeID()
  2. If type!=!llvm.ptr, then:

    • isa<LLVMPointerType>(type) != isa<PtrType>(type) and TypeID::get<PtrType>() == type.getTypeID()
  3. We also need that in all circumstances regardless of their origin:

    • !ptr.ptr<#llvm.address_space> = !llvm.ptr
      This condition refers to LLVMPointerType::get(0) == PtrType::get(LLVMAddressSpace::get(0)).

If we want all three conditions, then, as types: an instance of LLVMPointerType must be indistinguishable from a PtrType. The only thing that could tell that something is a LLVMPointerType is LLVMPointerType itself.

Which is what I'm currently doing manually, see: LLVMTypes.cpp#L256-L294

That's great thank you!

Hence, what would registration do?

  • It wouldn't affect printing because the type alone has no way to know it's a LLVMPointerType and testing to check if its, would be too time consuming. And using an extra TypeID or extra information is not possible without violating the 3rd condition.

The addDialectAlias would simply be an alternative way to later give the printer the information that it has to check whether a Ptr::PointerType is an instance of LLVMPointerType or not prior to deciding which dialect to use when printing. That said, I am actually pretty happy with your interface and like the fact that it is opt-in, even if that makes it less powerful. I can also see how its more performant.

Left some comments about the implementation about the current approach then. It looks to me like it could be simplified if only allowing a different dialect to be used for printing and parsing through the interface.

Comment on lines +1735 to +1757
/// Hooks for parsing a dialect alias. The method returns success if the
/// dialect has an alias for the symbol, otherwise it must return failure.
/// If there was an error during parsing, this method should return success
/// and set the attribute to null.
virtual LogicalResult parseDialectAlias(DialectAsmParser &parser,
Attribute &attr, Type type) const {
return failure();
}
virtual LogicalResult parseDialectAlias(DialectAsmParser &parser,
Type &type) const {
return failure();
}
/// Hooks for printing a dialect alias. The method returns success if the
/// dialect has an alias for the symbol, otherwise it must return failure.
virtual LogicalResult printDialectAlias(DialectAsmPrinter &printer,
Attribute attr) const {
return failure();
}
virtual LogicalResult printDialectAlias(DialectAsmPrinter &printer,
Type type) const {
return failure();
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are these APIs required? When parsing and the parser sees a !llvm, it will call dialect->parseType right?
Similarily for printing, after the dialect that should be used for printing has been determined through the interface call, couldn't one just call dialect->printType?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I introduced this methods just to make it easier to have aliases without having to implement the parseType function in the dialect. However, if the plan is to have this backed this by ODS at some point, then there's no need for them.

Copy link
Member

@zero9178 zero9178 Mar 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd personally prefer not having these unless required. We can always add these once there is a use-case. I don't think they add much clarity.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the plan is eventually doing ODS, then I agree these should be removed.

Comment on lines +1719 to +1721
/// A dialect alias was provided and it will be used
/// (no other hooks will be checked).
DialectAlias
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this still required? Can't the printer just check whether the type/attribute implements the interface, if it does call getAliasDialect, and then call its print* method?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DialectAlias is different from the TypeAsmAliasTypeInterface.

  • With DialectAlias you can alias any type or attribute, however, the type might not always be printed using the alias.
  • TypeAsmAliasTypeInterface, you can only alias objects that allow aliases. So it has more constraints in what it can do.

For example, with the first mechanism you can alias IntegerType::get(2^exp) to p2_int<exp>, and keep i3 as i3, however the second mechanism allows aliasing the type to only one dialect (unless the type implements a specific alias rule).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there actually a use case for this and do you need that powerful of a mechanism for the pointer dialect? In the interest of landing this PR with the interface mechanism I'd just drop that mechanism and implement it in another PR of really needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really. I was just trying to solve the problem in general, and I'm open drop it. However, I think there's no rush with landing this, as Euro LLVM will happen soon, hence all the other Ptr patches are not likely to land soon.

I'll implement the DenseMap approach to talk with real numbers. If you're interested, we can also talk more about it in the meeting.

@fabianmcg
Copy link
Contributor Author

The addDialectAlias would simply be an alternative way to later give the printer the information that it has to check whether a Ptr::PointerType is an instance of LLVMPointerType or not prior to deciding which dialect to use when printing

The problem here is that we would need something like:

DenseMap<TypeID, SmallVector<Dialect*>> aliasedTypes;
// Or
DenseMap<TypeID, SmallVector<llvm::function_ref<Dialect*(Type)>>> aliasedTypes;

In AsmPrinter and every time a type gets printed we would need to check it, which would kill performance. This issue appears because only the alias knows how to distinguish itself from the base type.

@zero9178
Copy link
Member

zero9178 commented Mar 28, 2024

The addDialectAlias would simply be an alternative way to later give the printer the information that it has to check whether a Ptr::PointerType is an instance of LLVMPointerType or not prior to deciding which dialect to use when printing

The problem here is that we would need something like:

DenseMap<TypeID, SmallVector<Dialect*>> aliasedTypes;
// Or
DenseMap<TypeID, SmallVector<llvm::function_ref<Dialect*(Type)>>> aliasedTypes;

In AsmPrinter and every time a type gets printed we would need to check it, which would kill performance. This issue appears because only the alias knows how to distinguish itself from the base type.

The latter version is how I'd roughly had thought that's how it'd be done. I doubt it'd have killed performance or even been a lot different in that regard to your current version, given that the number of registered aliases would be very low in a program (so not a lot of elements in the SmallVector. E.g. just one for LLVMPointerType) and the map lookup isn't really more expensive than a cast/dyn_cast/isa to an interface either. Real performance numbers would be the useful of course, but I don't see any reason why it'd be more expensive than the current approach unless you have a huge number of registered alias types.

@fabianmcg
Copy link
Contributor Author

You're correct the lookup itself wouldn't kill performance, as it is equivalent to cast, isa..., however, searching through the vector does have the potential to kill performance.

Maybe, we agree that that's the price to use aliases, and that in general users shouldn't use many. However, that's a design decision we need to make.

@zero9178
Copy link
Member

You're correct the lookup itself wouldn't kill performance, as it is equivalent to cast, isa..., however, searching through the vector does have the potential to kill performance.

Maybe, we agree that that's the price to use aliases, and that in general users shouldn't use many. However, that's a design decision we need to make.

It certainly would be the price if someone were to register thousands of aliases. But we're getting side tracked here, I think the interface variant is fine and if you prefer it then let's go with that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:ods mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants