| 
 | 1 | +//===-- OneToNTypeConversion.h - Utils for 1:N type conversion --*- C++ -*-===//  | 
 | 2 | +//  | 
 | 3 | +// Licensed under the Apache License v2.0 with LLVM Exceptions.  | 
 | 4 | +// See https://llvm.org/LICENSE.txt for license information.  | 
 | 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception  | 
 | 6 | +//  | 
 | 7 | +//===----------------------------------------------------------------------===//  | 
 | 8 | +//  | 
 | 9 | +// Note: The 1:N dialect conversion is deprecated and will be removed soon.  | 
 | 10 | +// 1:N support has been added to the regular dialect conversion driver.  | 
 | 11 | +//  | 
 | 12 | +// This file provides utils for implementing (poor-man's) dialect conversion  | 
 | 13 | +// passes with 1:N type conversions.  | 
 | 14 | +//  | 
 | 15 | +// The main function, `applyPartialOneToNConversion`, first applies a set of  | 
 | 16 | +// `RewritePattern`s, which produce unrealized casts to convert the operands and  | 
 | 17 | +// results from and to the source types, and then replaces all newly added  | 
 | 18 | +// unrealized casts by user-provided materializations. For this to work, the  | 
 | 19 | +// main function requires a special `TypeConverter`, a special  | 
 | 20 | +// `PatternRewriter`, and special RewritePattern`s, which extend their  | 
 | 21 | +// respective base classes for 1:N type converions.  | 
 | 22 | +//  | 
 | 23 | +// Note that this is much more simple-minded than the "real" dialect conversion,  | 
 | 24 | +// which checks for legality before applying patterns and does probably many  | 
 | 25 | +// other additional things. Ideally, some of the extensions here could be  | 
 | 26 | +// integrated there.  | 
 | 27 | +//  | 
 | 28 | +//===----------------------------------------------------------------------===//  | 
 | 29 | + | 
 | 30 | +#ifndef MLIR_TRANSFORMS_ONETONTYPECONVERSION_H  | 
 | 31 | +#define MLIR_TRANSFORMS_ONETONTYPECONVERSION_H  | 
 | 32 | + | 
 | 33 | +#include "mlir/IR/PatternMatch.h"  | 
 | 34 | +#include "mlir/Transforms/DialectConversion.h"  | 
 | 35 | +#include "llvm/ADT/SmallVector.h"  | 
 | 36 | + | 
 | 37 | +namespace mlir {  | 
 | 38 | + | 
 | 39 | +/// Stores a 1:N mapping of types and provides several useful accessors. This  | 
 | 40 | +/// class extends `SignatureConversion`, which already supports 1:N type  | 
 | 41 | +/// mappings but lacks some accessors into the mapping as well as access to the  | 
 | 42 | +/// original types.  | 
 | 43 | +class OneToNTypeMapping : public TypeConverter::SignatureConversion {  | 
 | 44 | +public:  | 
 | 45 | +  OneToNTypeMapping(TypeRange originalTypes)  | 
 | 46 | +      : TypeConverter::SignatureConversion(originalTypes.size()),  | 
 | 47 | +        originalTypes(originalTypes) {}  | 
 | 48 | + | 
 | 49 | +  using TypeConverter::SignatureConversion::getConvertedTypes;  | 
 | 50 | + | 
 | 51 | +  /// Returns the list of types that corresponds to the original type at the  | 
 | 52 | +  /// given index.  | 
 | 53 | +  TypeRange getConvertedTypes(unsigned originalTypeNo) const;  | 
 | 54 | + | 
 | 55 | +  /// Returns the list of original types.  | 
 | 56 | +  TypeRange getOriginalTypes() const { return originalTypes; }  | 
 | 57 | + | 
 | 58 | +  /// Returns the slice of converted values that corresponds the original value  | 
 | 59 | +  /// at the given index.  | 
 | 60 | +  ValueRange getConvertedValues(ValueRange convertedValues,  | 
 | 61 | +                                unsigned originalValueNo) const;  | 
 | 62 | + | 
 | 63 | +  /// Fills the given result vector with as many copies of the location of the  | 
 | 64 | +  /// original value as the number of values it is converted to.  | 
 | 65 | +  void convertLocation(Value originalValue, unsigned originalValueNo,  | 
 | 66 | +                       llvm::SmallVectorImpl<Location> &result) const;  | 
 | 67 | + | 
 | 68 | +  /// Fills the given result vector with as many copies of the lociation of each  | 
 | 69 | +  /// original value as the number of values they are respectively converted to.  | 
 | 70 | +  void convertLocations(ValueRange originalValues,  | 
 | 71 | +                        llvm::SmallVectorImpl<Location> &result) const;  | 
 | 72 | + | 
 | 73 | +  /// Returns true iff at least one type conversion maps an input type to a type  | 
 | 74 | +  /// that is different from itself.  | 
 | 75 | +  bool hasNonIdentityConversion() const;  | 
 | 76 | + | 
 | 77 | +private:  | 
 | 78 | +  llvm::SmallVector<Type> originalTypes;  | 
 | 79 | +};  | 
 | 80 | + | 
 | 81 | +/// Extends the basic `RewritePattern` class with a type converter member and  | 
 | 82 | +/// some accessors to it. This is useful for patterns that are not  | 
 | 83 | +/// `ConversionPattern`s but still require access to a type converter.  | 
 | 84 | +class RewritePatternWithConverter : public mlir::RewritePattern {  | 
 | 85 | +public:  | 
 | 86 | +  /// Construct a conversion pattern with the given converter, and forward the  | 
 | 87 | +  /// remaining arguments to RewritePattern.  | 
 | 88 | +  template <typename... Args>  | 
 | 89 | +  RewritePatternWithConverter(const TypeConverter &typeConverter,  | 
 | 90 | +                              Args &&...args)  | 
 | 91 | +      : RewritePattern(std::forward<Args>(args)...),  | 
 | 92 | +        typeConverter(&typeConverter) {}  | 
 | 93 | + | 
 | 94 | +  /// Return the type converter held by this pattern, or nullptr if the pattern  | 
 | 95 | +  /// does not require type conversion.  | 
 | 96 | +  const TypeConverter *getTypeConverter() const { return typeConverter; }  | 
 | 97 | + | 
 | 98 | +  template <typename ConverterTy>  | 
 | 99 | +  std::enable_if_t<std::is_base_of<TypeConverter, ConverterTy>::value,  | 
 | 100 | +                   const ConverterTy *>  | 
 | 101 | +  getTypeConverter() const {  | 
 | 102 | +    return static_cast<const ConverterTy *>(typeConverter);  | 
 | 103 | +  }  | 
 | 104 | + | 
 | 105 | +protected:  | 
 | 106 | +  /// A type converter for use by this pattern.  | 
 | 107 | +  const TypeConverter *const typeConverter;  | 
 | 108 | +};  | 
 | 109 | + | 
 | 110 | +/// Specialization of `PatternRewriter` that `OneToNConversionPattern`s use. The  | 
 | 111 | +/// class provides additional rewrite methods that are specific to 1:N type  | 
 | 112 | +/// conversions.  | 
 | 113 | +class OneToNPatternRewriter : public PatternRewriter {  | 
 | 114 | +public:  | 
 | 115 | +  OneToNPatternRewriter(MLIRContext *context,  | 
 | 116 | +                        OpBuilder::Listener *listener = nullptr)  | 
 | 117 | +      : PatternRewriter(context, listener) {}  | 
 | 118 | + | 
 | 119 | +  /// Replaces the results of the operation with the specified list of values  | 
 | 120 | +  /// mapped back to the original types as specified in the provided type  | 
 | 121 | +  /// mapping. That type mapping must match the replaced op (i.e., the original  | 
 | 122 | +  /// types must be the same as the result types of the op) and the new values  | 
 | 123 | +  /// (i.e., the converted types must be the same as the types of the new  | 
 | 124 | +  /// values).  | 
 | 125 | +  /// FIXME: The 1:N dialect conversion is deprecated and will be removed soon.  | 
 | 126 | +  /// Use replaceOpWithMultiple() instead.  | 
 | 127 | +  void replaceOp(Operation *op, ValueRange newValues,  | 
 | 128 | +                 const OneToNTypeMapping &resultMapping);  | 
 | 129 | +  using PatternRewriter::replaceOp;  | 
 | 130 | + | 
 | 131 | +  /// Applies the given argument conversion to the given block. This consists of  | 
 | 132 | +  /// replacing each original argument with N arguments as specified in the  | 
 | 133 | +  /// argument conversion and inserting unrealized casts from the converted  | 
 | 134 | +  /// values to the original types, which are then used in lieu of the original  | 
 | 135 | +  /// ones. (Eventually, `applyPartialOneToNConversion` replaces these casts  | 
 | 136 | +  /// with a user-provided argument materialization if necessary.) This is  | 
 | 137 | +  /// similar to `ArgConverter::applySignatureConversion` but (1) handles 1:N  | 
 | 138 | +  /// type conversion properly and probably (2) doesn't handle many other edge  | 
 | 139 | +  /// cases.  | 
 | 140 | +  Block *applySignatureConversion(Block *block,  | 
 | 141 | +                                  OneToNTypeMapping &argumentConversion);  | 
 | 142 | +};  | 
 | 143 | + | 
 | 144 | +/// Base class for patterns with 1:N type conversions. Derived classes have to  | 
 | 145 | +/// overwrite the `matchAndRewrite` overlaod that provides additional  | 
 | 146 | +/// information for 1:N type conversions.  | 
 | 147 | +class OneToNConversionPattern : public RewritePatternWithConverter {  | 
 | 148 | +public:  | 
 | 149 | +  using RewritePatternWithConverter::RewritePatternWithConverter;  | 
 | 150 | + | 
 | 151 | +  /// This function has to be implemented by derived classes and is called from  | 
 | 152 | +  /// the usual overloads. Like in "normal" `DialectConversion`, the function is  | 
 | 153 | +  /// provided with the converted operands (which thus have target types). Since  | 
 | 154 | +  /// 1:N conversions are supported, there is usually no 1:1 relationship  | 
 | 155 | +  /// between the original and the converted operands. Instead, the provided  | 
 | 156 | +  /// `operandMapping` can be used to access the converted operands that  | 
 | 157 | +  /// correspond to a particular original operand. Similarly, `resultMapping`  | 
 | 158 | +  /// is provided to help with assembling the result values, which may have 1:N  | 
 | 159 | +  /// correspondences as well. In that case, the original op should be replaced  | 
 | 160 | +  /// with the overload of `replaceOp` that takes the provided `resultMapping`  | 
 | 161 | +  /// in order to deal with the mapping of converted result values to their  | 
 | 162 | +  /// usages in the original types correctly.  | 
 | 163 | +  virtual LogicalResult matchAndRewrite(Operation *op,  | 
 | 164 | +                                        OneToNPatternRewriter &rewriter,  | 
 | 165 | +                                        const OneToNTypeMapping &operandMapping,  | 
 | 166 | +                                        const OneToNTypeMapping &resultMapping,  | 
 | 167 | +                                        ValueRange convertedOperands) const = 0;  | 
 | 168 | + | 
 | 169 | +  LogicalResult matchAndRewrite(Operation *op,  | 
 | 170 | +                                PatternRewriter &rewriter) const final;  | 
 | 171 | +};  | 
 | 172 | + | 
 | 173 | +/// This class is a wrapper around `OneToNConversionPattern` for matching  | 
 | 174 | +/// against instances of a particular op class.  | 
 | 175 | +template <typename SourceOp>  | 
 | 176 | +class OneToNOpConversionPattern : public OneToNConversionPattern {  | 
 | 177 | +public:  | 
 | 178 | +  OneToNOpConversionPattern(const TypeConverter &typeConverter,  | 
 | 179 | +                            MLIRContext *context, PatternBenefit benefit = 1,  | 
 | 180 | +                            ArrayRef<StringRef> generatedNames = {})  | 
 | 181 | +      : OneToNConversionPattern(typeConverter, SourceOp::getOperationName(),  | 
 | 182 | +                                benefit, context, generatedNames) {}  | 
 | 183 | +  /// Generic adaptor around the root op of this pattern using the converted  | 
 | 184 | +  /// operands. Importantly, each operand is represented as a *range* of values,  | 
 | 185 | +  /// namely the N values each original operand gets converted to. Concretely,  | 
 | 186 | +  /// this makes the result type of the accessor functions of the adaptor class  | 
 | 187 | +  /// be a `ValueRange`.  | 
 | 188 | +  class OpAdaptor  | 
 | 189 | +      : public SourceOp::template GenericAdaptor<ArrayRef<ValueRange>> {  | 
 | 190 | +  public:  | 
 | 191 | +    using RangeT = ArrayRef<ValueRange>;  | 
 | 192 | +    using BaseT = typename SourceOp::template GenericAdaptor<RangeT>;  | 
 | 193 | +    using Properties = typename SourceOp::template InferredProperties<SourceOp>;  | 
 | 194 | + | 
 | 195 | +    OpAdaptor(const OneToNTypeMapping *operandMapping,  | 
 | 196 | +              const OneToNTypeMapping *resultMapping,  | 
 | 197 | +              const ValueRange *convertedOperands, RangeT values, SourceOp op)  | 
 | 198 | +        : BaseT(values, op), operandMapping(operandMapping),  | 
 | 199 | +          resultMapping(resultMapping), convertedOperands(convertedOperands) {}  | 
 | 200 | + | 
 | 201 | +    /// Get the type mapping of the original operands to the converted operands.  | 
 | 202 | +    const OneToNTypeMapping &getOperandMapping() const {  | 
 | 203 | +      return *operandMapping;  | 
 | 204 | +    }  | 
 | 205 | + | 
 | 206 | +    /// Get the type mapping of the original results to the converted results.  | 
 | 207 | +    const OneToNTypeMapping &getResultMapping() const { return *resultMapping; }  | 
 | 208 | + | 
 | 209 | +    /// Get a flat range of all converted operands. Unlike `getOperands`, which  | 
 | 210 | +    /// returns an `ArrayRef` with one `ValueRange` for each original operand,  | 
 | 211 | +    /// this function returns a `ValueRange` that contains all converted  | 
 | 212 | +    /// operands irrespectively of which operand they originated from.  | 
 | 213 | +    ValueRange getFlatOperands() const { return *convertedOperands; }  | 
 | 214 | + | 
 | 215 | +  private:  | 
 | 216 | +    const OneToNTypeMapping *operandMapping;  | 
 | 217 | +    const OneToNTypeMapping *resultMapping;  | 
 | 218 | +    const ValueRange *convertedOperands;  | 
 | 219 | +  };  | 
 | 220 | + | 
 | 221 | +  using OneToNConversionPattern::matchAndRewrite;  | 
 | 222 | + | 
 | 223 | +  /// Overload that derived classes have to override for their op type.  | 
 | 224 | +  virtual LogicalResult  | 
 | 225 | +  matchAndRewrite(SourceOp op, OpAdaptor adaptor,  | 
 | 226 | +                  OneToNPatternRewriter &rewriter) const = 0;  | 
 | 227 | + | 
 | 228 | +  LogicalResult matchAndRewrite(Operation *op, OneToNPatternRewriter &rewriter,  | 
 | 229 | +                                const OneToNTypeMapping &operandMapping,  | 
 | 230 | +                                const OneToNTypeMapping &resultMapping,  | 
 | 231 | +                                ValueRange convertedOperands) const final {  | 
 | 232 | +    // Wrap converted operands and type mappings into an adaptor.  | 
 | 233 | +    SmallVector<ValueRange> valueRanges;  | 
 | 234 | +    for (int64_t i = 0; i < op->getNumOperands(); i++) {  | 
 | 235 | +      auto values = operandMapping.getConvertedValues(convertedOperands, i);  | 
 | 236 | +      valueRanges.push_back(values);  | 
 | 237 | +    }  | 
 | 238 | +    OpAdaptor adaptor(&operandMapping, &resultMapping, &convertedOperands,  | 
 | 239 | +                      valueRanges, cast<SourceOp>(op));  | 
 | 240 | + | 
 | 241 | +    // Call overload implemented by the derived class.  | 
 | 242 | +    return matchAndRewrite(cast<SourceOp>(op), adaptor, rewriter);  | 
 | 243 | +  }  | 
 | 244 | +};  | 
 | 245 | + | 
 | 246 | +/// Applies the given set of patterns recursively on the given op and adds user  | 
 | 247 | +/// materializations where necessary. The patterns are expected to be  | 
 | 248 | +/// `OneToNConversionPattern`, which help converting the types of the operands  | 
 | 249 | +/// and results of the matched ops. The provided type converter is used to  | 
 | 250 | +/// convert the operands of matched ops from their original types to operands  | 
 | 251 | +/// with different types. Unlike in `DialectConversion`, this supports 1:N type  | 
 | 252 | +/// conversions. Those conversions at the "boundary" of the pattern application,  | 
 | 253 | +/// where converted results are not consumed by replaced ops that expect the  | 
 | 254 | +/// converted operands or vice versa, the function inserts user materializations  | 
 | 255 | +/// from the type converter. Also unlike `DialectConversion`, there are no legal  | 
 | 256 | +/// or illegal types; the function simply applies the given patterns and does  | 
 | 257 | +/// not fail if some ops or types remain unconverted (i.e., the conversion is  | 
 | 258 | +/// only "partial").  | 
 | 259 | +/// FIXME: The 1:N dialect conversion is deprecated and will be removed soon.  | 
 | 260 | +/// 1:N support has been added to the regular dialect conversion driver.  | 
 | 261 | +/// Use applyPartialConversion() instead.  | 
 | 262 | +LogicalResult  | 
 | 263 | +applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter,  | 
 | 264 | +                             const FrozenRewritePatternSet &patterns);  | 
 | 265 | + | 
 | 266 | +/// Add a pattern to the given pattern list to convert the signature of a  | 
 | 267 | +/// FunctionOpInterface op with the given type converter. This only supports  | 
 | 268 | +/// ops which use FunctionType to represent their type. This is intended to be  | 
 | 269 | +/// used with the 1:N dialect conversion.  | 
 | 270 | +/// FIXME: The 1:N dialect conversion is deprecated and will be removed soon.  | 
 | 271 | +/// 1:N support has been added to the regular dialect conversion driver.  | 
 | 272 | +/// Use populateFunctionOpInterfaceTypeConversionPattern() instead.  | 
 | 273 | +void populateOneToNFunctionOpInterfaceTypeConversionPattern(  | 
 | 274 | +    StringRef functionLikeOpName, const TypeConverter &converter,  | 
 | 275 | +    RewritePatternSet &patterns);  | 
 | 276 | +template <typename FuncOpT>  | 
 | 277 | +void populateOneToNFunctionOpInterfaceTypeConversionPattern(  | 
 | 278 | +    const TypeConverter &converter, RewritePatternSet &patterns) {  | 
 | 279 | +  populateOneToNFunctionOpInterfaceTypeConversionPattern(  | 
 | 280 | +      FuncOpT::getOperationName(), converter, patterns);  | 
 | 281 | +}  | 
 | 282 | + | 
 | 283 | +} // namespace mlir  | 
 | 284 | + | 
 | 285 | +#endif // MLIR_TRANSFORMS_ONETONTYPECONVERSION_H  | 
0 commit comments