Skip to content

Commit 3a27b5b

Browse files
committed
[mlir][vector] Restrict DropInnerMostUnitDimsTransferRead
Restrict `DropInnerMostUnitDimsTransferRead` so that it fails when one of the indices to be dropped could be != 0, e.g. ``` func.func @negative_example(%A: memref<16x1xf32>, %i:index, %j:index) -> (vector<8x1xf32>) { %f0 = arith.constant 0.0 : f32 %1 = vector.transfer_read %A[%i, %j], %f0 : memref<16x1xf32>, vector<8x1xf32> return %1 : vector<8x1xf32> } ``` This is an edge case that could represent an out-of-bounds access, though that will depend on the actual value of `%j`. NOTE: This PR is limited to tests for `vector.transfer_read`. Depends on: llvm#94490, llvm#94604
1 parent f246e3a commit 3a27b5b

File tree

2 files changed

+32
-7
lines changed

2 files changed

+32
-7
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1293,6 +1293,21 @@ class DropInnerMostUnitDimsTransferRead
12931293
if (dimsToDrop == 0)
12941294
return failure();
12951295

1296+
// Make sure that the indixes to be dropped are equal 0.
1297+
// TODO: Deal with cases when the indices are not 0.
1298+
auto isZeroIdx = [](Value idx) {
1299+
Attribute attr;
1300+
APInt value;
1301+
if (!matchPattern(idx, m_Constant(&attr)))
1302+
return false;
1303+
if (matchPattern(attr, m_ConstantInt(&value)))
1304+
if (!value.isZero())
1305+
return false;
1306+
return true;
1307+
};
1308+
if (!llvm::all_of(readOp.getIndices().take_back(dimsToDrop), isZeroIdx))
1309+
return failure();
1310+
12961311
auto resultTargetVecType =
12971312
VectorType::get(targetType.getShape().drop_back(dimsToDrop),
12981313
targetType.getElementType(),

mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -111,31 +111,41 @@ func.func @contiguous_inner_most_outer_dim_dyn_scalable_inner_dim(%a: index, %b:
111111

112112
// -----
113113

114-
func.func @contiguous_inner_most_dim_non_zero_idxs(%A: memref<16x1xf32>, %i:index, %j:index) -> (vector<8x1xf32>) {
114+
func.func @contiguous_inner_most_dim_non_zero_idx(%A: memref<16x1xf32>, %i:index) -> (vector<8x1xf32>) {
115115
%c0 = arith.constant 0 : index
116116
%f0 = arith.constant 0.0 : f32
117-
%1 = vector.transfer_read %A[%i, %j], %f0 : memref<16x1xf32>, vector<8x1xf32>
117+
%1 = vector.transfer_read %A[%i, %c0], %f0 : memref<16x1xf32>, vector<8x1xf32>
118118
return %1 : vector<8x1xf32>
119119
}
120-
// CHECK: func @contiguous_inner_most_dim_non_zero_idxs(%[[SRC:.+]]: memref<16x1xf32>, %[[I:.+]]: index, %[[J:.+]]: index) -> vector<8x1xf32>
120+
// CHECK: func @contiguous_inner_most_dim_non_zero_idx(%[[SRC:.+]]: memref<16x1xf32>, %[[I:.+]]: index) -> vector<8x1xf32>
121121
// CHECK: %[[SRC_0:.+]] = memref.subview %[[SRC]]
122122
// CHECK-SAME: memref<16x1xf32> to memref<16xf32, strided<[1]>>
123123
// CHECK: %[[V:.+]] = vector.transfer_read %[[SRC_0]]
124124
// CHECK: %[[RESULT:.+]] = vector.shape_cast %[[V]] : vector<8xf32> to vector<8x1xf32>
125125
// CHECK: return %[[RESULT]]
126126

127+
// The index to be dropped is != 0 - this is currently not supported.
128+
func.func @negative_contiguous_inner_most_dim_non_zero_idxs(%A: memref<16x1xf32>, %i:index) -> (vector<8x1xf32>) {
129+
%f0 = arith.constant 0.0 : f32
130+
%1 = vector.transfer_read %A[%i, %i], %f0 : memref<16x1xf32>, vector<8x1xf32>
131+
return %1 : vector<8x1xf32>
132+
}
133+
// CHECK-LABEL: func @negative_contiguous_inner_most_dim_non_zero_idxs
134+
// CHECK-NOT: memref.subview
135+
// CHECK: vector.transfer_read
136+
127137
// Same as the top example within this split, but with the outer vector
128138
// dim scalable. Note that this example only makes sense when "8 = [8]" (i.e.
129139
// vscale = 1). This is assumed (implicitly) via the `in_bounds` attribute.
130140

131-
func.func @contiguous_inner_most_dim_non_zero_idxs_scalable_inner_dim(%A: memref<16x1xf32>, %i:index, %j:index) -> (vector<[8]x1xf32>) {
141+
func.func @contiguous_inner_most_dim_non_zero_idx_scalable_inner_dim(%A: memref<16x1xf32>, %i:index) -> (vector<[8]x1xf32>) {
132142
%c0 = arith.constant 0 : index
133143
%f0 = arith.constant 0.0 : f32
134-
%1 = vector.transfer_read %A[%i, %j], %f0 : memref<16x1xf32>, vector<[8]x1xf32>
144+
%1 = vector.transfer_read %A[%i, %c0], %f0 : memref<16x1xf32>, vector<[8]x1xf32>
135145
return %1 : vector<[8]x1xf32>
136146
}
137-
// CHECK-LABEL: func @contiguous_inner_most_dim_non_zero_idxs_scalable_inner_dim(
138-
// CHECK-SAME: %[[SRC:.+]]: memref<16x1xf32>, %[[I:.+]]: index, %[[J:.+]]: index) -> vector<[8]x1xf32>
147+
// CHECK-LABEL: func @contiguous_inner_most_dim_non_zero_idx_scalable_inner_dim(
148+
// CHECK-SAME: %[[SRC:.+]]: memref<16x1xf32>, %[[I:.+]]: index) -> vector<[8]x1xf32>
139149
// CHECK: %[[SRC_0:.+]] = memref.subview %[[SRC]]
140150
// CHECK-SAME: memref<16x1xf32> to memref<16xf32, strided<[1]>>
141151
// CHECK: %[[V:.+]] = vector.transfer_read %[[SRC_0]]

0 commit comments

Comments
 (0)