18
18
#include " llvm/ADT/MapVector.h"
19
19
#include " llvm/ADT/StringMap.h"
20
20
#include < type_traits>
21
+ #include < variant>
21
22
22
23
namespace mlir {
23
24
@@ -139,7 +140,8 @@ class TypeConverter {
139
140
};
140
141
141
142
// / Register a conversion function. A conversion function must be convertible
142
- // / to any of the following forms (where `T` is a class derived from `Type`):
143
+ // / to any of the following forms (where `T` is `Value` or a class derived
144
+ // / from `Type`, including `Type` itself):
143
145
// /
144
146
// / * std::optional<Type>(T)
145
147
// / - This form represents a 1-1 type conversion. It should return nullptr
@@ -154,6 +156,14 @@ class TypeConverter {
154
156
// / `std::nullopt` is returned, the converter is allowed to try another
155
157
// / conversion function to perform the conversion.
156
158
// /
159
+ // / Conversion functions that accept `Value` as the first argument are
160
+ // / context-aware. I.e., they can take into account IR when converting the
161
+ // / type of the given value. Context-unaware conversion functions accept
162
+ // / `Type` or a derived class as the first argument.
163
+ // /
164
+ // / Note: Context-unaware conversions are cached, but context-aware
165
+ // / conversions are not.
166
+ // /
157
167
// / Note: When attempting to convert a type, e.g. via 'convertType', the
158
168
// / mostly recently added conversions will be invoked first.
159
169
template <typename FnT, typename T = typename llvm::function_traits<
@@ -241,15 +251,28 @@ class TypeConverter {
241
251
wrapTypeAttributeConversion<T, A>(std::forward<FnT>(callback)));
242
252
}
243
253
244
- // / Convert the given type. This function should return failure if no valid
254
+ // / Convert the given type. This function returns failure if no valid
245
255
// / conversion exists, success otherwise. If the new set of types is empty,
246
256
// / the type is removed and any usages of the existing value are expected to
247
257
// / be removed during conversion.
258
+ // /
259
+ // / Note: This overload invokes only context-unaware type conversion
260
+ // / functions. Users should call the other overload if possible.
248
261
LogicalResult convertType (Type t, SmallVectorImpl<Type> &results) const ;
249
262
263
+ // / Convert the type of the given value. This function returns failure if no
264
+ // / valid conversion exists, success otherwise. If the new set of types is
265
+ // / empty, the type is removed and any usages of the existing value are
266
+ // / expected to be removed during conversion.
267
+ // /
268
+ // / Note: This overload invokes both context-aware and context-unaware type
269
+ // / conversion functions.
270
+ LogicalResult convertType (Value v, SmallVectorImpl<Type> &results) const ;
271
+
250
272
// / This hook simplifies defining 1-1 type conversions. This function returns
251
273
// / the type to convert to on success, and a null type on failure.
252
274
Type convertType (Type t) const ;
275
+ Type convertType (Value v) const ;
253
276
254
277
// / Attempts a 1-1 type conversion, expecting the result type to be
255
278
// / `TargetType`. Returns the converted type cast to `TargetType` on success,
@@ -258,13 +281,23 @@ class TypeConverter {
258
281
TargetType convertType (Type t) const {
259
282
return dyn_cast_or_null<TargetType>(convertType (t));
260
283
}
284
+ template <typename TargetType>
285
+ TargetType convertType (Value v) const {
286
+ return dyn_cast_or_null<TargetType>(convertType (v));
287
+ }
261
288
262
- // / Convert the given set of types, filling 'results' as necessary. This
263
- // / returns failure if the conversion of any of the types fails, success
289
+ // / Convert the given types, filling 'results' as necessary. This returns
290
+ // / " failure" if the conversion of any of the types fails, " success"
264
291
// / otherwise.
265
292
LogicalResult convertTypes (TypeRange types,
266
293
SmallVectorImpl<Type> &results) const ;
267
294
295
+ // / Convert the types of the given values, filling 'results' as necessary.
296
+ // / This returns "failure" if the conversion of any of the types fails,
297
+ // / "success" otherwise.
298
+ LogicalResult convertTypes (ValueRange values,
299
+ SmallVectorImpl<Type> &results) const ;
300
+
268
301
// / Return true if the given type is legal for this type converter, i.e. the
269
302
// / type converts to itself.
270
303
bool isLegal (Type type) const ;
@@ -328,7 +361,7 @@ class TypeConverter {
328
361
// / types is empty, the type is removed and any usages of the existing value
329
362
// / are expected to be removed during conversion.
330
363
using ConversionCallbackFn = std::function<std::optional<LogicalResult>(
331
- Type, SmallVectorImpl<Type> &)>;
364
+ std::variant< Type, Value> , SmallVectorImpl<Type> &)>;
332
365
333
366
// / The signature of the callback used to materialize a source conversion.
334
367
// /
@@ -348,13 +381,14 @@ class TypeConverter {
348
381
349
382
// / Generate a wrapper for the given callback. This allows for accepting
350
383
// / different callback forms, that all compose into a single version.
351
- // / With callback of form: `std::optional<Type>(T)`
384
+ // / With callback of form: `std::optional<Type>(T)`, where `T` can be a
385
+ // / `Value` or a `Type` (or a class derived from `Type`).
352
386
template <typename T, typename FnT>
353
387
std::enable_if_t <std::is_invocable_v<FnT, T>, ConversionCallbackFn>
354
- wrapCallback (FnT &&callback) const {
388
+ wrapCallback (FnT &&callback) {
355
389
return wrapCallback<T>([callback = std::forward<FnT>(callback)](
356
- T type , SmallVectorImpl<Type> &results) {
357
- if (std::optional<Type> resultOpt = callback (type )) {
390
+ T typeOrValue , SmallVectorImpl<Type> &results) {
391
+ if (std::optional<Type> resultOpt = callback (typeOrValue )) {
358
392
bool wasSuccess = static_cast <bool >(*resultOpt);
359
393
if (wasSuccess)
360
394
results.push_back (*resultOpt);
@@ -364,20 +398,49 @@ class TypeConverter {
364
398
});
365
399
}
366
400
// / With callback of form: `std::optional<LogicalResult>(
367
- // / T, SmallVectorImpl<Type> &, ArrayRef<Type>)` .
401
+ // / T, SmallVectorImpl<Type> &)`, where `T` is a type .
368
402
template <typename T, typename FnT>
369
- std::enable_if_t <std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &>,
403
+ std::enable_if_t <std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &> &&
404
+ std::is_base_of_v<Type, T>,
370
405
ConversionCallbackFn>
371
406
wrapCallback (FnT &&callback) const {
372
407
return [callback = std::forward<FnT>(callback)](
373
- Type type,
408
+ std::variant< Type, Value> type,
374
409
SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
375
- T derivedType = dyn_cast<T>(type);
410
+ T derivedType;
411
+ if (Type *t = std::get_if<Type>(&type)) {
412
+ derivedType = dyn_cast<T>(*t);
413
+ } else if (Value *v = std::get_if<Value>(&type)) {
414
+ derivedType = dyn_cast<T>(v->getType ());
415
+ } else {
416
+ llvm_unreachable (" unexpected variant" );
417
+ }
376
418
if (!derivedType)
377
419
return std::nullopt;
378
420
return callback (derivedType, results);
379
421
};
380
422
}
423
+ // / With callback of form: `std::optional<LogicalResult>(
424
+ // / T, SmallVectorImpl<Type>)`, where `T` is a `Value`.
425
+ template <typename T, typename FnT>
426
+ std::enable_if_t <std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &> &&
427
+ std::is_same_v<T, Value>,
428
+ ConversionCallbackFn>
429
+ wrapCallback (FnT &&callback) {
430
+ hasContextAwareTypeConversions = true ;
431
+ return [callback = std::forward<FnT>(callback)](
432
+ std::variant<Type, Value> type,
433
+ SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
434
+ if (Type *t = std::get_if<Type>(&type)) {
435
+ // Context-aware type conversion was called with a type.
436
+ return std::nullopt;
437
+ } else if (Value *v = std::get_if<Value>(&type)) {
438
+ return callback (*v, results);
439
+ }
440
+ llvm_unreachable (" unexpected variant" );
441
+ return std::nullopt;
442
+ };
443
+ }
381
444
382
445
// / Register a type conversion.
383
446
void registerConversion (ConversionCallbackFn callback) {
@@ -504,6 +567,12 @@ class TypeConverter {
504
567
mutable DenseMap<Type, SmallVector<Type, 2 >> cachedMultiConversions;
505
568
// / A mutex used for cache access
506
569
mutable llvm::sys::SmartRWMutex<true > cacheMutex;
570
+ // / Whether the type converter has context-aware type conversions. I.e.,
571
+ // / conversion rules that depend on the SSA value instead of just the type.
572
+ // / Type conversion caching is deactivated when there are context-aware
573
+ // / conversions because the type converter may return different results for
574
+ // / the same input type.
575
+ bool hasContextAwareTypeConversions = false ;
507
576
};
508
577
509
578
// ===----------------------------------------------------------------------===//
0 commit comments