Skip to content

[mlir][Transforms] Dialect conversion: Context-aware type conversions #140434

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 29 additions & 18 deletions mlir/docs/DialectConversion.md
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,15 @@ target types. If the source type is converted to itself, we say it is a "legal"
type. Type conversions are specified via the `addConversion` method described
below.

There are two kind of conversion functions: context-aware and context-unaware
conversions. A context-unaware conversion function converts a `Type` into a
`Type`. A context-aware conversion function converts a `Value` into a type. The
latter allows users to customize type conversion rules based on the IR.

Note: When there is at least one context-aware type conversion function, the
result of type conversions can no longer be cached, which can increase
compilation time. Use this feature with caution!

A `materialization` describes how a list of values should be converted to a
list of values with specific types. An important distinction from a
`conversion` is that a `materialization` can produce IR, whereas a `conversion`
Expand Down Expand Up @@ -287,29 +296,31 @@ Several of the available hooks are detailed below:
```c++
class TypeConverter {
public:
/// Register a conversion function. A conversion function defines how a given
/// source type should be converted. A conversion function must be convertible
/// to any of the following forms(where `T` is a class derived from `Type`:
/// * Optional<Type>(T)
/// Register a conversion function. A conversion function must be convertible
/// to any of the following forms (where `T` is `Value` or a class derived
/// from `Type`, including `Type` itself):
///
/// * std::optional<Type>(T)
/// - This form represents a 1-1 type conversion. It should return nullptr
/// or `std::nullopt` to signify failure. If `std::nullopt` is returned, the
/// converter is allowed to try another conversion function to perform
/// the conversion.
/// * Optional<LogicalResult>(T, SmallVectorImpl<Type> &)
/// or `std::nullopt` to signify failure. If `std::nullopt` is returned,
/// the converter is allowed to try another conversion function to
/// perform the conversion.
/// * std::optional<LogicalResult>(T, SmallVectorImpl<Type> &)
/// - This form represents a 1-N type conversion. It should return
/// `failure` or `std::nullopt` to signify a failed conversion. If the new
/// set of types is empty, the type is removed and any usages of the
/// `failure` or `std::nullopt` to signify a failed conversion. If the
/// new set of types is empty, the type is removed and any usages of the
/// existing value are expected to be removed during conversion. If
/// `std::nullopt` is returned, the converter is allowed to try another
/// conversion function to perform the conversion.
/// * Optional<LogicalResult>(T, SmallVectorImpl<Type> &, ArrayRef<Type>)
/// - This form represents a 1-N type conversion supporting recursive
/// types. The first two arguments and the return value are the same as
/// for the regular 1-N form. The third argument is contains is the
/// "call stack" of the recursive conversion: it contains the list of
/// types currently being converted, with the current type being the
/// last one. If it is present more than once in the list, the
/// conversion concerns a recursive type.
///
/// Conversion functions that accept `Value` as the first argument are
/// context-aware. I.e., they can take into account IR when converting the
/// type of the given value. Context-unaware conversion functions accept
/// `Type` or a derived class as the first argument.
///
/// Note: Context-unaware conversions are cached, but context-aware
/// conversions are not.
///
/// Note: When attempting to convert a type, e.g. via 'convertType', the
/// mostly recently added conversions will be invoked first.
template <typename FnT,
Expand Down
94 changes: 81 additions & 13 deletions mlir/include/mlir/Transforms/DialectConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ class TypeConverter {
};

/// Register a conversion function. A conversion function must be convertible
/// to any of the following forms (where `T` is a class derived from `Type`):
/// to any of the following forms (where `T` is `Value` or a class derived
/// from `Type`, including `Type` itself):
///
/// * std::optional<Type>(T)
/// - This form represents a 1-1 type conversion. It should return nullptr
Expand All @@ -154,6 +155,14 @@ class TypeConverter {
/// `std::nullopt` is returned, the converter is allowed to try another
/// conversion function to perform the conversion.
///
/// Conversion functions that accept `Value` as the first argument are
/// context-aware. I.e., they can take into account IR when converting the
/// type of the given value. Context-unaware conversion functions accept
/// `Type` or a derived class as the first argument.
///
/// Note: Context-unaware conversions are cached, but context-aware
/// conversions are not.
///
/// Note: When attempting to convert a type, e.g. via 'convertType', the
/// mostly recently added conversions will be invoked first.
template <typename FnT, typename T = typename llvm::function_traits<
Expand Down Expand Up @@ -241,15 +250,28 @@ class TypeConverter {
wrapTypeAttributeConversion<T, A>(std::forward<FnT>(callback)));
}

/// Convert the given type. This function should return failure if no valid
/// Convert the given type. This function returns failure if no valid
/// conversion exists, success otherwise. If the new set of types is empty,
/// the type is removed and any usages of the existing value are expected to
/// be removed during conversion.
///
/// Note: This overload invokes only context-unaware type conversion
/// functions. Users should call the other overload if possible.
LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) const;

/// Convert the type of the given value. This function returns failure if no
/// valid conversion exists, success otherwise. If the new set of types is
/// empty, the type is removed and any usages of the existing value are
/// expected to be removed during conversion.
///
/// Note: This overload invokes both context-aware and context-unaware type
/// conversion functions.
LogicalResult convertType(Value v, SmallVectorImpl<Type> &results) const;

/// This hook simplifies defining 1-1 type conversions. This function returns
/// the type to convert to on success, and a null type on failure.
Type convertType(Type t) const;
Type convertType(Value v) const;

/// Attempts a 1-1 type conversion, expecting the result type to be
/// `TargetType`. Returns the converted type cast to `TargetType` on success,
Expand All @@ -258,13 +280,23 @@ class TypeConverter {
TargetType convertType(Type t) const {
return dyn_cast_or_null<TargetType>(convertType(t));
}
template <typename TargetType>
TargetType convertType(Value v) const {
return dyn_cast_or_null<TargetType>(convertType(v));
}

/// Convert the given set of types, filling 'results' as necessary. This
/// returns failure if the conversion of any of the types fails, success
/// Convert the given types, filling 'results' as necessary. This returns
/// "failure" if the conversion of any of the types fails, "success"
/// otherwise.
LogicalResult convertTypes(TypeRange types,
SmallVectorImpl<Type> &results) const;

/// Convert the types of the given values, filling 'results' as necessary.
/// This returns "failure" if the conversion of any of the types fails,
/// "success" otherwise.
LogicalResult convertTypes(ValueRange values,
SmallVectorImpl<Type> &results) const;

/// Return true if the given type is legal for this type converter, i.e. the
/// type converts to itself.
bool isLegal(Type type) const;
Expand Down Expand Up @@ -328,7 +360,7 @@ class TypeConverter {
/// types is empty, the type is removed and any usages of the existing value
/// are expected to be removed during conversion.
using ConversionCallbackFn = std::function<std::optional<LogicalResult>(
Type, SmallVectorImpl<Type> &)>;
PointerUnion<Type, Value>, SmallVectorImpl<Type> &)>;

/// The signature of the callback used to materialize a source conversion.
///
Expand All @@ -348,13 +380,14 @@ class TypeConverter {

/// Generate a wrapper for the given callback. This allows for accepting
/// different callback forms, that all compose into a single version.
/// With callback of form: `std::optional<Type>(T)`
/// With callback of form: `std::optional<Type>(T)`, where `T` can be a
/// `Value` or a `Type` (or a class derived from `Type`).
template <typename T, typename FnT>
std::enable_if_t<std::is_invocable_v<FnT, T>, ConversionCallbackFn>
wrapCallback(FnT &&callback) const {
wrapCallback(FnT &&callback) {
return wrapCallback<T>([callback = std::forward<FnT>(callback)](
T type, SmallVectorImpl<Type> &results) {
if (std::optional<Type> resultOpt = callback(type)) {
T typeOrValue, SmallVectorImpl<Type> &results) {
if (std::optional<Type> resultOpt = callback(typeOrValue)) {
bool wasSuccess = static_cast<bool>(*resultOpt);
if (wasSuccess)
results.push_back(*resultOpt);
Expand All @@ -364,20 +397,49 @@ class TypeConverter {
});
}
/// With callback of form: `std::optional<LogicalResult>(
/// T, SmallVectorImpl<Type> &, ArrayRef<Type>)`.
/// T, SmallVectorImpl<Type> &)`, where `T` is a type.
template <typename T, typename FnT>
std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &>,
std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &> &&
std::is_base_of_v<Type, T>,
ConversionCallbackFn>
wrapCallback(FnT &&callback) const {
return [callback = std::forward<FnT>(callback)](
Type type,
PointerUnion<Type, Value> typeOrValue,
SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
T derivedType = dyn_cast<T>(type);
T derivedType;
if (Type t = dyn_cast<Type>(typeOrValue)) {
derivedType = dyn_cast<T>(t);
} else if (Value v = dyn_cast<Value>(typeOrValue)) {
derivedType = dyn_cast<T>(v.getType());
} else {
llvm_unreachable("unexpected variant");
}
if (!derivedType)
return std::nullopt;
return callback(derivedType, results);
};
}
/// With callback of form: `std::optional<LogicalResult>(
/// T, SmallVectorImpl<Type>)`, where `T` is a `Value`.
template <typename T, typename FnT>
std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &> &&
std::is_same_v<T, Value>,
ConversionCallbackFn>
wrapCallback(FnT &&callback) {
hasContextAwareTypeConversions = true;
return [callback = std::forward<FnT>(callback)](
PointerUnion<Type, Value> typeOrValue,
SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
if (Type t = dyn_cast<Type>(typeOrValue)) {
// Context-aware type conversion was called with a type.
return std::nullopt;
} else if (Value v = dyn_cast<Value>(typeOrValue)) {
return callback(v, results);
}
llvm_unreachable("unexpected variant");
return std::nullopt;
};
}

/// Register a type conversion.
void registerConversion(ConversionCallbackFn callback) {
Expand Down Expand Up @@ -504,6 +566,12 @@ class TypeConverter {
mutable DenseMap<Type, SmallVector<Type, 2>> cachedMultiConversions;
/// A mutex used for cache access
mutable llvm::sys::SmartRWMutex<true> cacheMutex;
/// Whether the type converter has context-aware type conversions. I.e.,
/// conversion rules that depend on the SSA value instead of just the type.
/// Type conversion caching is deactivated when there are context-aware
/// conversions because the type converter may return different results for
/// the same input type.
bool hasContextAwareTypeConversions = false;
};

//===----------------------------------------------------------------------===//
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ class Structural1ToNConversionPattern : public OpConversionPattern<SourceOp> {
SmallVector<unsigned> offsets;
offsets.push_back(0);
// Do the type conversion and record the offsets.
for (Type type : op.getResultTypes()) {
if (failed(typeConverter->convertTypes(type, dstTypes)))
for (Value v : op.getResults()) {
if (failed(typeConverter->convertType(v, dstTypes)))
return rewriter.notifyMatchFailure(op, "could not convert result type");
offsets.push_back(dstTypes.size());
}
Expand Down
44 changes: 42 additions & 2 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1256,7 +1256,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(

// If there is no legal conversion, fail to match this pattern.
SmallVector<Type, 1> legalTypes;
if (failed(currentTypeConverter->convertType(origType, legalTypes))) {
if (failed(currentTypeConverter->convertType(operand, legalTypes))) {
notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
diag << "unable to convert type for " << valueDiagTag << " #"
<< it.index() << ", type was " << origType;
Expand Down Expand Up @@ -2899,6 +2899,27 @@ LogicalResult TypeConverter::convertType(Type t,
return failure();
}

LogicalResult TypeConverter::convertType(Value v,
SmallVectorImpl<Type> &results) const {
assert(v && "expected non-null value");

// If this type converter does not have context-aware type conversions, call
// the type-based overload, which has caching.
if (!hasContextAwareTypeConversions)
return convertType(v.getType(), results);

// Walk the added converters in reverse order to apply the most recently
// registered first.
for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
if (std::optional<LogicalResult> result = converter(v, results)) {
if (!succeeded(*result))
return failure();
return success();
}
}
return failure();
}

Type TypeConverter::convertType(Type t) const {
// Use the multi-type result version to convert the type.
SmallVector<Type, 1> results;
Expand All @@ -2909,6 +2930,16 @@ Type TypeConverter::convertType(Type t) const {
return results.size() == 1 ? results.front() : nullptr;
}

Type TypeConverter::convertType(Value v) const {
// Use the multi-type result version to convert the type.
SmallVector<Type, 1> results;
if (failed(convertType(v, results)))
return nullptr;

// Check to ensure that only one type was produced.
return results.size() == 1 ? results.front() : nullptr;
}

LogicalResult
TypeConverter::convertTypes(TypeRange types,
SmallVectorImpl<Type> &results) const {
Expand All @@ -2918,6 +2949,15 @@ TypeConverter::convertTypes(TypeRange types,
return success();
}

LogicalResult
TypeConverter::convertTypes(ValueRange values,
SmallVectorImpl<Type> &results) const {
for (Value value : values)
if (failed(convertType(value, results)))
return failure();
return success();
}

bool TypeConverter::isLegal(Type type) const {
return convertType(type) == type;
}
Expand Down Expand Up @@ -3128,7 +3168,7 @@ mlir::convertOpResultTypes(Operation *op, ValueRange operands,
newOp.addOperands(operands);

SmallVector<Type> newResultTypes;
if (failed(converter.convertTypes(op->getResultTypes(), newResultTypes)))
if (failed(converter.convertTypes(op->getResults(), newResultTypes)))
return rewriter.notifyMatchFailure(loc, "couldn't convert return types");

newOp.addTypes(newResultTypes);
Expand Down
18 changes: 18 additions & 0 deletions mlir/test/Transforms/test-legalize-type-conversion.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,21 @@ func.func @test_signature_conversion_no_converter() {
}) : () -> ()
return
}

// -----

// CHECK-LABEL: func @context_aware_conversion()
func.func @context_aware_conversion() {
// Case 1: Convert i37 --> i38.
// CHECK: %[[cast0:.*]] = builtin.unrealized_conversion_cast %{{.*}} : i37 to i38
// CHECK: "test.legal_op_d"(%[[cast0]]) : (i38) -> ()
%0 = "test.context_op"() {increment = 1 : i64} : () -> (i37)
"test.replace_with_legal_op"(%0) : (i37) -> ()

// Case 2: Convert i37 --> i39.
// CHECK: %[[cast1:.*]] = builtin.unrealized_conversion_cast %{{.*}} : i37 to i39
// CHECK: "test.legal_op_d"(%[[cast1]]) : (i39) -> ()
%1 = "test.context_op"() {increment = 2 : i64} : () -> (i37)
"test.replace_with_legal_op"(%1) : (i37) -> ()
return
}
Loading