Skip to content

Commit 2e69f4f

Browse files
[mlir][vector] Fix illegal vector.transfer + tensor.insert/extract_slice folding
vector.transfer operations do not have rank-reducing semantics. Bail on illegal rank-reduction: we need to check that the rank-reduced dims are exactly the leading dims. I.e. the following is illegal: ``` %0 = vector.transfer_write %v, %t[0,0], %cst : vector<2x4xf32>, tensor<2x4xf32> %1 = tensor.insert_slice %0 into %tt[0,0,0][2,1,4][1,1,1] : tensor<2x4xf32> into tensor<2x1x4xf32> ``` Cannot fold into: ``` %0 = vector.transfer_write %v, %t[0,0,0], %cst : vector<2x4xf32>, tensor<2x1x4xf32> ``` For this, check the trailing `vectorRank` dims of the insert_slice result tensor match the trailing dims of the inferred result tensor. Differential Revision: https://reviews.llvm.org/D116409
1 parent 84b285d commit 2e69f4f

File tree

2 files changed

+84
-1
lines changed

2 files changed

+84
-1
lines changed

mlir/lib/Dialect/Vector/VectorOps.cpp

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "mlir/IR/BlockAndValueMapping.h"
2525
#include "mlir/IR/Builders.h"
2626
#include "mlir/IR/BuiltinOps.h"
27+
#include "mlir/IR/BuiltinTypes.h"
2728
#include "mlir/IR/DialectImplementation.h"
2829
#include "mlir/IR/OpImplementation.h"
2930
#include "mlir/IR/PatternMatch.h"
@@ -2783,8 +2784,35 @@ struct FoldExtractSliceIntoTransferRead
27832784
if (!extractOp.hasUnitStride())
27842785
return failure();
27852786

2787+
// Bail on illegal rank-reduction: we need to check that the rank-reduced
2788+
// dims are exactly the leading dims. I.e. the following is illegal:
2789+
// ```
2790+
// %0 = tensor.extract_slice %t[0,0,0][2,1,4][1,1,1] :
2791+
// tensor<2x1x4xf32> to tensor<2x4xf32>
2792+
// %1 = vector.transfer_read %0[0,0], %cst :
2793+
// tensor<2x4xf32>, vector<2x4xf32>
2794+
// ```
2795+
//
2796+
// Cannot fold into:
2797+
// ```
2798+
// %0 = vector.transfer_read %t[0,0,0], %cst :
2799+
// tensor<2x1x4xf32>, vector<2x4xf32>
2800+
// ```
2801+
// For this, check the trailing `vectorRank` dims of the extract_slice
2802+
// result tensor match the trailing dims of the inferred result tensor.
27862803
int64_t rankReduced =
27872804
extractOp.getSourceType().getRank() - extractOp.getType().getRank();
2805+
int64_t vectorRank = xferOp.getVectorType().getRank();
2806+
RankedTensorType inferredDestTensorType =
2807+
tensor::ExtractSliceOp::inferResultType(
2808+
extractOp.getSourceType(), extractOp.getMixedOffsets(),
2809+
extractOp.getMixedSizes(), extractOp.getMixedStrides());
2810+
auto actualDestTensorShape = extractOp.getType().getShape();
2811+
if (rankReduced > 0 &&
2812+
actualDestTensorShape.take_back(vectorRank) !=
2813+
inferredDestTensorType.getShape().take_back(vectorRank))
2814+
return failure();
2815+
27882816
SmallVector<Value> newIndices;
27892817
// In case this is a rank-reducing ExtractSliceOp, copy rank-reduced
27902818
// indices first.
@@ -3168,14 +3196,43 @@ struct FoldInsertSliceIntoTransferWrite
31683196
if (xferOp.mask())
31693197
return failure();
31703198
// Fold only if the TransferWriteOp completely overwrites the `source` with
3171-
// a vector. I.e., the result of the TransferWriteOp is a new tensor who's
3199+
// a vector. I.e., the result of the TransferWriteOp is a new tensor whose
31723200
// content is the data of the vector.
31733201
if (!llvm::equal(xferOp.getVectorType().getShape(),
31743202
xferOp.getShapedType().getShape()))
31753203
return failure();
31763204
if (!xferOp.permutation_map().isIdentity())
31773205
return failure();
31783206

