Skip to content

[mlir][Transforms] Dialect conversion: Context-aware type conversions #140434

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

matthias-springer
Copy link
Member

This commit adds support for context-aware type conversions: type conversion rules that can return different types depending on the IR.

There is no change for existing (context-unaware) type conversion rules:

// Example: Conversion any integer type to f32.
converter.addConversion([](IntegerType t) {
  return Float32Type::get(t.getContext());
}

There is now an additional overload to register context-aware type conversion rules:

// Example: Type conversion rule for integers, depending on the context:
// Get the defining op of `v`, read its "increment" attribute and return an
// integer with a bitwidth that is increased by "increment".
converter.addConversion([](Value v) -> std::optional<Type> {
  auto intType = dyn_cast<IntegerType>(v.getType());
  if (!intType)
    return std::nullopt;
  Operation *op = v.getDefiningOp();
  if (!op)
    return std::nullopt;
  auto incrementAttr = op->getAttrOfType<IntegerAttr>("increment");
  if (!incrementAttr)
    return std::nullopt;
  return IntegerType::get(v.getContext(),
                          intType.getWidth() + incrementAttr.getInt());
});

For performance reasons, the type converter caches the result of type conversions. This is no longer possible when there context-aware type conversions because each conversion could compute a different type depending on the context. There is no performance degradation when there are only context-unaware type conversions.

Note: This commit just adds context-aware type conversions to the dialect conversion framework. There are many existing patterns that still call converter.convertType(someValue.getType()). These should be gradually updated in subsequent commits to call converter.convertType(someValue).

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir mlir:scf labels May 18, 2025
@llvmbot
Copy link
Member

llvmbot commented May 18, 2025

@llvm/pr-subscribers-mlir-scf

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

Changes

This commit adds support for context-aware type conversions: type conversion rules that can return different types depending on the IR.

There is no change for existing (context-unaware) type conversion rules:

// Example: Conversion any integer type to f32.
converter.addConversion([](IntegerType t) {
  return Float32Type::get(t.getContext());
}

There is now an additional overload to register context-aware type conversion rules:

// Example: Type conversion rule for integers, depending on the context:
// Get the defining op of `v`, read its "increment" attribute and return an
// integer with a bitwidth that is increased by "increment".
converter.addConversion([](Value v) -&gt; std::optional&lt;Type&gt; {
  auto intType = dyn_cast&lt;IntegerType&gt;(v.getType());
  if (!intType)
    return std::nullopt;
  Operation *op = v.getDefiningOp();
  if (!op)
    return std::nullopt;
  auto incrementAttr = op-&gt;getAttrOfType&lt;IntegerAttr&gt;("increment");
  if (!incrementAttr)
    return std::nullopt;
  return IntegerType::get(v.getContext(),
                          intType.getWidth() + incrementAttr.getInt());
});

For performance reasons, the type converter caches the result of type conversions. This is no longer possible when there context-aware type conversions because each conversion could compute a different type depending on the context. There is no performance degradation when there are only context-unaware type conversions.

Note: This commit just adds context-aware type conversions to the dialect conversion framework. There are many existing patterns that still call converter.convertType(someValue.getType()). These should be gradually updated in subsequent commits to call converter.convertType(someValue).


Patch is 20.60 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/140434.diff

6 Files Affected:

  • (modified) mlir/docs/DialectConversion.md (+29-18)
  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (+82-13)
  • (modified) mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp (+2-2)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+43-2)
  • (modified) mlir/test/Transforms/test-legalize-type-conversion.mlir (+18)
  • (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+26-3)
diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md
index cf577eca5b9a6..61872d10670dc 100644
--- a/mlir/docs/DialectConversion.md
+++ b/mlir/docs/DialectConversion.md
@@ -235,6 +235,15 @@ target types. If the source type is converted to itself, we say it is a "legal"
 type. Type conversions are specified via the `addConversion` method described
 below.
 
+There are two kind of conversion functions: context-aware and context-unaware
+conversions. A context-unaware conversion function converts a `Type` into a
+`Type`. A context-aware conversion function converts a `Value` into a type. The
+latter allows users to customize type conversion rules based on the IR.
+
+Note: When there is at least one context-aware type conversion function, the
+result of type conversions can no longer be cached, which can increase
+compilation time. Use this feature with caution!
+
 A `materialization` describes how a list of values should be converted to a
 list of values with specific types. An important distinction from a
 `conversion` is that a `materialization` can produce IR, whereas a `conversion`
@@ -287,29 +296,31 @@ Several of the available hooks are detailed below:
 ```c++
 class TypeConverter {
  public:
-  /// Register a conversion function. A conversion function defines how a given
-  /// source type should be converted. A conversion function must be convertible
-  /// to any of the following forms(where `T` is a class derived from `Type`:
-  ///   * Optional<Type>(T)
+  /// Register a conversion function. A conversion function must be convertible
+  /// to any of the following forms (where `T` is `Value` or a class derived
+  /// from `Type`, including `Type` itself):
+  ///
+  ///   * std::optional<Type>(T)
   ///     - This form represents a 1-1 type conversion. It should return nullptr
-  ///       or `std::nullopt` to signify failure. If `std::nullopt` is returned, the
-  ///       converter is allowed to try another conversion function to perform
-  ///       the conversion.
-  ///   * Optional<LogicalResult>(T, SmallVectorImpl<Type> &)
+  ///       or `std::nullopt` to signify failure. If `std::nullopt` is returned,
+  ///       the converter is allowed to try another conversion function to
+  ///       perform the conversion.
+  ///   * std::optional<LogicalResult>(T, SmallVectorImpl<Type> &)
   ///     - This form represents a 1-N type conversion. It should return
-  ///       `failure` or `std::nullopt` to signify a failed conversion. If the new
-  ///       set of types is empty, the type is removed and any usages of the
+  ///       `failure` or `std::nullopt` to signify a failed conversion. If the
+  ///       new set of types is empty, the type is removed and any usages of the
   ///       existing value are expected to be removed during conversion. If
   ///       `std::nullopt` is returned, the converter is allowed to try another
   ///       conversion function to perform the conversion.
-  ///   * Optional<LogicalResult>(T, SmallVectorImpl<Type> &, ArrayRef<Type>)
-  ///     - This form represents a 1-N type conversion supporting recursive
-  ///       types. The first two arguments and the return value are the same as
-  ///       for the regular 1-N form. The third argument is contains is the
-  ///       "call stack" of the recursive conversion: it contains the list of
-  ///       types currently being converted, with the current type being the
-  ///       last one. If it is present more than once in the list, the
-  ///       conversion concerns a recursive type.
+  ///
+  /// Conversion functions that accept `Value` as the first argument are
+  /// context-aware. I.e., they can take into account IR when converting the
+  /// type of the given value. Context-unaware conversion functions accept
+  /// `Type` or a derived class as the first argument.
+  ///
+  /// Note: Context-unaware conversions are cached, but context-aware
+  /// conversions are not.
+  ///
   /// Note: When attempting to convert a type, e.g. via 'convertType', the
   ///       mostly recently added conversions will be invoked first.
   template <typename FnT,
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index e7d05c3ce1adf..07adbde3a5a60 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -18,6 +18,7 @@
 #include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/StringMap.h"
 #include <type_traits>
+#include <variant>
 
 namespace mlir {
 
@@ -139,7 +140,8 @@ class TypeConverter {
   };
 
   /// Register a conversion function. A conversion function must be convertible
-  /// to any of the following forms (where `T` is a class derived from `Type`):
+  /// to any of the following forms (where `T` is `Value` or a class derived
+  /// from `Type`, including `Type` itself):
   ///
   ///   * std::optional<Type>(T)
   ///     - This form represents a 1-1 type conversion. It should return nullptr
@@ -154,6 +156,14 @@ class TypeConverter {
   ///       `std::nullopt` is returned, the converter is allowed to try another
   ///       conversion function to perform the conversion.
   ///
+  /// Conversion functions that accept `Value` as the first argument are
+  /// context-aware. I.e., they can take into account IR when converting the
+  /// type of the given value. Context-unaware conversion functions accept
+  /// `Type` or a derived class as the first argument.
+  ///
+  /// Note: Context-unaware conversions are cached, but context-aware
+  /// conversions are not.
+  ///
   /// Note: When attempting to convert a type, e.g. via 'convertType', the
   ///       mostly recently added conversions will be invoked first.
   template <typename FnT, typename T = typename llvm::function_traits<
@@ -241,15 +251,28 @@ class TypeConverter {
         wrapTypeAttributeConversion<T, A>(std::forward<FnT>(callback)));
   }
 
-  /// Convert the given type. This function should return failure if no valid
+  /// Convert the given type. This function returns failure if no valid
   /// conversion exists, success otherwise. If the new set of types is empty,
   /// the type is removed and any usages of the existing value are expected to
   /// be removed during conversion.
+  ///
+  /// Note: This overload invokes only context-unaware type conversion
+  /// functions. Users should call the other overload if possible.
   LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) const;
 
+  /// Convert the type of the given value. This function returns failure if no
+  /// valid conversion exists, success otherwise. If the new set of types is
+  /// empty, the type is removed and any usages of the existing value are
+  /// expected to be removed during conversion.
+  ///
+  /// Note: This overload invokes both context-aware and context-unaware type
+  /// conversion functions.
+  LogicalResult convertType(Value v, SmallVectorImpl<Type> &results) const;
+
   /// This hook simplifies defining 1-1 type conversions. This function returns
   /// the type to convert to on success, and a null type on failure.
   Type convertType(Type t) const;
+  Type convertType(Value v) const;
 
   /// Attempts a 1-1 type conversion, expecting the result type to be
   /// `TargetType`. Returns the converted type cast to `TargetType` on success,
@@ -258,13 +281,23 @@ class TypeConverter {
   TargetType convertType(Type t) const {
     return dyn_cast_or_null<TargetType>(convertType(t));
   }
+  template <typename TargetType>
+  TargetType convertType(Value v) const {
+    return dyn_cast_or_null<TargetType>(convertType(v));
+  }
 
-  /// Convert the given set of types, filling 'results' as necessary. This
-  /// returns failure if the conversion of any of the types fails, success
+  /// Convert the given types, filling 'results' as necessary. This returns
+  /// "failure" if the conversion of any of the types fails, "success"
   /// otherwise.
   LogicalResult convertTypes(TypeRange types,
                              SmallVectorImpl<Type> &results) const;
 
+  /// Convert the types of the given values, filling 'results' as necessary.
+  /// This returns "failure" if the conversion of any of the types fails,
+  /// "success" otherwise.
+  LogicalResult convertTypes(ValueRange values,
+                             SmallVectorImpl<Type> &results) const;
+
   /// Return true if the given type is legal for this type converter, i.e. the
   /// type converts to itself.
   bool isLegal(Type type) const;
@@ -328,7 +361,7 @@ class TypeConverter {
   /// types is empty, the type is removed and any usages of the existing value
   /// are expected to be removed during conversion.
   using ConversionCallbackFn = std::function<std::optional<LogicalResult>(
-      Type, SmallVectorImpl<Type> &)>;
+      std::variant<Type, Value>, SmallVectorImpl<Type> &)>;
 
   /// The signature of the callback used to materialize a source conversion.
   ///
@@ -348,13 +381,14 @@ class TypeConverter {
 
   /// Generate a wrapper for the given callback. This allows for accepting
   /// different callback forms, that all compose into a single version.
-  /// With callback of form: `std::optional<Type>(T)`
+  /// With callback of form: `std::optional<Type>(T)`, where `T` can be a
+  /// `Value` or a `Type` (or a class derived from `Type`).
   template <typename T, typename FnT>
   std::enable_if_t<std::is_invocable_v<FnT, T>, ConversionCallbackFn>
-  wrapCallback(FnT &&callback) const {
+  wrapCallback(FnT &&callback) {
     return wrapCallback<T>([callback = std::forward<FnT>(callback)](
-                               T type, SmallVectorImpl<Type> &results) {
-      if (std::optional<Type> resultOpt = callback(type)) {
+                               T typeOrValue, SmallVectorImpl<Type> &results) {
+      if (std::optional<Type> resultOpt = callback(typeOrValue)) {
         bool wasSuccess = static_cast<bool>(*resultOpt);
         if (wasSuccess)
           results.push_back(*resultOpt);
@@ -364,20 +398,49 @@ class TypeConverter {
     });
   }
   /// With callback of form: `std::optional<LogicalResult>(
-  ///     T, SmallVectorImpl<Type> &, ArrayRef<Type>)`.
+  ///     T, SmallVectorImpl<Type> &)`, where `T` is a type.
   template <typename T, typename FnT>
-  std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &>,
+  std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &> &&
+                       std::is_base_of_v<Type, T>,
                    ConversionCallbackFn>
   wrapCallback(FnT &&callback) const {
     return [callback = std::forward<FnT>(callback)](
-               Type type,
+               std::variant<Type, Value> type,
                SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
-      T derivedType = dyn_cast<T>(type);
+      T derivedType;
+      if (Type *t = std::get_if<Type>(&type)) {
+        derivedType = dyn_cast<T>(*t);
+      } else if (Value *v = std::get_if<Value>(&type)) {
+        derivedType = dyn_cast<T>(v->getType());
+      } else {
+        llvm_unreachable("unexpected variant");
+      }
       if (!derivedType)
         return std::nullopt;
       return callback(derivedType, results);
     };
   }
+  /// With callback of form: `std::optional<LogicalResult>(
+  ///     T, SmallVectorImpl<Type>)`, where `T` is a `Value`.
+  template <typename T, typename FnT>
+  std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &> &&
+                       std::is_same_v<T, Value>,
+                   ConversionCallbackFn>
+  wrapCallback(FnT &&callback) {
+    hasContextAwareTypeConversions = true;
+    return [callback = std::forward<FnT>(callback)](
+               std::variant<Type, Value> type,
+               SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
+      if (Type *t = std::get_if<Type>(&type)) {
+        // Context-aware type conversion was called with a type.
+        return std::nullopt;
+      } else if (Value *v = std::get_if<Value>(&type)) {
+        return callback(*v, results);
+      }
+      llvm_unreachable("unexpected variant");
+      return std::nullopt;
+    };
+  }
 
   /// Register a type conversion.
   void registerConversion(ConversionCallbackFn callback) {
@@ -504,6 +567,12 @@ class TypeConverter {
   mutable DenseMap<Type, SmallVector<Type, 2>> cachedMultiConversions;
   /// A mutex used for cache access
   mutable llvm::sys::SmartRWMutex<true> cacheMutex;
+  /// Whether the type converter has context-aware type conversions. I.e.,
+  /// conversion rules that depend on the SSA value instead of just the type.
+  /// Type conversion caching is deactivated when there are context-aware
+  /// conversions because the type converter may return different results for
+  /// the same input type.
+  bool hasContextAwareTypeConversions = false;
 };
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
index 09326242eec2a..de4612fa0846a 100644
--- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
@@ -52,8 +52,8 @@ class Structural1ToNConversionPattern : public OpConversionPattern<SourceOp> {
     SmallVector<unsigned> offsets;
     offsets.push_back(0);
     // Do the type conversion and record the offsets.
-    for (Type type : op.getResultTypes()) {
-      if (failed(typeConverter->convertTypes(type, dstTypes)))
+    for (Value v : op.getResults()) {
+      if (failed(typeConverter->convertType(v, dstTypes)))
         return rewriter.notifyMatchFailure(op, "could not convert result type");
       offsets.push_back(dstTypes.size());
     }
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index bd11bbe58a3f6..2a1d154faeaf3 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1256,7 +1256,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
 
     // If there is no legal conversion, fail to match this pattern.
     SmallVector<Type, 1> legalTypes;
-    if (failed(currentTypeConverter->convertType(origType, legalTypes))) {
+    if (failed(currentTypeConverter->convertType(operand, legalTypes))) {
       notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
         diag << "unable to convert type for " << valueDiagTag << " #"
              << it.index() << ", type was " << origType;
@@ -2899,6 +2899,28 @@ LogicalResult TypeConverter::convertType(Type t,
   return failure();
 }
 
+LogicalResult TypeConverter::convertType(Value v,
+                                         SmallVectorImpl<Type> &results) const {
+  assert(v && "expected non-null value");
+
+  // If this type converter does not have context-aware type conversions, call
+  // the type-based overload, which has caching.
+  if (!hasContextAwareTypeConversions) {
+    return convertType(v.getType(), results);
+  }
+
+  // Walk the added converters in reverse order to apply the most recently
+  // registered first.
+  for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
+    if (std::optional<LogicalResult> result = converter(v, results)) {
+      if (!succeeded(*result))
+        return failure();
+      return success();
+    }
+  }
+  return failure();
+}
+
 Type TypeConverter::convertType(Type t) const {
   // Use the multi-type result version to convert the type.
   SmallVector<Type, 1> results;
@@ -2909,6 +2931,16 @@ Type TypeConverter::convertType(Type t) const {
   return results.size() == 1 ? results.front() : nullptr;
 }
 
+Type TypeConverter::convertType(Value v) const {
+  // Use the multi-type result version to convert the type.
+  SmallVector<Type, 1> results;
+  if (failed(convertType(v, results)))
+    return nullptr;
+
+  // Check to ensure that only one type was produced.
+  return results.size() == 1 ? results.front() : nullptr;
+}
+
 LogicalResult
 TypeConverter::convertTypes(TypeRange types,
                             SmallVectorImpl<Type> &results) const {
@@ -2918,6 +2950,15 @@ TypeConverter::convertTypes(TypeRange types,
   return success();
 }
 
+LogicalResult
+TypeConverter::convertTypes(ValueRange values,
+                            SmallVectorImpl<Type> &results) const {
+  for (Value value : values)
+    if (failed(convertType(value, results)))
+      return failure();
+  return success();
+}
+
 bool TypeConverter::isLegal(Type type) const {
   return convertType(type) == type;
 }
@@ -3128,7 +3169,7 @@ mlir::convertOpResultTypes(Operation *op, ValueRange operands,
   newOp.addOperands(operands);
 
   SmallVector<Type> newResultTypes;
-  if (failed(converter.convertTypes(op->getResultTypes(), newResultTypes)))
+  if (failed(converter.convertTypes(op->getResults(), newResultTypes)))
     return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
 
   newOp.addTypes(newResultTypes);
diff --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir
index db8bd0f6378d2..7b5e6e796a528 100644
--- a/mlir/test/Transforms/test-legalize-type-conversion.mlir
+++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir
@@ -142,3 +142,21 @@ func.func @test_signature_conversion_no_converter() {
   }) : () -> ()
   return
 }
+
+// -----
+
+// CHECK-LABEL: func @context_aware_conversion()
+func.func @context_aware_conversion() {
+  // Case 1: Convert i37 --> i38.
+  // CHECK: %[[cast0:.*]] = unrealized_conversion_cast %{{.*}} : i37 to i38
+  // CHECK: "test.legal_op_d"(%[[cast0]]) : (i38) -> ()
+  %0 = "test.context_op"() {increment = 1 : i64} : () -> (i37)
+  "test.replace_with_legal_op"(%0) : (i37) -> ()
+
+  // Case 2: Convert i37 --> i39.
+  // CHECK: %[[cast1:.*]] = unrealized_conversion_cast %{{.*}} : i37 to i39
+  // CHECK: "test.legal_op_d"(%[[cast1]]) : (i39) -> ()
+  %1 = "test.context_op"() {increment = 2 : i64} : () -> (i37)
+  "test.replace_with_legal_op"(%1) : (i37) -> ()
+  return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index d073843484d81..bd85e6fd9ae7f 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1827,9 +1827,9 @@ struct TestReplaceWithLegalOp : public ConversionPattern {
       : ConversionPattern(converter, "test.replace_with_legal_op",
                           /*benefit=*/1, ctx) {}
   LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+  matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
                   ConversionPatternRewriter &rewriter) const final {
-    rewriter.replaceOpWithNewOp<LegalOpD>(op, operands[0]);
+    rewriter.replaceOpWithNewOp<LegalOpD>(op, operands[0].front());
     return success();
   }
 };
@@ -1865,7 +1865,7 @@ struct TestTypeConversionDriver
       return nullptr;
     });
     converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) {
-      // Drop all integer types.
+      // Drop all other integer types.
       return success();
     });
     converter.addConversion(
@@ -1902,6 +1902,19 @@ struct TestTypeConversionDriver
           results.push_back(result);
           return success();
         });
+    converter.addConversion([](Value v) -> std::optional<Type> {
+      auto intType = dyn_cast<IntegerType>(v.getType());
+      if (!intType || intType.getWidth() != 37)
+        return std::nullopt;
+      Operation *op = v.getDefiningOp();
+      if (!op)
+        return std::nullopt;
+      auto incrementAttr = op->getAttrOfType<IntegerAttr>("increment");
+      if (!incrementAttr)
+        return std::nullopt;
+      return IntegerType::get(v.getContext(),
+                              intType.getWidth() + incrementAttr.getInt());
+    });
 
     /// Add the legal set of type materializations.
     converter.addSourceMaterialization([](OpBuilder &builder, Type resultType,
@@ -1922,9 +1935,19 @@ struct TestTypeConversionDriver
       // Otherwise, fail.
       return nullptr;
     });
+    // Materialize i37 to any desired type with unrealized_conversion_cast.
+    converter.addTargetMaterialization([](OpBuilder &builder, Type type,
+                                          ValueRange inputs,
+                        ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented May 18, 2025

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

This commit adds support for context-aware type conversions: type conversion rules that can return different types depending on the IR.

There is no change for existing (context-unaware) type conversion rules:

// Example: Conversion any integer type to f32.
converter.addConversion([](IntegerType t) {
  return Float32Type::get(t.getContext());
}

There is now an additional overload to register context-aware type conversion rules:

// Example: Type conversion rule for integers, depending on the context:
// Get the defining op of `v`, read its "increment" attribute and return an
// integer with a bitwidth that is increased by "increment".
converter.addConversion([](Value v) -&gt; std::optional&lt;Type&gt; {
  auto intType = dyn_cast&lt;IntegerType&gt;(v.getType());
  if (!intType)
    return std::nullopt;
  Operation *op = v.getDefiningOp();
  if (!op)
    return std::nullopt;
  auto incrementAttr = op-&gt;getAttrOfType&lt;IntegerAttr&gt;("increment");
  if (!incrementAttr)
    return std::nullopt;
  return IntegerType::get(v.getContext(),
                          intType.getWidth() + incrementAttr.getInt());
});

For performance reasons, the type converter caches the result of type conversions. This is no longer possible when there context-aware type conversions because each conversion could compute a different type depending on the context. There is no performance degradation when there are only context-unaware type conversions.

Note: This commit just adds context-aware type conversions to the dialect conversion framework. There are many existing patterns that still call converter.convertType(someValue.getType()). These should be gradually updated in subsequent commits to call converter.convertType(someValue).


Patch is 20.60 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/140434.diff

6 Files Affected:

  • (modified) mlir/docs/DialectConversion.md (+29-18)
  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (+82-13)
  • (modified) mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp (+2-2)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+43-2)
  • (modified) mlir/test/Transforms/test-legalize-type-conversion.mlir (+18)
  • (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+26-3)
diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md
index cf577eca5b9a6..61872d10670dc 100644
--- a/mlir/docs/DialectConversion.md
+++ b/mlir/docs/DialectConversion.md
@@ -235,6 +235,15 @@ target types. If the source type is converted to itself, we say it is a "legal"
 type. Type conversions are specified via the `addConversion` method described
 below.
 
+There are two kind of conversion functions: context-aware and context-unaware
+conversions. A context-unaware conversion function converts a `Type` into a
+`Type`. A context-aware conversion function converts a `Value` into a type. The
+latter allows users to customize type conversion rules based on the IR.
+
+Note: When there is at least one context-aware type conversion function, the
+result of type conversions can no longer be cached, which can increase
+compilation time. Use this feature with caution!
+
 A `materialization` describes how a list of values should be converted to a
 list of values with specific types. An important distinction from a
 `conversion` is that a `materialization` can produce IR, whereas a `conversion`
@@ -287,29 +296,31 @@ Several of the available hooks are detailed below:
 ```c++
 class TypeConverter {
  public:
-  /// Register a conversion function. A conversion function defines how a given
-  /// source type should be converted. A conversion function must be convertible
-  /// to any of the following forms(where `T` is a class derived from `Type`:
-  ///   * Optional<Type>(T)
+  /// Register a conversion function. A conversion function must be convertible
+  /// to any of the following forms (where `T` is `Value` or a class derived
+  /// from `Type`, including `Type` itself):
+  ///
+  ///   * std::optional<Type>(T)
   ///     - This form represents a 1-1 type conversion. It should return nullptr
-  ///       or `std::nullopt` to signify failure. If `std::nullopt` is returned, the
-  ///       converter is allowed to try another conversion function to perform
-  ///       the conversion.
-  ///   * Optional<LogicalResult>(T, SmallVectorImpl<Type> &)
+  ///       or `std::nullopt` to signify failure. If `std::nullopt` is returned,
+  ///       the converter is allowed to try another conversion function to
+  ///       perform the conversion.
+  ///   * std::optional<LogicalResult>(T, SmallVectorImpl<Type> &)
   ///     - This form represents a 1-N type conversion. It should return
-  ///       `failure` or `std::nullopt` to signify a failed conversion. If the new
-  ///       set of types is empty, the type is removed and any usages of the
+  ///       `failure` or `std::nullopt` to signify a failed conversion. If the
+  ///       new set of types is empty, the type is removed and any usages of the
   ///       existing value are expected to be removed during conversion. If
   ///       `std::nullopt` is returned, the converter is allowed to try another
   ///       conversion function to perform the conversion.
-  ///   * Optional<LogicalResult>(T, SmallVectorImpl<Type> &, ArrayRef<Type>)
-  ///     - This form represents a 1-N type conversion supporting recursive
-  ///       types. The first two arguments and the return value are the same as
-  ///       for the regular 1-N form. The third argument is contains is the
-  ///       "call stack" of the recursive conversion: it contains the list of
-  ///       types currently being converted, with the current type being the
-  ///       last one. If it is present more than once in the list, the
-  ///       conversion concerns a recursive type.
+  ///
+  /// Conversion functions that accept `Value` as the first argument are
+  /// context-aware. I.e., they can take into account IR when converting the
+  /// type of the given value. Context-unaware conversion functions accept
+  /// `Type` or a derived class as the first argument.
+  ///
+  /// Note: Context-unaware conversions are cached, but context-aware
+  /// conversions are not.
+  ///
   /// Note: When attempting to convert a type, e.g. via 'convertType', the
   ///       mostly recently added conversions will be invoked first.
   template <typename FnT,
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index e7d05c3ce1adf..07adbde3a5a60 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -18,6 +18,7 @@
 #include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/StringMap.h"
 #include <type_traits>
+#include <variant>
 
 namespace mlir {
 
@@ -139,7 +140,8 @@ class TypeConverter {
   };
 
   /// Register a conversion function. A conversion function must be convertible
-  /// to any of the following forms (where `T` is a class derived from `Type`):
+  /// to any of the following forms (where `T` is `Value` or a class derived
+  /// from `Type`, including `Type` itself):
   ///
   ///   * std::optional<Type>(T)
   ///     - This form represents a 1-1 type conversion. It should return nullptr
@@ -154,6 +156,14 @@ class TypeConverter {
   ///       `std::nullopt` is returned, the converter is allowed to try another
   ///       conversion function to perform the conversion.
   ///
+  /// Conversion functions that accept `Value` as the first argument are
+  /// context-aware. I.e., they can take into account IR when converting the
+  /// type of the given value. Context-unaware conversion functions accept
+  /// `Type` or a derived class as the first argument.
+  ///
+  /// Note: Context-unaware conversions are cached, but context-aware
+  /// conversions are not.
+  ///
   /// Note: When attempting to convert a type, e.g. via 'convertType', the
   ///       mostly recently added conversions will be invoked first.
   template <typename FnT, typename T = typename llvm::function_traits<
@@ -241,15 +251,28 @@ class TypeConverter {
         wrapTypeAttributeConversion<T, A>(std::forward<FnT>(callback)));
   }
 
-  /// Convert the given type. This function should return failure if no valid
+  /// Convert the given type. This function returns failure if no valid
   /// conversion exists, success otherwise. If the new set of types is empty,
   /// the type is removed and any usages of the existing value are expected to
   /// be removed during conversion.
+  ///
+  /// Note: This overload invokes only context-unaware type conversion
+  /// functions. Users should call the other overload if possible.
   LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) const;
 
+  /// Convert the type of the given value. This function returns failure if no
+  /// valid conversion exists, success otherwise. If the new set of types is
+  /// empty, the type is removed and any usages of the existing value are
+  /// expected to be removed during conversion.
+  ///
+  /// Note: This overload invokes both context-aware and context-unaware type
+  /// conversion functions.
+  LogicalResult convertType(Value v, SmallVectorImpl<Type> &results) const;
+
   /// This hook simplifies defining 1-1 type conversions. This function returns
   /// the type to convert to on success, and a null type on failure.
   Type convertType(Type t) const;
+  Type convertType(Value v) const;
 
   /// Attempts a 1-1 type conversion, expecting the result type to be
   /// `TargetType`. Returns the converted type cast to `TargetType` on success,
@@ -258,13 +281,23 @@ class TypeConverter {
   TargetType convertType(Type t) const {
     return dyn_cast_or_null<TargetType>(convertType(t));
   }
+  template <typename TargetType>
+  TargetType convertType(Value v) const {
+    return dyn_cast_or_null<TargetType>(convertType(v));
+  }
 
-  /// Convert the given set of types, filling 'results' as necessary. This
-  /// returns failure if the conversion of any of the types fails, success
+  /// Convert the given types, filling 'results' as necessary. This returns
+  /// "failure" if the conversion of any of the types fails, "success"
   /// otherwise.
   LogicalResult convertTypes(TypeRange types,
                              SmallVectorImpl<Type> &results) const;
 
+  /// Convert the types of the given values, filling 'results' as necessary.
+  /// This returns "failure" if the conversion of any of the types fails,
+  /// "success" otherwise.
+  LogicalResult convertTypes(ValueRange values,
+                             SmallVectorImpl<Type> &results) const;
+
   /// Return true if the given type is legal for this type converter, i.e. the
   /// type converts to itself.
   bool isLegal(Type type) const;
@@ -328,7 +361,7 @@ class TypeConverter {
   /// types is empty, the type is removed and any usages of the existing value
   /// are expected to be removed during conversion.
   using ConversionCallbackFn = std::function<std::optional<LogicalResult>(
-      Type, SmallVectorImpl<Type> &)>;
+      std::variant<Type, Value>, SmallVectorImpl<Type> &)>;
 
   /// The signature of the callback used to materialize a source conversion.
   ///
@@ -348,13 +381,14 @@ class TypeConverter {
 
   /// Generate a wrapper for the given callback. This allows for accepting
   /// different callback forms, that all compose into a single version.
-  /// With callback of form: `std::optional<Type>(T)`
+  /// With callback of form: `std::optional<Type>(T)`, where `T` can be a
+  /// `Value` or a `Type` (or a class derived from `Type`).
   template <typename T, typename FnT>
   std::enable_if_t<std::is_invocable_v<FnT, T>, ConversionCallbackFn>
-  wrapCallback(FnT &&callback) const {
+  wrapCallback(FnT &&callback) {
     return wrapCallback<T>([callback = std::forward<FnT>(callback)](
-                               T type, SmallVectorImpl<Type> &results) {
-      if (std::optional<Type> resultOpt = callback(type)) {
+                               T typeOrValue, SmallVectorImpl<Type> &results) {
+      if (std::optional<Type> resultOpt = callback(typeOrValue)) {
         bool wasSuccess = static_cast<bool>(*resultOpt);
         if (wasSuccess)
           results.push_back(*resultOpt);
@@ -364,20 +398,49 @@ class TypeConverter {
     });
   }
   /// With callback of form: `std::optional<LogicalResult>(
-  ///     T, SmallVectorImpl<Type> &, ArrayRef<Type>)`.
+  ///     T, SmallVectorImpl<Type> &)`, where `T` is a type.
   template <typename T, typename FnT>
-  std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &>,
+  std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &> &&
+                       std::is_base_of_v<Type, T>,
                    ConversionCallbackFn>
   wrapCallback(FnT &&callback) const {
     return [callback = std::forward<FnT>(callback)](
-               Type type,
+               std::variant<Type, Value> type,
                SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
-      T derivedType = dyn_cast<T>(type);
+      T derivedType;
+      if (Type *t = std::get_if<Type>(&type)) {
+        derivedType = dyn_cast<T>(*t);
+      } else if (Value *v = std::get_if<Value>(&type)) {
+        derivedType = dyn_cast<T>(v->getType());
+      } else {
+        llvm_unreachable("unexpected variant");
+      }
       if (!derivedType)
         return std::nullopt;
       return callback(derivedType, results);
     };
   }
+  /// With callback of form: `std::optional<LogicalResult>(
+  ///     T, SmallVectorImpl<Type>)`, where `T` is a `Value`.
+  template <typename T, typename FnT>
+  std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &> &&
+                       std::is_same_v<T, Value>,
+                   ConversionCallbackFn>
+  wrapCallback(FnT &&callback) {
+    hasContextAwareTypeConversions = true;
+    return [callback = std::forward<FnT>(callback)](
+               std::variant<Type, Value> type,
+               SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
+      if (Type *t = std::get_if<Type>(&type)) {
+        // Context-aware type conversion was called with a type.
+        return std::nullopt;
+      } else if (Value *v = std::get_if<Value>(&type)) {
+        return callback(*v, results);
+      }
+      llvm_unreachable("unexpected variant");
+      return std::nullopt;
+    };
+  }
 
   /// Register a type conversion.
   void registerConversion(ConversionCallbackFn callback) {
@@ -504,6 +567,12 @@ class TypeConverter {
   mutable DenseMap<Type, SmallVector<Type, 2>> cachedMultiConversions;
   /// A mutex used for cache access
   mutable llvm::sys::SmartRWMutex<true> cacheMutex;
+  /// Whether the type converter has context-aware type conversions. I.e.,
+  /// conversion rules that depend on the SSA value instead of just the type.
+  /// Type conversion caching is deactivated when there are context-aware
+  /// conversions because the type converter may return different results for
+  /// the same input type.
+  bool hasContextAwareTypeConversions = false;
 };
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
index 09326242eec2a..de4612fa0846a 100644
--- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
@@ -52,8 +52,8 @@ class Structural1ToNConversionPattern : public OpConversionPattern<SourceOp> {
     SmallVector<unsigned> offsets;
     offsets.push_back(0);
     // Do the type conversion and record the offsets.
-    for (Type type : op.getResultTypes()) {
-      if (failed(typeConverter->convertTypes(type, dstTypes)))
+    for (Value v : op.getResults()) {
+      if (failed(typeConverter->convertType(v, dstTypes)))
         return rewriter.notifyMatchFailure(op, "could not convert result type");
       offsets.push_back(dstTypes.size());
     }
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index bd11bbe58a3f6..2a1d154faeaf3 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1256,7 +1256,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
 
     // If there is no legal conversion, fail to match this pattern.
     SmallVector<Type, 1> legalTypes;
-    if (failed(currentTypeConverter->convertType(origType, legalTypes))) {
+    if (failed(currentTypeConverter->convertType(operand, legalTypes))) {
       notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
         diag << "unable to convert type for " << valueDiagTag << " #"
              << it.index() << ", type was " << origType;
@@ -2899,6 +2899,28 @@ LogicalResult TypeConverter::convertType(Type t,
   return failure();
 }
 
+LogicalResult TypeConverter::convertType(Value v,
+                                         SmallVectorImpl<Type> &results) const {
+  assert(v && "expected non-null value");
+
+  // If this type converter does not have context-aware type conversions, call
+  // the type-based overload, which has caching.
+  if (!hasContextAwareTypeConversions) {
+    return convertType(v.getType(), results);
+  }
+
+  // Walk the added converters in reverse order to apply the most recently
+  // registered first.
+  for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
+    if (std::optional<LogicalResult> result = converter(v, results)) {
+      if (!succeeded(*result))
+        return failure();
+      return success();
+    }
+  }
+  return failure();
+}
+
 Type TypeConverter::convertType(Type t) const {
   // Use the multi-type result version to convert the type.
   SmallVector<Type, 1> results;
@@ -2909,6 +2931,16 @@ Type TypeConverter::convertType(Type t) const {
   return results.size() == 1 ? results.front() : nullptr;
 }
 
+Type TypeConverter::convertType(Value v) const {
+  // Use the multi-type result version to convert the type.
+  SmallVector<Type, 1> results;
+  if (failed(convertType(v, results)))
+    return nullptr;
+
+  // Check to ensure that only one type was produced.
+  return results.size() == 1 ? results.front() : nullptr;
+}
+
 LogicalResult
 TypeConverter::convertTypes(TypeRange types,
                             SmallVectorImpl<Type> &results) const {
@@ -2918,6 +2950,15 @@ TypeConverter::convertTypes(TypeRange types,
   return success();
 }
 
+LogicalResult
+TypeConverter::convertTypes(ValueRange values,
+                            SmallVectorImpl<Type> &results) const {
+  for (Value value : values)
+    if (failed(convertType(value, results)))
+      return failure();
+  return success();
+}
+
 bool TypeConverter::isLegal(Type type) const {
   return convertType(type) == type;
 }
@@ -3128,7 +3169,7 @@ mlir::convertOpResultTypes(Operation *op, ValueRange operands,
   newOp.addOperands(operands);
 
   SmallVector<Type> newResultTypes;
-  if (failed(converter.convertTypes(op->getResultTypes(), newResultTypes)))
+  if (failed(converter.convertTypes(op->getResults(), newResultTypes)))
     return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
 
   newOp.addTypes(newResultTypes);
diff --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir
index db8bd0f6378d2..7b5e6e796a528 100644
--- a/mlir/test/Transforms/test-legalize-type-conversion.mlir
+++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir
@@ -142,3 +142,21 @@ func.func @test_signature_conversion_no_converter() {
   }) : () -> ()
   return
 }
+
+// -----
+
+// CHECK-LABEL: func @context_aware_conversion()
+func.func @context_aware_conversion() {
+  // Case 1: Convert i37 --> i38.
+  // CHECK: %[[cast0:.*]] = unrealized_conversion_cast %{{.*}} : i37 to i38
+  // CHECK: "test.legal_op_d"(%[[cast0]]) : (i38) -> ()
+  %0 = "test.context_op"() {increment = 1 : i64} : () -> (i37)
+  "test.replace_with_legal_op"(%0) : (i37) -> ()
+
+  // Case 2: Convert i37 --> i39.
+  // CHECK: %[[cast1:.*]] = unrealized_conversion_cast %{{.*}} : i37 to i39
+  // CHECK: "test.legal_op_d"(%[[cast1]]) : (i39) -> ()
+  %1 = "test.context_op"() {increment = 2 : i64} : () -> (i37)
+  "test.replace_with_legal_op"(%1) : (i37) -> ()
+  return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index d073843484d81..bd85e6fd9ae7f 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1827,9 +1827,9 @@ struct TestReplaceWithLegalOp : public ConversionPattern {
       : ConversionPattern(converter, "test.replace_with_legal_op",
                           /*benefit=*/1, ctx) {}
   LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+  matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
                   ConversionPatternRewriter &rewriter) const final {
-    rewriter.replaceOpWithNewOp<LegalOpD>(op, operands[0]);
+    rewriter.replaceOpWithNewOp<LegalOpD>(op, operands[0].front());
     return success();
   }
 };
@@ -1865,7 +1865,7 @@ struct TestTypeConversionDriver
       return nullptr;
     });
     converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) {
-      // Drop all integer types.
+      // Drop all other integer types.
       return success();
     });
     converter.addConversion(
@@ -1902,6 +1902,19 @@ struct TestTypeConversionDriver
           results.push_back(result);
           return success();
         });
+    converter.addConversion([](Value v) -> std::optional<Type> {
+      auto intType = dyn_cast<IntegerType>(v.getType());
+      if (!intType || intType.getWidth() != 37)
+        return std::nullopt;
+      Operation *op = v.getDefiningOp();
+      if (!op)
+        return std::nullopt;
+      auto incrementAttr = op->getAttrOfType<IntegerAttr>("increment");
+      if (!incrementAttr)
+        return std::nullopt;
+      return IntegerType::get(v.getContext(),
+                              intType.getWidth() + incrementAttr.getInt());
+    });
 
     /// Add the legal set of type materializations.
     converter.addSourceMaterialization([](OpBuilder &builder, Type resultType,
@@ -1922,9 +1935,19 @@ struct TestTypeConversionDriver
       // Otherwise, fail.
       return nullptr;
     });
+    // Materialize i37 to any desired type with unrealized_conversion_cast.
+    converter.addTargetMaterialization([](OpBuilder &builder, Type type,
+                                          ValueRange inputs,
+                        ...
[truncated]

experiment

more
@matthias-springer matthias-springer force-pushed the users/matthias-springer/context_aware_type_conversions branch from d565695 to 727c0f7 Compare May 18, 2025 05:19
@j2kun
Copy link
Contributor

j2kun commented May 18, 2025

This looks great! Do you have any idea or plan on how one would be able to do context-aware type conversion when no Value is available? (Such as a function signature that is a pure declaration)

@matthias-springer
Copy link
Member Author

If the function has a body (non-external), the context-aware type converter can be called for the block arguments of the entry block. (The lowering pattern does not do this yet. There are a bunch of patterns that must be updated in follow-up commits.)

If the function does not have a body (external), this commit does not help. What would work: writing a separate conversion pattern for function ops that contains your custom type conversion logic instead of calling the type converter API.

Copy link
Member

@zero9178 zero9178 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, but please wait for either @j2kun or @ZenithalHourlyRate to be happy as well.

Disabling all caching seems unfortunate to me but expected. Maybe in a future PR this could be more fine grained such that the algorithm is:

cache = true
foreach conversion in reverse(conversions):
   if is_context_aware(conversion):
      cache = false
   result = conversion(value)
   ... // as before

if cache:
   cacheMap[type(value)] = result

It should be safe to cache any type/value before the type-aware conversion was queried. We could then also document and recommend users to add context-aware conversions last.

matthias-springer and others added 2 commits May 19, 2025 09:43
Co-authored-by: Markus Böck <markus.boeck02@gmail.com>
@matthias-springer
Copy link
Member Author

matthias-springer commented May 19, 2025

It should be safe to cache any type/value before the type-aware conversion was queried. We could then also document and recommend users to add context-aware conversions last.

You probably mean "to add context-aware conversions first"? The conversions are applied in reverse order. Adding context-aware conversion first won't be possible if, let's say you're extending the LLVM type converter.

When this becomes a performance problem, I would start by adding a Type template parameter to context-aware type conversions. Similar to what we have for context-unaware type conversions. Then we can keep track of the types that have context-aware conversions. Caching would be deactivated only for those types. I suspect that there won't be much overlap between types that require context-aware and types that require context-unaware conversions.

@ZenithalHourlyRate
Copy link
Member

Happy to see this! I think this API is general enough. I would like to list scenarios downstream project has encountered where Type <-> Value correspondence would occur or there is no such correspondence, so we can discuss with more care.

  1. Value that has DefiningOp (handled in this PR)
  2. Value that is BlockArgument (find parent Block/Region/Op, looking "up" in the IR hierarchy)
  3. Block Return Type (should find Block Terminator, looking "down" in the hierarchy; or should it find parent Block?)
  4. function signature without entry block (user should take care of it)
  5. function signature with entry block (looking "down" in the hierarchy into its entry block, for both argument and return type)

The context for a Type is defined either up or down in the hierarchy, which I think is not so clear and may worth some clarification. Also, typically the context is actually an Attribute, which now may be arbitrarily discarded (some discussion in #127772).

I am a little bit concerned about the safety of accessing the context for Value during dialect conversion. The context finding behavior, ought to be safe (theoretically a Value should only find the original IR during conversion), was unsafe actually. The complexity roots from the Block. For Operations we kept both copies during conversion so a Value should find the original Op, but Block are not so well maintained in that things inside a block may not access the original Block depending on the conversion order.

  1. As originally mentioned in Add new context-aware dialect conversion framework google/heir#1527, a Value of BlockArgument could have no parent op after block signature conversion as the block is unlinked, resulting in segfault when trying to access one. Type/Value legality check is also affected.
  2. The order of conversion now matters because the context is defined both up/down the hierarchy. ParentOp (like function signature) converted the first or child (like block signature) converted the first will affect whether a Value is able to find its context. This kind of bug is quite hard to debug for user.
  3. Such unlinked behavior could be resolved by lookup the internal Mapping, but currently we are unable to do so.

@matthias-springer
Copy link
Member Author

  • Block Return Type (should find Block Terminator, looking "down" in the hierarchy; or should it find parent Block?)

What is a block return type?

Also, typically the context is actually an Attribute, which now may be arbitrarily discarded (some discussion in #127772).

You may be interested in this presentation from the last MLIR workshop.

I am a little bit concerned about the safety of accessing the context for Value during dialect conversion. The context finding behavior, ought to be safe (theoretically a Value should only find the original IR during conversion),

We have a similar situation in conversion patterns: What kind of IR traversals are safe? The problem is that you are seeing a mixture of old and new IR. I'd say that analyzing the defining op of an SSA value is fine. If that is not enough context, you're getting on thin ice.

Btw, as part of the One-Shot Dialect Conversion refactoring, I plan to immediately materialize all IR changes. I.e., you would no longer see old IR. But there is still a long way to go for this refactoring...

For Operations we kept both copies during conversion so a Value should find the original Op, but Block are not so well maintained in that things inside a block may not access the original Block depending on the conversion order.

With a One-Shot Dialect Conversion, we would no longer maintain the original ops. There would also be no unlinked/old blocks anymore.

The order of conversion now matters because the context is defined both up/down the hierarchy.

The traversal order is well-defined for a dialect conversion. (In contrast to a greedy pattern rewrite.)

ParentOp (like function signature) converted the first or child (like block signature) converted the first will affect whether a Value is able to find its context.

Is that because once you converted an operation, it can no longer serve as "context"?

@ZenithalHourlyRate
Copy link
Member

ZenithalHourlyRate commented May 19, 2025

Block Return Type (should find Block Terminator, looking "down" in the hierarchy; or should it find parent Block?)

What is a block return type?

Sorry I was thinking about function return type and my old code used to traverse down to the entry block terminator to get its context. A refactor added a pre-processing to lift that up to the defining op (with ugly annotation to it).

I'd say that analyzing the defining op of an SSA value is fine. If that is not enough context, you're getting on thin ice.

With that pre-processing, I would agree with it, but by the current state we still need to be careful about BlockArgument. We could add a comment in code that in context-aware conversion hook the only safe IR tarversing is its defining op / parent block but no other access is guaranteed to be safe?

ParentOp (like function signature) converted the first or child (like block signature) converted the first will affect whether a Value is able to find its context.

Is that because once you converted an operation, it can no longer serve as "context"?

I would be curious in one-shot dialect conversion how would a context be defined, if the original IR is not available.

@j2kun
Copy link
Contributor

j2kun commented May 19, 2025

I think I have the same concerns as Hongren, mainly that when a region-bearing op is converted today, the block is unlinked and "analyzing the defining op of an SSA value" crashes when the value is a block argument of the old block. To get around this I had to ensure the block arguments were remapped properly, and since we were using attributes attached to the defining op, we had to ensure all our patterns copied over the context attribute (i.e., when converting the region-bearing op in question).

Perhaps you could include a context-aware conversion test that uses a region-bearing op? Maybe something like this which uses the op's attached attribute to define the width of the new type (your "increment" one is fine too).

%c0 = arith.constant {width = 64 : index} 0 : index
%0 = affine.for %arg1 = 0 to 3 iter_args(%sum = %c0) {width = 64 : index} -> index {
  %1 = arith.addi %sum, %arg1 {width = 64 : index} : index
  affine.yield %1 : index
}

In particular, if the pattern that processes arith.addi requires a precondition ensured by the pattern that processes affine.for (e.g., the attribute must be copied over), then this would be useful information to document. However, it sounds like this will not be a problem, as you mentioned removing this unlinked block behavior. So this would be a temporary warning in some sense.

@matthias-springer
Copy link
Member Author

when a region-bearing op is converted today, the block is unlinked and "analyzing the defining op of an SSA value" crashes when the value is a block argument of the old block. To get around this I had to ensure the block arguments were remapped properly,

So at the point of time when the type converter is called, the block is already detached? I thought this happens after type conversion and after the replacement block has been added... How did you work around the issue exactly?

In particular, if the pattern that processes arith.addi requires a precondition ensured by the pattern that processes affine.for (e.g., the attribute must be copied over), then this would be useful information to document.

Document where? This sounds to me not like a problem of the type converter, but the usual problem of discardable attributes getting dropped.

we had to ensure all our patterns copied over the context attribute

Have you thought about turning your context attribute into a non-discardable attribute? Or wrapping SSA values in an assume op instead of attaching discardable attributes?

@ZenithalHourlyRate
Copy link
Member

we had to ensure all our patterns copied over the context attribute

Have you thought about turning your context attribute into a non-discardable attribute? Or wrapping SSA values in an assume op instead of attaching discardable attributes?

Our downstream project heavily use arith/tensor dialect and attach context attribute to them so we are unable to convert them into non-discardable attribute. The assume op approach would prevent canonicalizer/cse for arith/tensor to happen, which we also rely on.

when a region-bearing op is converted today, the block is unlinked and "analyzing the defining op of an SSA value" crashes when the value is a block argument of the old block. To get around this I had to ensure the block arguments were remapped properly,

So at the point of time when the type converter is called, the block is already detached? I thought this happens after type conversion and after the replacement block has been added... How did you work around the issue exactly?

Technically this happens at the legality check stage, see workaround here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:scf mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants