|
24 | 24 | #include "mlir/IR/BlockAndValueMapping.h"
|
25 | 25 | #include "mlir/IR/Builders.h"
|
26 | 26 | #include "mlir/IR/BuiltinOps.h"
|
| 27 | +#include "mlir/IR/BuiltinTypes.h" |
27 | 28 | #include "mlir/IR/DialectImplementation.h"
|
28 | 29 | #include "mlir/IR/OpImplementation.h"
|
29 | 30 | #include "mlir/IR/PatternMatch.h"
|
@@ -2783,8 +2784,35 @@ struct FoldExtractSliceIntoTransferRead
|
2783 | 2784 | if (!extractOp.hasUnitStride())
|
2784 | 2785 | return failure();
|
2785 | 2786 |
|
| 2787 | + // Bail on illegal rank-reduction: we need to check that the rank-reduced |
| 2788 | + // dims are exactly the leading dims. I.e. the following is illegal: |
| 2789 | + // ``` |
| 2790 | + // %0 = tensor.extract_slice %t[0,0,0][2,1,4][1,1,1] : |
| 2791 | + // tensor<2x1x4xf32> to tensor<2x4xf32> |
| 2792 | + // %1 = vector.transfer_read %0[0,0], %cst : |
| 2793 | + // tensor<2x4xf32>, vector<2x4xf32> |
| 2794 | + // ``` |
| 2795 | + // |
| 2796 | + // Cannot fold into: |
| 2797 | + // ``` |
| 2798 | + // %0 = vector.transfer_read %t[0,0,0], %cst : |
| 2799 | + // tensor<2x1x4xf32>, vector<2x4xf32> |
| 2800 | + // ``` |
| 2801 | + // For this, check the trailing `vectorRank` dims of the extract_slice |
| 2802 | + // result tensor match the trailing dims of the inferred result tensor. |
2786 | 2803 | int64_t rankReduced =
|
2787 | 2804 | extractOp.getSourceType().getRank() - extractOp.getType().getRank();
|
| 2805 | + int64_t vectorRank = xferOp.getVectorType().getRank(); |
| 2806 | + RankedTensorType inferredDestTensorType = |
| 2807 | + tensor::ExtractSliceOp::inferResultType( |
| 2808 | + extractOp.getSourceType(), extractOp.getMixedOffsets(), |
| 2809 | + extractOp.getMixedSizes(), extractOp.getMixedStrides()); |
| 2810 | + auto actualDestTensorShape = extractOp.getType().getShape(); |
| 2811 | + if (rankReduced > 0 && |
| 2812 | + actualDestTensorShape.take_back(vectorRank) != |
| 2813 | + inferredDestTensorType.getShape().take_back(vectorRank)) |
| 2814 | + return failure(); |
| 2815 | + |
2788 | 2816 | SmallVector<Value> newIndices;
|
2789 | 2817 | // In case this is a rank-reducing ExtractSliceOp, copy rank-reduced
|
2790 | 2818 | // indices first.
|
@@ -3168,14 +3196,43 @@ struct FoldInsertSliceIntoTransferWrite
|
3168 | 3196 | if (xferOp.mask())
|
3169 | 3197 | return failure();
|
3170 | 3198 | // Fold only if the TransferWriteOp completely overwrites the `source` with
|
3171 |
| - // a vector. I.e., the result of the TransferWriteOp is a new tensor who's |
| 3199 | + // a vector. I.e., the result of the TransferWriteOp is a new tensor whose |
3172 | 3200 | // content is the data of the vector.
|
3173 | 3201 | if (!llvm::equal(xferOp.getVectorType().getShape(),
|
3174 | 3202 | xferOp.getShapedType().getShape()))
|
3175 | 3203 | return failure();
|
3176 | 3204 | if (!xferOp.permutation_map().isIdentity())
|
3177 | 3205 | return failure();
|
3178 | 3206 |
|
| 3207 | + // Bail on illegal rank-reduction: we need to check that the rank-reduced |
| 3208 | + // dims are exactly the leading dims. I.e. the following is illegal: |
| 3209 | + // ``` |
| 3210 | + // %0 = vector.transfer_write %v, %t[0,0], %cst : |
| 3211 | + // vector<2x4xf32>, tensor<2x4xf32> |
| 3212 | + // %1 = tensor.insert_slice %0 into %tt[0,0,0][2,1,4][1,1,1] : |
| 3213 | + // tensor<2x4xf32> into tensor<2x1x4xf32> |
| 3214 | + // ``` |
| 3215 | + // |
| 3216 | + // Cannot fold into: |
| 3217 | + // ``` |
| 3218 | + // %0 = vector.transfer_write %v, %t[0,0,0], %cst : |
| 3219 | + // vector<2x4xf32>, tensor<2x1x4xf32> |
| 3220 | + // ``` |
| 3221 | + // For this, check the trailing `vectorRank` dims of the insert_slice result |
| 3222 | + // tensor match the trailing dims of the inferred result tensor. |
| 3223 | + int64_t rankReduced = |
| 3224 | + insertOp.getType().getRank() - insertOp.getSourceType().getRank(); |
| 3225 | + int64_t vectorRank = xferOp.getVectorType().getRank(); |
| 3226 | + RankedTensorType inferredSourceTensorType = |
| 3227 | + tensor::ExtractSliceOp::inferResultType( |
| 3228 | + insertOp.getType(), insertOp.getMixedOffsets(), |
| 3229 | + insertOp.getMixedSizes(), insertOp.getMixedStrides()); |
| 3230 | + auto actualSourceTensorShape = insertOp.getSourceType().getShape(); |
| 3231 | + if (rankReduced > 0 && |
| 3232 | + actualSourceTensorShape.take_back(vectorRank) != |
| 3233 | + inferredSourceTensorType.getShape().take_back(vectorRank)) |
| 3234 | + return failure(); |
| 3235 | + |
3179 | 3236 | SmallVector<Value> indices = getValueOrCreateConstantIndexOp(
|
3180 | 3237 | rewriter, insertOp.getLoc(), insertOp.getMixedOffsets());
|
3181 | 3238 | SmallVector<bool> inBounds(xferOp.getTransferRank(), true);
|
|
0 commit comments