3207+
// Bail on illegal rank-reduction: we need to check that the rank-reduced
3208+
// dims are exactly the leading dims. I.e. the following is illegal:
3209+
// ```
3210+
// %0 = vector.transfer_write %v, %t[0,0], %cst :
3211+
// vector<2x4xf32>, tensor<2x4xf32>
3212+
// %1 = tensor.insert_slice %0 into %tt[0,0,0][2,1,4][1,1,1] :
3213+
// tensor<2x4xf32> into tensor<2x1x4xf32>
3214+
// ```
3215+
//
3216+
// Cannot fold into:
3217+
// ```
3218+
// %0 = vector.transfer_write %v, %t[0,0,0], %cst :
3219+
// vector<2x4xf32>, tensor<2x1x4xf32>
3220+
// ```
3221+
// For this, check the trailing `vectorRank` dims of the insert_slice result
3222+
// tensor match the trailing dims of the inferred result tensor.
3223+
int64_t rankReduced =
3224+
insertOp.getType().getRank() - insertOp.getSourceType().getRank();
3225+
int64_t vectorRank = xferOp.getVectorType().getRank();
3226+
RankedTensorType inferredSourceTensorType =
3227+
tensor::ExtractSliceOp::inferResultType(
3228+
insertOp.getType(), insertOp.getMixedOffsets(),
3229+
insertOp.getMixedSizes(), insertOp.getMixedStrides());
3230+
auto actualSourceTensorShape = insertOp.getSourceType().getShape();
3231+
if (rankReduced > 0 &&
3232+
actualSourceTensorShape.take_back(vectorRank) !=
3233+
inferredSourceTensorType.getShape().take_back(vectorRank))
3234+
return failure();
3235+
31793236
SmallVector<Value> indices = getValueOrCreateConstantIndexOp(
31803237
rewriter, insertOp.getLoc(), insertOp.getMixedOffsets());
31813238
SmallVector<bool> inBounds(xferOp.getTransferRank(), true);

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -995,6 +995,20 @@ func @transfer_read_of_extract_slice_rank_reducing(%t : tensor<?x?x?xf32>, %s1 :
995995

996996
// -----
997997

998+
// CHECK-LABEL: func @transfer_read_of_extract_slice_illegal_rank_reducing(
999+
// CHECK: extract_slice
1000+
// CHECK: vector.transfer_read
1001+
func @transfer_read_of_extract_slice_illegal_rank_reducing(%t : tensor<?x?x?xf32>, %s1 : index, %s2 : index) -> vector<5x6xf32> {
1002+
%c3 = arith.constant 3 : index
1003+
%c4 = arith.constant 4 : index
1004+
%cst = arith.constant 0.0 : f32
1005+
%0 = tensor.extract_slice %t[5, %s1, 6] [%s2, 1, 12] [1, 1, 1] : tensor<?x?x?xf32> to tensor<?x12xf32>
1006+
%1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor<?x12xf32>, vector<5x6xf32>
1007+
return %1 : vector<5x6xf32>
1008+
}
1009+
1010+
// -----
1011+
9981012
// CHECK-LABEL: func @insert_slice_of_transfer_write(
9991013
// CHECK-SAME: %[[t1:.*]]: tensor<?x12xf32>, %[[v:.*]]: vector<5x6xf32>, %[[s:.*]]: index
10001014
// CHECK: %[[c3:.*]] = arith.constant 3 : index
@@ -1009,6 +1023,18 @@ func @insert_slice_of_transfer_write(%t1 : tensor<?x12xf32>, %v : vector<5x6xf32
10091023

10101024
// -----
10111025

1026+
// CHECK-LABEL: func @insert_slice_of_transfer_write_illegal_rank_extending(
1027+
// CHECK: vector.transfer_write
1028+
// CHECK: insert_slice
1029+
func @insert_slice_of_transfer_write_illegal_rank_extending(%t1 : tensor<?x?x12xf32>, %v : vector<5x6xf32>, %s : index, %t2 : tensor<5x6xf32>) -> tensor<?x?x12xf32> {
1030+
%c0 = arith.constant 0 : index
1031+
%0 = vector.transfer_write %v, %t2[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<5x6xf32>
1032+
%1 = tensor.insert_slice %0 into %t1[4, 3, %s] [5, 1, 6] [1, 1, 1] : tensor<5x6xf32> into tensor<?x?x12xf32>
1033+
return %1 : tensor<?x?x12xf32>
1034+
}
1035+
1036+
// -----
1037+
10121038
// CHECK-LABEL: func @insert_slice_of_transfer_write_rank_extending(
10131039
// CHECK-SAME: %[[t1:.*]]: tensor<?x?x12xf32>, %[[v:.*]]: vector<5x6xf32>, %[[s:.*]]: index
10141040
// CHECK-DAG: %[[c3:.*]] = arith.constant 3 : index

0 commit comments

Comments
 (0)