-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][Transforms] Dialect conversion: Context-aware type conversions #140434
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[mlir][Transforms] Dialect conversion: Context-aware type conversions #140434
Conversation
@llvm/pr-subscribers-mlir-scf @llvm/pr-subscribers-mlir-core Author: Matthias Springer (matthias-springer) ChangesThis commit adds support for context-aware type conversions: type conversion rules that can return different types depending on the IR. There is no change for existing (context-unaware) type conversion rules:
There is now an additional overload to register context-aware type conversion rules:
For performance reasons, the type converter caches the result of type conversions. This is no longer possible when there context-aware type conversions because each conversion could compute a different type depending on the context. There is no performance degradation when there are only context-unaware type conversions. Note: This commit just adds context-aware type conversions to the dialect conversion framework. There are many existing patterns that still call Patch is 20.60 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/140434.diff 6 Files Affected:
diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md
index cf577eca5b9a6..61872d10670dc 100644
--- a/mlir/docs/DialectConversion.md
+++ b/mlir/docs/DialectConversion.md
@@ -235,6 +235,15 @@ target types. If the source type is converted to itself, we say it is a "legal"
type. Type conversions are specified via the `addConversion` method described
below.
+There are two kind of conversion functions: context-aware and context-unaware
+conversions. A context-unaware conversion function converts a `Type` into a
+`Type`. A context-aware conversion function converts a `Value` into a type. The
+latter allows users to customize type conversion rules based on the IR.
+
+Note: When there is at least one context-aware type conversion function, the
+result of type conversions can no longer be cached, which can increase
+compilation time. Use this feature with caution!
+
A `materialization` describes how a list of values should be converted to a
list of values with specific types. An important distinction from a
`conversion` is that a `materialization` can produce IR, whereas a `conversion`
@@ -287,29 +296,31 @@ Several of the available hooks are detailed below:
```c++
class TypeConverter {
public:
- /// Register a conversion function. A conversion function defines how a given
- /// source type should be converted. A conversion function must be convertible
- /// to any of the following forms(where `T` is a class derived from `Type`:
- /// * Optional<Type>(T)
+ /// Register a conversion function. A conversion function must be convertible
+ /// to any of the following forms (where `T` is `Value` or a class derived
+ /// from `Type`, including `Type` itself):
+ ///
+ /// * std::optional<Type>(T)
/// - This form represents a 1-1 type conversion. It should return nullptr
- /// or `std::nullopt` to signify failure. If `std::nullopt` is returned, the
- /// converter is allowed to try another conversion function to perform
- /// the conversion.
- /// * Optional<LogicalResult>(T, SmallVectorImpl<Type> &)
+ /// or `std::nullopt` to signify failure. If `std::nullopt` is returned,
+ /// the converter is allowed to try another conversion function to
+ /// perform the conversion.
+ /// * std::optional<LogicalResult>(T, SmallVectorImpl<Type> &)
/// - This form represents a 1-N type conversion. It should return
- /// `failure` or `std::nullopt` to signify a failed conversion. If the new
- /// set of types is empty, the type is removed and any usages of the
+ /// `failure` or `std::nullopt` to signify a failed conversion. If the
+ /// new set of types is empty, the type is removed and any usages of the
/// existing value are expected to be removed during conversion. If
/// `std::nullopt` is returned, the converter is allowed to try another
/// conversion function to perform the conversion.
- /// * Optional<LogicalResult>(T, SmallVectorImpl<Type> &, ArrayRef<Type>)
- /// - This form represents a 1-N type conversion supporting recursive
- /// types. The first two arguments and the return value are the same as
- /// for the regular 1-N form. The third argument is contains is the
- /// "call stack" of the recursive conversion: it contains the list of
- /// types currently being converted, with the current type being the
- /// last one. If it is present more than once in the list, the
- /// conversion concerns a recursive type.
+ ///
+ /// Conversion functions that accept `Value` as the first argument are
+ /// context-aware. I.e., they can take into account IR when converting the
+ /// type of the given value. Context-unaware conversion functions accept
+ /// `Type` or a derived class as the first argument.
+ ///
+ /// Note: Context-unaware conversions are cached, but context-aware
+ /// conversions are not.
+ ///
/// Note: When attempting to convert a type, e.g. via 'convertType', the
/// mostly recently added conversions will be invoked first.
template <typename FnT,
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index e7d05c3ce1adf..07adbde3a5a60 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -18,6 +18,7 @@
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/StringMap.h"
#include <type_traits>
+#include <variant>
namespace mlir {
@@ -139,7 +140,8 @@ class TypeConverter {
};
/// Register a conversion function. A conversion function must be convertible
- /// to any of the following forms (where `T` is a class derived from `Type`):
+ /// to any of the following forms (where `T` is `Value` or a class derived
+ /// from `Type`, including `Type` itself):
///
/// * std::optional<Type>(T)
/// - This form represents a 1-1 type conversion. It should return nullptr
@@ -154,6 +156,14 @@ class TypeConverter {
/// `std::nullopt` is returned, the converter is allowed to try another
/// conversion function to perform the conversion.
///
+ /// Conversion functions that accept `Value` as the first argument are
+ /// context-aware. I.e., they can take into account IR when converting the
+ /// type of the given value. Context-unaware conversion functions accept
+ /// `Type` or a derived class as the first argument.
+ ///
+ /// Note: Context-unaware conversions are cached, but context-aware
+ /// conversions are not.
+ ///
/// Note: When attempting to convert a type, e.g. via 'convertType', the
/// mostly recently added conversions will be invoked first.
template <typename FnT, typename T = typename llvm::function_traits<
@@ -241,15 +251,28 @@ class TypeConverter {
wrapTypeAttributeConversion<T, A>(std::forward<FnT>(callback)));
}
- /// Convert the given type. This function should return failure if no valid
+ /// Convert the given type. This function returns failure if no valid
/// conversion exists, success otherwise. If the new set of types is empty,
/// the type is removed and any usages of the existing value are expected to
/// be removed during conversion.
+ ///
+ /// Note: This overload invokes only context-unaware type conversion
+ /// functions. Users should call the other overload if possible.
LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) const;
+ /// Convert the type of the given value. This function returns failure if no
+ /// valid conversion exists, success otherwise. If the new set of types is
+ /// empty, the type is removed and any usages of the existing value are
+ /// expected to be removed during conversion.
+ ///
+ /// Note: This overload invokes both context-aware and context-unaware type
+ /// conversion functions.
+ LogicalResult convertType(Value v, SmallVectorImpl<Type> &results) const;
+
/// This hook simplifies defining 1-1 type conversions. This function returns
/// the type to convert to on success, and a null type on failure.
Type convertType(Type t) const;
+ Type convertType(Value v) const;
/// Attempts a 1-1 type conversion, expecting the result type to be
/// `TargetType`. Returns the converted type cast to `TargetType` on success,
@@ -258,13 +281,23 @@ class TypeConverter {
TargetType convertType(Type t) const {
return dyn_cast_or_null<TargetType>(convertType(t));
}
+ template <typename TargetType>
+ TargetType convertType(Value v) const {
+ return dyn_cast_or_null<TargetType>(convertType(v));
+ }
- /// Convert the given set of types, filling 'results' as necessary. This
- /// returns failure if the conversion of any of the types fails, success
+ /// Convert the given types, filling 'results' as necessary. This returns
+ /// "failure" if the conversion of any of the types fails, "success"
/// otherwise.
LogicalResult convertTypes(TypeRange types,
SmallVectorImpl<Type> &results) const;
+ /// Convert the types of the given values, filling 'results' as necessary.
+ /// This returns "failure" if the conversion of any of the types fails,
+ /// "success" otherwise.
+ LogicalResult convertTypes(ValueRange values,
+ SmallVectorImpl<Type> &results) const;
+
/// Return true if the given type is legal for this type converter, i.e. the
/// type converts to itself.
bool isLegal(Type type) const;
@@ -328,7 +361,7 @@ class TypeConverter {
/// types is empty, the type is removed and any usages of the existing value
/// are expected to be removed during conversion.
using ConversionCallbackFn = std::function<std::optional<LogicalResult>(
- Type, SmallVectorImpl<Type> &)>;
+ std::variant<Type, Value>, SmallVectorImpl<Type> &)>;
/// The signature of the callback used to materialize a source conversion.
///
@@ -348,13 +381,14 @@ class TypeConverter {
/// Generate a wrapper for the given callback. This allows for accepting
/// different callback forms, that all compose into a single version.
- /// With callback of form: `std::optional<Type>(T)`
+ /// With callback of form: `std::optional<Type>(T)`, where `T` can be a
+ /// `Value` or a `Type` (or a class derived from `Type`).
template <typename T, typename FnT>
std::enable_if_t<std::is_invocable_v<FnT, T>, ConversionCallbackFn>
- wrapCallback(FnT &&callback) const {
+ wrapCallback(FnT &&callback) {
return wrapCallback<T>([callback = std::forward<FnT>(callback)](
- T type, SmallVectorImpl<Type> &results) {
- if (std::optional<Type> resultOpt = callback(type)) {
+ T typeOrValue, SmallVectorImpl<Type> &results) {
+ if (std::optional<Type> resultOpt = callback(typeOrValue)) {
bool wasSuccess = static_cast<bool>(*resultOpt);
if (wasSuccess)
results.push_back(*resultOpt);
@@ -364,20 +398,49 @@ class TypeConverter {
});
}
/// With callback of form: `std::optional<LogicalResult>(
- /// T, SmallVectorImpl<Type> &, ArrayRef<Type>)`.
+ /// T, SmallVectorImpl<Type> &)`, where `T` is a type.
template <typename T, typename FnT>
- std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &>,
+ std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &> &&
+ std::is_base_of_v<Type, T>,
ConversionCallbackFn>
wrapCallback(FnT &&callback) const {
return [callback = std::forward<FnT>(callback)](
- Type type,
+ std::variant<Type, Value> type,
SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
- T derivedType = dyn_cast<T>(type);
+ T derivedType;
+ if (Type *t = std::get_if<Type>(&type)) {
+ derivedType = dyn_cast<T>(*t);
+ } else if (Value *v = std::get_if<Value>(&type)) {
+ derivedType = dyn_cast<T>(v->getType());
+ } else {
+ llvm_unreachable("unexpected variant");
+ }
if (!derivedType)
return std::nullopt;
return callback(derivedType, results);
};
}
+ /// With callback of form: `std::optional<LogicalResult>(
+ /// T, SmallVectorImpl<Type>)`, where `T` is a `Value`.
+ template <typename T, typename FnT>
+ std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &> &&
+ std::is_same_v<T, Value>,
+ ConversionCallbackFn>
+ wrapCallback(FnT &&callback) {
+ hasContextAwareTypeConversions = true;
+ return [callback = std::forward<FnT>(callback)](
+ std::variant<Type, Value> type,
+ SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
+ if (Type *t = std::get_if<Type>(&type)) {
+ // Context-aware type conversion was called with a type.
+ return std::nullopt;
+ } else if (Value *v = std::get_if<Value>(&type)) {
+ return callback(*v, results);
+ }
+ llvm_unreachable("unexpected variant");
+ return std::nullopt;
+ };
+ }
/// Register a type conversion.
void registerConversion(ConversionCallbackFn callback) {
@@ -504,6 +567,12 @@ class TypeConverter {
mutable DenseMap<Type, SmallVector<Type, 2>> cachedMultiConversions;
/// A mutex used for cache access
mutable llvm::sys::SmartRWMutex<true> cacheMutex;
+ /// Whether the type converter has context-aware type conversions. I.e.,
+ /// conversion rules that depend on the SSA value instead of just the type.
+ /// Type conversion caching is deactivated when there are context-aware
+ /// conversions because the type converter may return different results for
+ /// the same input type.
+ bool hasContextAwareTypeConversions = false;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
index 09326242eec2a..de4612fa0846a 100644
--- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
@@ -52,8 +52,8 @@ class Structural1ToNConversionPattern : public OpConversionPattern<SourceOp> {
SmallVector<unsigned> offsets;
offsets.push_back(0);
// Do the type conversion and record the offsets.
- for (Type type : op.getResultTypes()) {
- if (failed(typeConverter->convertTypes(type, dstTypes)))
+ for (Value v : op.getResults()) {
+ if (failed(typeConverter->convertType(v, dstTypes)))
return rewriter.notifyMatchFailure(op, "could not convert result type");
offsets.push_back(dstTypes.size());
}
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index bd11bbe58a3f6..2a1d154faeaf3 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1256,7 +1256,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
// If there is no legal conversion, fail to match this pattern.
SmallVector<Type, 1> legalTypes;
- if (failed(currentTypeConverter->convertType(origType, legalTypes))) {
+ if (failed(currentTypeConverter->convertType(operand, legalTypes))) {
notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
diag << "unable to convert type for " << valueDiagTag << " #"
<< it.index() << ", type was " << origType;
@@ -2899,6 +2899,28 @@ LogicalResult TypeConverter::convertType(Type t,
return failure();
}
+LogicalResult TypeConverter::convertType(Value v,
+ SmallVectorImpl<Type> &results) const {
+ assert(v && "expected non-null value");
+
+ // If this type converter does not have context-aware type conversions, call
+ // the type-based overload, which has caching.
+ if (!hasContextAwareTypeConversions) {
+ return convertType(v.getType(), results);
+ }
+
+ // Walk the added converters in reverse order to apply the most recently
+ // registered first.
+ for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
+ if (std::optional<LogicalResult> result = converter(v, results)) {
+ if (!succeeded(*result))
+ return failure();
+ return success();
+ }
+ }
+ return failure();
+}
+
Type TypeConverter::convertType(Type t) const {
// Use the multi-type result version to convert the type.
SmallVector<Type, 1> results;
@@ -2909,6 +2931,16 @@ Type TypeConverter::convertType(Type t) const {
return results.size() == 1 ? results.front() : nullptr;
}
+Type TypeConverter::convertType(Value v) const {
+ // Use the multi-type result version to convert the type.
+ SmallVector<Type, 1> results;
+ if (failed(convertType(v, results)))
+ return nullptr;
+
+ // Check to ensure that only one type was produced.
+ return results.size() == 1 ? results.front() : nullptr;
+}
+
LogicalResult
TypeConverter::convertTypes(TypeRange types,
SmallVectorImpl<Type> &results) const {
@@ -2918,6 +2950,15 @@ TypeConverter::convertTypes(TypeRange types,
return success();
}
+LogicalResult
+TypeConverter::convertTypes(ValueRange values,
+ SmallVectorImpl<Type> &results) const {
+ for (Value value : values)
+ if (failed(convertType(value, results)))
+ return failure();
+ return success();
+}
+
bool TypeConverter::isLegal(Type type) const {
return convertType(type) == type;
}
@@ -3128,7 +3169,7 @@ mlir::convertOpResultTypes(Operation *op, ValueRange operands,
newOp.addOperands(operands);
SmallVector<Type> newResultTypes;
- if (failed(converter.convertTypes(op->getResultTypes(), newResultTypes)))
+ if (failed(converter.convertTypes(op->getResults(), newResultTypes)))
return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
newOp.addTypes(newResultTypes);
diff --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir
index db8bd0f6378d2..7b5e6e796a528 100644
--- a/mlir/test/Transforms/test-legalize-type-conversion.mlir
+++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir
@@ -142,3 +142,21 @@ func.func @test_signature_conversion_no_converter() {
}) : () -> ()
return
}
+
+// -----
+
+// CHECK-LABEL: func @context_aware_conversion()
+func.func @context_aware_conversion() {
+ // Case 1: Convert i37 --> i38.
+ // CHECK: %[[cast0:.*]] = unrealized_conversion_cast %{{.*}} : i37 to i38
+ // CHECK: "test.legal_op_d"(%[[cast0]]) : (i38) -> ()
+ %0 = "test.context_op"() {increment = 1 : i64} : () -> (i37)
+ "test.replace_with_legal_op"(%0) : (i37) -> ()
+
+ // Case 2: Convert i37 --> i39.
+ // CHECK: %[[cast1:.*]] = unrealized_conversion_cast %{{.*}} : i37 to i39
+ // CHECK: "test.legal_op_d"(%[[cast1]]) : (i39) -> ()
+ %1 = "test.context_op"() {increment = 2 : i64} : () -> (i37)
+ "test.replace_with_legal_op"(%1) : (i37) -> ()
+ return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index d073843484d81..bd85e6fd9ae7f 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1827,9 +1827,9 @@ struct TestReplaceWithLegalOp : public ConversionPattern {
: ConversionPattern(converter, "test.replace_with_legal_op",
/*benefit=*/1, ctx) {}
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const final {
- rewriter.replaceOpWithNewOp<LegalOpD>(op, operands[0]);
+ rewriter.replaceOpWithNewOp<LegalOpD>(op, operands[0].front());
return success();
}
};
@@ -1865,7 +1865,7 @@ struct TestTypeConversionDriver
return nullptr;
});
converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) {
- // Drop all integer types.
+ // Drop all other integer types.
return success();
});
converter.addConversion(
@@ -1902,6 +1902,19 @@ struct TestTypeConversionDriver
results.push_back(result);
return success();
});
+ converter.addConversion([](Value v) -> std::optional<Type> {
+ auto intType = dyn_cast<IntegerType>(v.getType());
+ if (!intType || intType.getWidth() != 37)
+ return std::nullopt;
+ Operation *op = v.getDefiningOp();
+ if (!op)
+ return std::nullopt;
+ auto incrementAttr = op->getAttrOfType<IntegerAttr>("increment");
+ if (!incrementAttr)
+ return std::nullopt;
+ return IntegerType::get(v.getContext(),
+ intType.getWidth() + incrementAttr.getInt());
+ });
/// Add the legal set of type materializations.
converter.addSourceMaterialization([](OpBuilder &builder, Type resultType,
@@ -1922,9 +1935,19 @@ struct TestTypeConversionDriver
// Otherwise, fail.
return nullptr;
});
+ // Materialize i37 to any desired type with unrealized_conversion_cast.
+ converter.addTargetMaterialization([](OpBuilder &builder, Type type,
+ ValueRange inputs,
+ ...
[truncated]
|
@llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesThis commit adds support for context-aware type conversions: type conversion rules that can return different types depending on the IR. There is no change for existing (context-unaware) type conversion rules:
There is now an additional overload to register context-aware type conversion rules:
For performance reasons, the type converter caches the result of type conversions. This is no longer possible when there context-aware type conversions because each conversion could compute a different type depending on the context. There is no performance degradation when there are only context-unaware type conversions. Note: This commit just adds context-aware type conversions to the dialect conversion framework. There are many existing patterns that still call Patch is 20.60 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/140434.diff 6 Files Affected:
diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md
index cf577eca5b9a6..61872d10670dc 100644
--- a/mlir/docs/DialectConversion.md
+++ b/mlir/docs/DialectConversion.md
@@ -235,6 +235,15 @@ target types. If the source type is converted to itself, we say it is a "legal"
type. Type conversions are specified via the `addConversion` method described
below.
+There are two kind of conversion functions: context-aware and context-unaware
+conversions. A context-unaware conversion function converts a `Type` into a
+`Type`. A context-aware conversion function converts a `Value` into a type. The
+latter allows users to customize type conversion rules based on the IR.
+
+Note: When there is at least one context-aware type conversion function, the
+result of type conversions can no longer be cached, which can increase
+compilation time. Use this feature with caution!
+
A `materialization` describes how a list of values should be converted to a
list of values with specific types. An important distinction from a
`conversion` is that a `materialization` can produce IR, whereas a `conversion`
@@ -287,29 +296,31 @@ Several of the available hooks are detailed below:
```c++
class TypeConverter {
public:
- /// Register a conversion function. A conversion function defines how a given
- /// source type should be converted. A conversion function must be convertible
- /// to any of the following forms(where `T` is a class derived from `Type`:
- /// * Optional<Type>(T)
+ /// Register a conversion function. A conversion function must be convertible
+ /// to any of the following forms (where `T` is `Value` or a class derived
+ /// from `Type`, including `Type` itself):
+ ///
+ /// * std::optional<Type>(T)
/// - This form represents a 1-1 type conversion. It should return nullptr
- /// or `std::nullopt` to signify failure. If `std::nullopt` is returned, the
- /// converter is allowed to try another conversion function to perform
- /// the conversion.
- /// * Optional<LogicalResult>(T, SmallVectorImpl<Type> &)
+ /// or `std::nullopt` to signify failure. If `std::nullopt` is returned,
+ /// the converter is allowed to try another conversion function to
+ /// perform the conversion.
+ /// * std::optional<LogicalResult>(T, SmallVectorImpl<Type> &)
/// - This form represents a 1-N type conversion. It should return
- /// `failure` or `std::nullopt` to signify a failed conversion. If the new
- /// set of types is empty, the type is removed and any usages of the
+ /// `failure` or `std::nullopt` to signify a failed conversion. If the
+ /// new set of types is empty, the type is removed and any usages of the
/// existing value are expected to be removed during conversion. If
/// `std::nullopt` is returned, the converter is allowed to try another
/// conversion function to perform the conversion.
- /// * Optional<LogicalResult>(T, SmallVectorImpl<Type> &, ArrayRef<Type>)
- /// - This form represents a 1-N type conversion supporting recursive
- /// types. The first two arguments and the return value are the same as
- /// for the regular 1-N form. The third argument is contains is the
- /// "call stack" of the recursive conversion: it contains the list of
- /// types currently being converted, with the current type being the
- /// last one. If it is present more than once in the list, the
- /// conversion concerns a recursive type.
+ ///
+ /// Conversion functions that accept `Value` as the first argument are
+ /// context-aware. I.e., they can take into account IR when converting the
+ /// type of the given value. Context-unaware conversion functions accept
+ /// `Type` or a derived class as the first argument.
+ ///
+ /// Note: Context-unaware conversions are cached, but context-aware
+ /// conversions are not.
+ ///
/// Note: When attempting to convert a type, e.g. via 'convertType', the
/// mostly recently added conversions will be invoked first.
template <typename FnT,
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index e7d05c3ce1adf..07adbde3a5a60 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -18,6 +18,7 @@
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/StringMap.h"
#include <type_traits>
+#include <variant>
namespace mlir {
@@ -139,7 +140,8 @@ class TypeConverter {
};
/// Register a conversion function. A conversion function must be convertible
- /// to any of the following forms (where `T` is a class derived from `Type`):
+ /// to any of the following forms (where `T` is `Value` or a class derived
+ /// from `Type`, including `Type` itself):
///
/// * std::optional<Type>(T)
/// - This form represents a 1-1 type conversion. It should return nullptr
@@ -154,6 +156,14 @@ class TypeConverter {
/// `std::nullopt` is returned, the converter is allowed to try another
/// conversion function to perform the conversion.
///
+ /// Conversion functions that accept `Value` as the first argument are
+ /// context-aware. I.e., they can take into account IR when converting the
+ /// type of the given value. Context-unaware conversion functions accept
+ /// `Type` or a derived class as the first argument.
+ ///
+ /// Note: Context-unaware conversions are cached, but context-aware
+ /// conversions are not.
+ ///
/// Note: When attempting to convert a type, e.g. via 'convertType', the
/// mostly recently added conversions will be invoked first.
template <typename FnT, typename T = typename llvm::function_traits<
@@ -241,15 +251,28 @@ class TypeConverter {
wrapTypeAttributeConversion<T, A>(std::forward<FnT>(callback)));
}
- /// Convert the given type. This function should return failure if no valid
+ /// Convert the given type. This function returns failure if no valid
/// conversion exists, success otherwise. If the new set of types is empty,
/// the type is removed and any usages of the existing value are expected to
/// be removed during conversion.
+ ///
+ /// Note: This overload invokes only context-unaware type conversion
+ /// functions. Users should call the other overload if possible.
LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) const;
+ /// Convert the type of the given value. This function returns failure if no
+ /// valid conversion exists, success otherwise. If the new set of types is
+ /// empty, the type is removed and any usages of the existing value are
+ /// expected to be removed during conversion.
+ ///
+ /// Note: This overload invokes both context-aware and context-unaware type
+ /// conversion functions.
+ LogicalResult convertType(Value v, SmallVectorImpl<Type> &results) const;
+
/// This hook simplifies defining 1-1 type conversions. This function returns
/// the type to convert to on success, and a null type on failure.
Type convertType(Type t) const;
+ Type convertType(Value v) const;
/// Attempts a 1-1 type conversion, expecting the result type to be
/// `TargetType`. Returns the converted type cast to `TargetType` on success,
@@ -258,13 +281,23 @@ class TypeConverter {
TargetType convertType(Type t) const {
return dyn_cast_or_null<TargetType>(convertType(t));
}
+ template <typename TargetType>
+ TargetType convertType(Value v) const {
+ return dyn_cast_or_null<TargetType>(convertType(v));
+ }
- /// Convert the given set of types, filling 'results' as necessary. This
- /// returns failure if the conversion of any of the types fails, success
+ /// Convert the given types, filling 'results' as necessary. This returns
+ /// "failure" if the conversion of any of the types fails, "success"
/// otherwise.
LogicalResult convertTypes(TypeRange types,
SmallVectorImpl<Type> &results) const;
+ /// Convert the types of the given values, filling 'results' as necessary.
+ /// This returns "failure" if the conversion of any of the types fails,
+ /// "success" otherwise.
+ LogicalResult convertTypes(ValueRange values,
+ SmallVectorImpl<Type> &results) const;
+
/// Return true if the given type is legal for this type converter, i.e. the
/// type converts to itself.
bool isLegal(Type type) const;
@@ -328,7 +361,7 @@ class TypeConverter {
/// types is empty, the type is removed and any usages of the existing value
/// are expected to be removed during conversion.
using ConversionCallbackFn = std::function<std::optional<LogicalResult>(
- Type, SmallVectorImpl<Type> &)>;
+ std::variant<Type, Value>, SmallVectorImpl<Type> &)>;
/// The signature of the callback used to materialize a source conversion.
///
@@ -348,13 +381,14 @@ class TypeConverter {
/// Generate a wrapper for the given callback. This allows for accepting
/// different callback forms, that all compose into a single version.
- /// With callback of form: `std::optional<Type>(T)`
+ /// With callback of form: `std::optional<Type>(T)`, where `T` can be a
+ /// `Value` or a `Type` (or a class derived from `Type`).
template <typename T, typename FnT>
std::enable_if_t<std::is_invocable_v<FnT, T>, ConversionCallbackFn>
- wrapCallback(FnT &&callback) const {
+ wrapCallback(FnT &&callback) {
return wrapCallback<T>([callback = std::forward<FnT>(callback)](
- T type, SmallVectorImpl<Type> &results) {
- if (std::optional<Type> resultOpt = callback(type)) {
+ T typeOrValue, SmallVectorImpl<Type> &results) {
+ if (std::optional<Type> resultOpt = callback(typeOrValue)) {
bool wasSuccess = static_cast<bool>(*resultOpt);
if (wasSuccess)
results.push_back(*resultOpt);
@@ -364,20 +398,49 @@ class TypeConverter {
});
}
/// With callback of form: `std::optional<LogicalResult>(
- /// T, SmallVectorImpl<Type> &, ArrayRef<Type>)`.
+ /// T, SmallVectorImpl<Type> &)`, where `T` is a type.
template <typename T, typename FnT>
- std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &>,
+ std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &> &&
+ std::is_base_of_v<Type, T>,
ConversionCallbackFn>
wrapCallback(FnT &&callback) const {
return [callback = std::forward<FnT>(callback)](
- Type type,
+ std::variant<Type, Value> type,
SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
- T derivedType = dyn_cast<T>(type);
+ T derivedType;
+ if (Type *t = std::get_if<Type>(&type)) {
+ derivedType = dyn_cast<T>(*t);
+ } else if (Value *v = std::get_if<Value>(&type)) {
+ derivedType = dyn_cast<T>(v->getType());
+ } else {
+ llvm_unreachable("unexpected variant");
+ }
if (!derivedType)
return std::nullopt;
return callback(derivedType, results);
};
}
+ /// With callback of form: `std::optional<LogicalResult>(
+ /// T, SmallVectorImpl<Type>)`, where `T` is a `Value`.
+ template <typename T, typename FnT>
+ std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &> &&
+ std::is_same_v<T, Value>,
+ ConversionCallbackFn>
+ wrapCallback(FnT &&callback) {
+ hasContextAwareTypeConversions = true;
+ return [callback = std::forward<FnT>(callback)](
+ std::variant<Type, Value> type,
+ SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
+ if (Type *t = std::get_if<Type>(&type)) {
+ // Context-aware type conversion was called with a type.
+ return std::nullopt;
+ } else if (Value *v = std::get_if<Value>(&type)) {
+ return callback(*v, results);
+ }
+ llvm_unreachable("unexpected variant");
+ return std::nullopt;
+ };
+ }
/// Register a type conversion.
void registerConversion(ConversionCallbackFn callback) {
@@ -504,6 +567,12 @@ class TypeConverter {
mutable DenseMap<Type, SmallVector<Type, 2>> cachedMultiConversions;
/// A mutex used for cache access
mutable llvm::sys::SmartRWMutex<true> cacheMutex;
+ /// Whether the type converter has context-aware type conversions. I.e.,
+ /// conversion rules that depend on the SSA value instead of just the type.
+ /// Type conversion caching is deactivated when there are context-aware
+ /// conversions because the type converter may return different results for
+ /// the same input type.
+ bool hasContextAwareTypeConversions = false;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
index 09326242eec2a..de4612fa0846a 100644
--- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
@@ -52,8 +52,8 @@ class Structural1ToNConversionPattern : public OpConversionPattern<SourceOp> {
SmallVector<unsigned> offsets;
offsets.push_back(0);
// Do the type conversion and record the offsets.
- for (Type type : op.getResultTypes()) {
- if (failed(typeConverter->convertTypes(type, dstTypes)))
+ for (Value v : op.getResults()) {
+ if (failed(typeConverter->convertType(v, dstTypes)))
return rewriter.notifyMatchFailure(op, "could not convert result type");
offsets.push_back(dstTypes.size());
}
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index bd11bbe58a3f6..2a1d154faeaf3 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1256,7 +1256,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
// If there is no legal conversion, fail to match this pattern.
SmallVector<Type, 1> legalTypes;
- if (failed(currentTypeConverter->convertType(origType, legalTypes))) {
+ if (failed(currentTypeConverter->convertType(operand, legalTypes))) {
notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
diag << "unable to convert type for " << valueDiagTag << " #"
<< it.index() << ", type was " << origType;
@@ -2899,6 +2899,28 @@ LogicalResult TypeConverter::convertType(Type t,
return failure();
}
+LogicalResult TypeConverter::convertType(Value v,
+ SmallVectorImpl<Type> &results) const {
+ assert(v && "expected non-null value");
+
+ // If this type converter does not have context-aware type conversions, call
+ // the type-based overload, which has caching.
+ if (!hasContextAwareTypeConversions) {
+ return convertType(v.getType(), results);
+ }
+
+ // Walk the added converters in reverse order to apply the most recently
+ // registered first.
+ for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
+ if (std::optional<LogicalResult> result = converter(v, results)) {
+ if (!succeeded(*result))
+ return failure();
+ return success();
+ }
+ }
+ return failure();
+}
+
Type TypeConverter::convertType(Type t) const {
// Use the multi-type result version to convert the type.
SmallVector<Type, 1> results;
@@ -2909,6 +2931,16 @@ Type TypeConverter::convertType(Type t) const {
return results.size() == 1 ? results.front() : nullptr;
}
+Type TypeConverter::convertType(Value v) const {
+ // Use the multi-type result version to convert the type.
+ SmallVector<Type, 1> results;
+ if (failed(convertType(v, results)))
+ return nullptr;
+
+ // Check to ensure that only one type was produced.
+ return results.size() == 1 ? results.front() : nullptr;
+}
+
LogicalResult
TypeConverter::convertTypes(TypeRange types,
SmallVectorImpl<Type> &results) const {
@@ -2918,6 +2950,15 @@ TypeConverter::convertTypes(TypeRange types,
return success();
}
+LogicalResult
+TypeConverter::convertTypes(ValueRange values,
+ SmallVectorImpl<Type> &results) const {
+ for (Value value : values)
+ if (failed(convertType(value, results)))
+ return failure();
+ return success();
+}
+
bool TypeConverter::isLegal(Type type) const {
return convertType(type) == type;
}
@@ -3128,7 +3169,7 @@ mlir::convertOpResultTypes(Operation *op, ValueRange operands,
newOp.addOperands(operands);
SmallVector<Type> newResultTypes;
- if (failed(converter.convertTypes(op->getResultTypes(), newResultTypes)))
+ if (failed(converter.convertTypes(op->getResults(), newResultTypes)))
return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
newOp.addTypes(newResultTypes);
diff --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir
index db8bd0f6378d2..7b5e6e796a528 100644
--- a/mlir/test/Transforms/test-legalize-type-conversion.mlir
+++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir
@@ -142,3 +142,21 @@ func.func @test_signature_conversion_no_converter() {
}) : () -> ()
return
}
+
+// -----
+
+// CHECK-LABEL: func @context_aware_conversion()
+func.func @context_aware_conversion() {
+ // Case 1: Convert i37 --> i38.
+ // CHECK: %[[cast0:.*]] = unrealized_conversion_cast %{{.*}} : i37 to i38
+ // CHECK: "test.legal_op_d"(%[[cast0]]) : (i38) -> ()
+ %0 = "test.context_op"() {increment = 1 : i64} : () -> (i37)
+ "test.replace_with_legal_op"(%0) : (i37) -> ()
+
+ // Case 2: Convert i37 --> i39.
+ // CHECK: %[[cast1:.*]] = unrealized_conversion_cast %{{.*}} : i37 to i39
+ // CHECK: "test.legal_op_d"(%[[cast1]]) : (i39) -> ()
+ %1 = "test.context_op"() {increment = 2 : i64} : () -> (i37)
+ "test.replace_with_legal_op"(%1) : (i37) -> ()
+ return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index d073843484d81..bd85e6fd9ae7f 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1827,9 +1827,9 @@ struct TestReplaceWithLegalOp : public ConversionPattern {
: ConversionPattern(converter, "test.replace_with_legal_op",
/*benefit=*/1, ctx) {}
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const final {
- rewriter.replaceOpWithNewOp<LegalOpD>(op, operands[0]);
+ rewriter.replaceOpWithNewOp<LegalOpD>(op, operands[0].front());
return success();
}
};
@@ -1865,7 +1865,7 @@ struct TestTypeConversionDriver
return nullptr;
});
converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) {
- // Drop all integer types.
+ // Drop all other integer types.
return success();
});
converter.addConversion(
@@ -1902,6 +1902,19 @@ struct TestTypeConversionDriver
results.push_back(result);
return success();
});
+ converter.addConversion([](Value v) -> std::optional<Type> {
+ auto intType = dyn_cast<IntegerType>(v.getType());
+ if (!intType || intType.getWidth() != 37)
+ return std::nullopt;
+ Operation *op = v.getDefiningOp();
+ if (!op)
+ return std::nullopt;
+ auto incrementAttr = op->getAttrOfType<IntegerAttr>("increment");
+ if (!incrementAttr)
+ return std::nullopt;
+ return IntegerType::get(v.getContext(),
+ intType.getWidth() + incrementAttr.getInt());
+ });
/// Add the legal set of type materializations.
converter.addSourceMaterialization([](OpBuilder &builder, Type resultType,
@@ -1922,9 +1935,19 @@ struct TestTypeConversionDriver
// Otherwise, fail.
return nullptr;
});
+ // Materialize i37 to any desired type with unrealized_conversion_cast.
+ converter.addTargetMaterialization([](OpBuilder &builder, Type type,
+ ValueRange inputs,
+ ...
[truncated]
|
d565695
to
727c0f7
Compare
This looks great! Do you have any idea or plan on how one would be able to do context-aware type conversion when no |
If the function has a body (non-external), the context-aware type converter can be called for the block arguments of the entry block. (The lowering pattern does not do this yet. There are a bunch of patterns that must be updated in follow-up commits.) If the function does not have a body (external), this commit does not help. What would work: writing a separate conversion pattern for function ops that contains your custom type conversion logic instead of calling the type converter API. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, but please wait for either @j2kun or @ZenithalHourlyRate to be happy as well.
Disabling all caching seems unfortunate to me but expected. Maybe in a future PR this could be more fine grained such that the algorithm is:
cache = true
foreach conversion in reverse(conversions):
if is_context_aware(conversion):
cache = false
result = conversion(value)
... // as before
if cache:
cacheMap[type(value)] = result
It should be safe to cache any type/value before the type-aware conversion was queried. We could then also document and recommend users to add context-aware conversions last.
Co-authored-by: Markus Böck <markus.boeck02@gmail.com>
You probably mean "to add context-aware conversions first"? The conversions are applied in reverse order. Adding context-aware conversion first won't be possible if, let's say you're extending the LLVM type converter. When this becomes a performance problem, I would start by adding a |
Happy to see this! I think this API is general enough. I would like to list scenarios downstream project has encountered where
The context for a Type is defined either up or down in the hierarchy, which I think is not so clear and may worth some clarification. Also, typically the context is actually an I am a little bit concerned about the safety of accessing the context for
|
What is a block return type?
You may be interested in this presentation from the last MLIR workshop.
We have a similar situation in conversion patterns: What kind of IR traversals are safe? The problem is that you are seeing a mixture of old and new IR. I'd say that analyzing the defining op of an SSA value is fine. If that is not enough context, you're getting on thin ice. Btw, as part of the One-Shot Dialect Conversion refactoring, I plan to immediately materialize all IR changes. I.e., you would no longer see old IR. But there is still a long way to go for this refactoring...
With a One-Shot Dialect Conversion, we would no longer maintain the original ops. There would also be no unlinked/old blocks anymore.
The traversal order is well-defined for a dialect conversion. (In contrast to a greedy pattern rewrite.)
Is that because once you converted an operation, it can no longer serve as "context"? |
Sorry I was thinking about function return type and my old code used to traverse down to the entry block terminator to get its context. A refactor added a pre-processing to lift that up to the defining op (with ugly annotation to it).
With that pre-processing, I would agree with it, but by the current state we still need to be careful about BlockArgument. We could add a comment in code that in context-aware conversion hook the only safe IR tarversing is its defining op / parent block but no other access is guaranteed to be safe?
I would be curious in one-shot dialect conversion how would a context be defined, if the original IR is not available. |
I think I have the same concerns as Hongren, mainly that when a region-bearing op is converted today, the block is unlinked and "analyzing the defining op of an SSA value" crashes when the value is a block argument of the old block. To get around this I had to ensure the block arguments were remapped properly, and since we were using attributes attached to the defining op, we had to ensure all our patterns copied over the context attribute (i.e., when converting the region-bearing op in question). Perhaps you could include a context-aware conversion test that uses a region-bearing op? Maybe something like this which uses the op's attached attribute to define the width of the new type (your "increment" one is fine too). %c0 = arith.constant {width = 64 : index} 0 : index
%0 = affine.for %arg1 = 0 to 3 iter_args(%sum = %c0) {width = 64 : index} -> index {
%1 = arith.addi %sum, %arg1 {width = 64 : index} : index
affine.yield %1 : index
} In particular, if the pattern that processes |
So at the point of time when the type converter is called, the block is already detached? I thought this happens after type conversion and after the replacement block has been added... How did you work around the issue exactly?
Document where? This sounds to me not like a problem of the type converter, but the usual problem of discardable attributes getting dropped.
Have you thought about turning your context attribute into a non-discardable attribute? Or wrapping SSA values in an |
Our downstream project heavily use
Technically this happens at the legality check stage, see workaround here |
This commit adds support for context-aware type conversions: type conversion rules that can return different types depending on the IR.
There is no change for existing (context-unaware) type conversion rules:
There is now an additional overload to register context-aware type conversion rules:
For performance reasons, the type converter caches the result of type conversions. This is no longer possible when there context-aware type conversions because each conversion could compute a different type depending on the context. There is no performance degradation when there are only context-unaware type conversions.
Note: This commit just adds context-aware type conversions to the dialect conversion framework. There are many existing patterns that still call
converter.convertType(someValue.getType())
. These should be gradually updated in subsequent commits to callconverter.convertType(someValue)
.