Skip to content

Commit d565695

Browse files
prototype
experiment more
1 parent 1bc0043 commit d565695

File tree

6 files changed

+200
-38
lines changed

6 files changed

+200
-38
lines changed

mlir/docs/DialectConversion.md

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,15 @@ target types. If the source type is converted to itself, we say it is a "legal"
235235
type. Type conversions are specified via the `addConversion` method described
236236
below.
237237
238+
There are two kind of conversion functions: context-aware and context-unaware
239+
conversions. A context-unaware conversion function converts a `Type` into a
240+
`Type`. A context-aware conversion function converts a `Value` into a type. The
241+
latter allows users to customize type conversion rules based on the IR.
242+
243+
Note: When there is at least one context-aware type conversion function, the
244+
result of type conversions can no longer be cached, which can increase
245+
compilation time. Use this feature with caution!
246+
238247
A `materialization` describes how a list of values should be converted to a
239248
list of values with specific types. An important distinction from a
240249
`conversion` is that a `materialization` can produce IR, whereas a `conversion`
@@ -287,29 +296,31 @@ Several of the available hooks are detailed below:
287296
```c++
288297
class TypeConverter {
289298
public:
290-
/// Register a conversion function. A conversion function defines how a given
291-
/// source type should be converted. A conversion function must be convertible
292-
/// to any of the following forms(where `T` is a class derived from `Type`:
293-
/// * Optional<Type>(T)
299+
/// Register a conversion function. A conversion function must be convertible
300+
/// to any of the following forms (where `T` is `Value` or a class derived
301+
/// from `Type`, including `Type` itself):
302+
///
303+
/// * std::optional<Type>(T)
294304
/// - This form represents a 1-1 type conversion. It should return nullptr
295-
/// or `std::nullopt` to signify failure. If `std::nullopt` is returned, the
296-
/// converter is allowed to try another conversion function to perform
297-
/// the conversion.
298-
/// * Optional<LogicalResult>(T, SmallVectorImpl<Type> &)
305+
/// or `std::nullopt` to signify failure. If `std::nullopt` is returned,
306+
/// the converter is allowed to try another conversion function to
307+
/// perform the conversion.
308+
/// * std::optional<LogicalResult>(T, SmallVectorImpl<Type> &)
299309
/// - This form represents a 1-N type conversion. It should return
300-
/// `failure` or `std::nullopt` to signify a failed conversion. If the new
301-
/// set of types is empty, the type is removed and any usages of the
310+
/// `failure` or `std::nullopt` to signify a failed conversion. If the
311+
/// new set of types is empty, the type is removed and any usages of the
302312
/// existing value are expected to be removed during conversion. If
303313
/// `std::nullopt` is returned, the converter is allowed to try another
304314
/// conversion function to perform the conversion.
305-
/// * Optional<LogicalResult>(T, SmallVectorImpl<Type> &, ArrayRef<Type>)
306-
/// - This form represents a 1-N type conversion supporting recursive
307-
/// types. The first two arguments and the return value are the same as
308-
/// for the regular 1-N form. The third argument is contains is the
309-
/// "call stack" of the recursive conversion: it contains the list of
310-
/// types currently being converted, with the current type being the
311-
/// last one. If it is present more than once in the list, the
312-
/// conversion concerns a recursive type.
315+
///
316+
/// Conversion functions that accept `Value` as the first argument are
317+
/// context-aware. I.e., they can take into account IR when converting the
318+
/// type of the given value. Context-unaware conversion functions accept
319+
/// `Type` or a derived class as the first argument.
320+
///
321+
/// Note: Context-unaware conversions are cached, but context-aware
322+
/// conversions are not.
323+
///
313324
/// Note: When attempting to convert a type, e.g. via 'convertType', the
314325
/// mostly recently added conversions will be invoked first.
315326
template <typename FnT,

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 82 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "llvm/ADT/MapVector.h"
1919
#include "llvm/ADT/StringMap.h"
2020
#include <type_traits>
21+
#include <variant>
2122

2223
namespace mlir {
2324

@@ -139,7 +140,8 @@ class TypeConverter {
139140
};
140141

141142
/// Register a conversion function. A conversion function must be convertible
142-
/// to any of the following forms (where `T` is a class derived from `Type`):
143+
/// to any of the following forms (where `T` is `Value` or a class derived
144+
/// from `Type`, including `Type` itself):
143145
///
144146
/// * std::optional<Type>(T)
145147
/// - This form represents a 1-1 type conversion. It should return nullptr
@@ -154,6 +156,14 @@ class TypeConverter {
154156
/// `std::nullopt` is returned, the converter is allowed to try another
155157
/// conversion function to perform the conversion.
156158
///
159+
/// Conversion functions that accept `Value` as the first argument are
160+
/// context-aware. I.e., they can take into account IR when converting the
161+
/// type of the given value. Context-unaware conversion functions accept
162+
/// `Type` or a derived class as the first argument.
163+
///
164+
/// Note: Context-unaware conversions are cached, but context-aware
165+
/// conversions are not.
166+
///
157167
/// Note: When attempting to convert a type, e.g. via 'convertType', the
158168
/// mostly recently added conversions will be invoked first.
159169
template <typename FnT, typename T = typename llvm::function_traits<
@@ -241,15 +251,28 @@ class TypeConverter {
241251
wrapTypeAttributeConversion<T, A>(std::forward<FnT>(callback)));
242252
}
243253

244-
/// Convert the given type. This function should return failure if no valid
254+
/// Convert the given type. This function returns failure if no valid
245255
/// conversion exists, success otherwise. If the new set of types is empty,
246256
/// the type is removed and any usages of the existing value are expected to
247257
/// be removed during conversion.
258+
///
259+
/// Note: This overload invokes only context-unaware type conversion
260+
/// functions. Users should call the other overload if possible.
248261
LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) const;
249262

263+
/// Convert the type of the given value. This function returns failure if no
264+
/// valid conversion exists, success otherwise. If the new set of types is
265+
/// empty, the type is removed and any usages of the existing value are
266+
/// expected to be removed during conversion.
267+
///
268+
/// Note: This overload invokes both context-aware and context-unaware type
269+
/// conversion functions.
270+
LogicalResult convertType(Value v, SmallVectorImpl<Type> &results) const;
271+
250272
/// This hook simplifies defining 1-1 type conversions. This function returns
251273
/// the type to convert to on success, and a null type on failure.
252274
Type convertType(Type t) const;
275+
Type convertType(Value v) const;
253276

254277
/// Attempts a 1-1 type conversion, expecting the result type to be
255278
/// `TargetType`. Returns the converted type cast to `TargetType` on success,
@@ -258,13 +281,23 @@ class TypeConverter {
258281
TargetType convertType(Type t) const {
259282
return dyn_cast_or_null<TargetType>(convertType(t));
260283
}
284+
template <typename TargetType>
285+
TargetType convertType(Value v) const {
286+
return dyn_cast_or_null<TargetType>(convertType(v));
287+
}
261288

262-
/// Convert the given set of types, filling 'results' as necessary. This
263-
/// returns failure if the conversion of any of the types fails, success
289+
/// Convert the given types, filling 'results' as necessary. This returns
290+
/// "failure" if the conversion of any of the types fails, "success"
264291
/// otherwise.
265292
LogicalResult convertTypes(TypeRange types,
266293
SmallVectorImpl<Type> &results) const;
267294

295+
/// Convert the types of the given values, filling 'results' as necessary.
296+
/// This returns "failure" if the conversion of any of the types fails,
297+
/// "success" otherwise.
298+
LogicalResult convertTypes(ValueRange values,
299+
SmallVectorImpl<Type> &results) const;
300+
268301
/// Return true if the given type is legal for this type converter, i.e. the
269302
/// type converts to itself.
270303
bool isLegal(Type type) const;
@@ -328,7 +361,7 @@ class TypeConverter {
328361
/// types is empty, the type is removed and any usages of the existing value
329362
/// are expected to be removed during conversion.
330363
using ConversionCallbackFn = std::function<std::optional<LogicalResult>(
331-
Type, SmallVectorImpl<Type> &)>;
364+
std::variant<Type, Value>, SmallVectorImpl<Type> &)>;
332365

333366
/// The signature of the callback used to materialize a source conversion.
334367
///
@@ -348,13 +381,14 @@ class TypeConverter {
348381

349382
/// Generate a wrapper for the given callback. This allows for accepting
350383
/// different callback forms, that all compose into a single version.
351-
/// With callback of form: `std::optional<Type>(T)`
384+
/// With callback of form: `std::optional<Type>(T)`, where `T` can be a
385+
/// `Value` or a `Type` (or a class derived from `Type`).
352386
template <typename T, typename FnT>
353387
std::enable_if_t<std::is_invocable_v<FnT, T>, ConversionCallbackFn>
354-
wrapCallback(FnT &&callback) const {
388+
wrapCallback(FnT &&callback) {
355389
return wrapCallback<T>([callback = std::forward<FnT>(callback)](
356-
T type, SmallVectorImpl<Type> &results) {
357-
if (std::optional<Type> resultOpt = callback(type)) {
390+
T typeOrValue, SmallVectorImpl<Type> &results) {
391+
if (std::optional<Type> resultOpt = callback(typeOrValue)) {
358392
bool wasSuccess = static_cast<bool>(*resultOpt);
359393
if (wasSuccess)
360394
results.push_back(*resultOpt);
@@ -364,20 +398,49 @@ class TypeConverter {
364398
});
365399
}
366400
/// With callback of form: `std::optional<LogicalResult>(
367-
/// T, SmallVectorImpl<Type> &, ArrayRef<Type>)`.
401+
/// T, SmallVectorImpl<Type> &)`, where `T` is a type.
368402
template <typename T, typename FnT>
369-
std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &>,
403+
std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &> &&
404+
std::is_base_of_v<Type, T>,
370405
ConversionCallbackFn>
371406
wrapCallback(FnT &&callback) const {
372407
return [callback = std::forward<FnT>(callback)](
373-
Type type,
408+
std::variant<Type, Value> type,
374409
SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
375-
T derivedType = dyn_cast<T>(type);
410+
T derivedType;
411+
if (Type *t = std::get_if<Type>(&type)) {
412+
derivedType = dyn_cast<T>(*t);
413+
} else if (Value *v = std::get_if<Value>(&type)) {
414+
derivedType = dyn_cast<T>(v->getType());
415+
} else {
416+
llvm_unreachable("unexpected variant");
417+
}
376418
if (!derivedType)
377419
return std::nullopt;
378420
return callback(derivedType, results);
379421
};
380422
}
423+
/// With callback of form: `std::optional<LogicalResult>(
424+
/// T, SmallVectorImpl<Type>)`, where `T` is a `Value`.
425+
template <typename T, typename FnT>
426+
std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &> &&
427+
std::is_same_v<T, Value>,
428+
ConversionCallbackFn>
429+
wrapCallback(FnT &&callback) {
430+
hasContextAwareTypeConversions = true;
431+
return [callback = std::forward<FnT>(callback)](
432+
std::variant<Type, Value> type,
433+
SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
434+
if (Type *t = std::get_if<Type>(&type)) {
435+
// Context-aware type conversion was called with a type.
436+
return std::nullopt;
437+
} else if (Value *v = std::get_if<Value>(&type)) {
438+
return callback(*v, results);
439+
}
440+
llvm_unreachable("unexpected variant");
441+
return std::nullopt;
442+
};
443+
}
381444

382445
/// Register a type conversion.
383446
void registerConversion(ConversionCallbackFn callback) {
@@ -504,6 +567,12 @@ class TypeConverter {
504567
mutable DenseMap<Type, SmallVector<Type, 2>> cachedMultiConversions;
505568
/// A mutex used for cache access
506569
mutable llvm::sys::SmartRWMutex<true> cacheMutex;
570+
/// Whether the type converter has context-aware type conversions. I.e.,
571+
/// conversion rules that depend on the SSA value instead of just the type.
572+
/// Type conversion caching is deactivated when there are context-aware
573+
/// conversions because the type converter may return different results for
574+
/// the same input type.
575+
bool hasContextAwareTypeConversions = false;
507576
};
508577

509578
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ class Structural1ToNConversionPattern : public OpConversionPattern<SourceOp> {
5252
SmallVector<unsigned> offsets;
5353
offsets.push_back(0);
5454
// Do the type conversion and record the offsets.
55-
for (Type type : op.getResultTypes()) {
56-
if (failed(typeConverter->convertTypes(type, dstTypes)))
55+
for (Value v : op.getResults()) {
56+
if (failed(typeConverter->convertType(v, dstTypes)))
5757
return rewriter.notifyMatchFailure(op, "could not convert result type");
5858
offsets.push_back(dstTypes.size());
5959
}

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1256,7 +1256,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
12561256

12571257
// If there is no legal conversion, fail to match this pattern.
12581258
SmallVector<Type, 1> legalTypes;
1259-
if (failed(currentTypeConverter->convertType(origType, legalTypes))) {
1259+
if (failed(currentTypeConverter->convertType(operand, legalTypes))) {
12601260
notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
12611261
diag << "unable to convert type for " << valueDiagTag << " #"
12621262
<< it.index() << ", type was " << origType;
@@ -2899,6 +2899,28 @@ LogicalResult TypeConverter::convertType(Type t,
28992899
return failure();
29002900
}
29012901

2902+
LogicalResult TypeConverter::convertType(Value v,
2903+
SmallVectorImpl<Type> &results) const {
2904+
assert(v && "expected non-null value");
2905+
2906+
// If this type converter does not have context-aware type conversions, call
2907+
// the type-based overload, which has caching.
2908+
if (!hasContextAwareTypeConversions) {
2909+
return convertType(v.getType(), results);
2910+
}
2911+
2912+
// Walk the added converters in reverse order to apply the most recently
2913+
// registered first.
2914+
for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
2915+
if (std::optional<LogicalResult> result = converter(v, results)) {
2916+
if (!succeeded(*result))
2917+
return failure();
2918+
return success();
2919+
}
2920+
}
2921+
return failure();
2922+
}
2923+
29022924
Type TypeConverter::convertType(Type t) const {
29032925
// Use the multi-type result version to convert the type.
29042926
SmallVector<Type, 1> results;
@@ -2909,6 +2931,16 @@ Type TypeConverter::convertType(Type t) const {
29092931
return results.size() == 1 ? results.front() : nullptr;
29102932
}
29112933

2934+
Type TypeConverter::convertType(Value v) const {
2935+
// Use the multi-type result version to convert the type.
2936+
SmallVector<Type, 1> results;
2937+
if (failed(convertType(v, results)))
2938+
return nullptr;
2939+
2940+
// Check to ensure that only one type was produced.
2941+
return results.size() == 1 ? results.front() : nullptr;
2942+
}
2943+
29122944
LogicalResult
29132945
TypeConverter::convertTypes(TypeRange types,
29142946
SmallVectorImpl<Type> &results) const {
@@ -2918,6 +2950,15 @@ TypeConverter::convertTypes(TypeRange types,
29182950
return success();
29192951
}
29202952

2953+
LogicalResult
2954+
TypeConverter::convertTypes(ValueRange values,
2955+
SmallVectorImpl<Type> &results) const {
2956+
for (Value value : values)
2957+
if (failed(convertType(value, results)))
2958+
return failure();
2959+
return success();
2960+
}
2961+
29212962
bool TypeConverter::isLegal(Type type) const {
29222963
return convertType(type) == type;
29232964
}
@@ -3128,7 +3169,7 @@ mlir::convertOpResultTypes(Operation *op, ValueRange operands,
31283169
newOp.addOperands(operands);
31293170

31303171
SmallVector<Type> newResultTypes;
3131-
if (failed(converter.convertTypes(op->getResultTypes(), newResultTypes)))
3172+
if (failed(converter.convertTypes(op->getResults(), newResultTypes)))
31323173
return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
31333174

31343175
newOp.addTypes(newResultTypes);

mlir/test/Transforms/test-legalize-type-conversion.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,3 +142,21 @@ func.func @test_signature_conversion_no_converter() {
142142
}) : () -> ()
143143
return
144144
}
145+
146+
// -----
147+
148+
// CHECK-LABEL: func @context_aware_conversion()
149+
func.func @context_aware_conversion() {
150+
// Case 1: Convert i37 --> i38.
151+
// CHECK: %[[cast0:.*]] = unrealized_conversion_cast %{{.*}} : i37 to i38
152+
// CHECK: "test.legal_op_d"(%[[cast0]]) : (i38) -> ()
153+
%0 = "test.context_op"() {increment = 1 : i64} : () -> (i37)
154+
"test.replace_with_legal_op"(%0) : (i37) -> ()
155+
156+
// Case 2: Convert i37 --> i39.
157+
// CHECK: %[[cast1:.*]] = unrealized_conversion_cast %{{.*}} : i37 to i39
158+
// CHECK: "test.legal_op_d"(%[[cast1]]) : (i39) -> ()
159+
%1 = "test.context_op"() {increment = 2 : i64} : () -> (i37)
160+
"test.replace_with_legal_op"(%1) : (i37) -> ()
161+
return
162+
}

0 commit comments

Comments
 (0)