Skip to content

Commit e01ff82

Browse files
authored
[mlir][vector] Fix scalability issues in drop innermost unit dims transfer patterns (#92402)
Previously, these rewrites would drop scalable dimensions and treated `[1]` (scalable one dim) as a unit dimension. This patch propagates scalable dimensions and ensures `[1]` is not treated as a unit dimension.
1 parent 8aa6511 commit e01ff82

File tree

2 files changed

+65
-4
lines changed

2 files changed

+65
-4
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1237,6 +1237,10 @@ getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
12371237
if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
12381238
return failure();
12391239

1240+
auto isUnitDim = [](VectorType type, int dim) {
1241+
return type.getDimSize(dim) == 1 && !type.getScalableDims()[dim];
1242+
};
1243+
12401244
// According to vector.transfer_read/write semantics, the vector can be a
12411245
// slice. Thus, we have to offset the check index with `rankDiff` in
12421246
// `srcStrides` and source dim sizes.
@@ -1247,8 +1251,7 @@ getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
12471251
// It can be folded only if they are 1 and the stride is 1.
12481252
int dim = vectorType.getRank() - i - 1;
12491253
if (srcStrides[dim + rankDiff] != 1 ||
1250-
srcType.getDimSize(dim + rankDiff) != 1 ||
1251-
vectorType.getDimSize(dim) != 1)
1254+
srcType.getDimSize(dim + rankDiff) != 1 || !isUnitDim(vectorType, dim))
12521255
break;
12531256
result++;
12541257
}
@@ -1292,7 +1295,8 @@ class DropInnerMostUnitDimsTransferRead
12921295

12931296
auto resultTargetVecType =
12941297
VectorType::get(targetType.getShape().drop_back(dimsToDrop),
1295-
targetType.getElementType());
1298+
targetType.getElementType(),
1299+
targetType.getScalableDims().drop_back(dimsToDrop));
12961300

12971301
auto loc = readOp.getLoc();
12981302
SmallVector<OpFoldResult> sizes =
@@ -1378,7 +1382,8 @@ class DropInnerMostUnitDimsTransferWrite
13781382

13791383
auto resultTargetVecType =
13801384
VectorType::get(targetType.getShape().drop_back(dimsToDrop),
1381-
targetType.getElementType());
1385+
targetType.getElementType(),
1386+
targetType.getScalableDims().drop_back(dimsToDrop));
13821387

13831388
Location loc = writeOp.getLoc();
13841389
SmallVector<OpFoldResult> sizes =

mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,3 +174,59 @@ func.func @non_unit_strides(%arg0: memref<512x16x1xf32, strided<[8192, 16, 4], o
174174
// The inner most unit dims can not be dropped if the strides are not ones.
175175
// CHECK: func.func @non_unit_strides
176176
// CHECK-NOT: memref.subview
177+
178+
// -----
179+
180+
func.func @leading_scalable_dimension_transfer_read(%dest : memref<24x1xf32>) -> vector<[4]x1xf32> {
181+
%c0 = arith.constant 0 : index
182+
%pad = arith.constant 0.0 : f32
183+
%0 = vector.transfer_read %dest[%c0, %c0], %pad {in_bounds = [true, true]} : memref<24x1xf32>, vector<[4]x1xf32>
184+
return %0 : vector<[4]x1xf32>
185+
}
186+
// CHECK: func.func @leading_scalable_dimension_transfer_read
187+
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
188+
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[DEST]][0, 0] [24, 1] [1, 1] : memref<24x1xf32> to memref<24xf32, strided<[1]>>
189+
// CHECK: %[[READ:.+]] = vector.transfer_read %[[SUBVIEW]]{{.*}} {in_bounds = [true]} : memref<24xf32, strided<[1]>>, vector<[4]xf32>
190+
// CHECK: %[[CAST:.+]] = vector.shape_cast %[[READ]] : vector<[4]xf32> to vector<[4]x1xf32>
191+
// CHECK: return %[[CAST]]
192+
193+
// -----
194+
195+
// Negative test: [1] (scalable 1) is _not_ a unit dimension.
196+
func.func @trailing_scalable_one_dim_transfer_read(%dest : memref<24x1xf32>) -> vector<4x[1]xf32> {
197+
%c0 = arith.constant 0 : index
198+
%pad = arith.constant 0.0 : f32
199+
%0 = vector.transfer_read %dest[%c0, %c0], %pad {in_bounds = [true, true]} : memref<24x1xf32>, vector<4x[1]xf32>
200+
return %0 : vector<4x[1]xf32>
201+
}
202+
// CHECK: func.func @trailing_scalable_one_dim_transfer_read
203+
// CHECK-NOT: vector.shape_cast
204+
// CHECK: vector.transfer_read {{.*}} : memref<24x1xf32>, vector<4x[1]xf32>
205+
// CHECK-NOT: vector.shape_cast
206+
207+
// -----
208+
209+
func.func @leading_scalable_dimension_transfer_write(%dest : memref<24x1xf32>, %vec: vector<[4]x1xf32>) {
210+
%c0 = arith.constant 0 : index
211+
vector.transfer_write %vec, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[4]x1xf32>, memref<24x1xf32>
212+
return
213+
}
214+
// CHECK: func.func @leading_scalable_dimension_transfer_write
215+
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
216+
// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]]
217+
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[DEST]][0, 0] [24, 1] [1, 1] : memref<24x1xf32> to memref<24xf32, strided<[1]>>
218+
// CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<[4]x1xf32> to vector<[4]xf32>
219+
// CHECK: vector.transfer_write %[[CAST]], %[[SUBVIEW]]{{.*}} {in_bounds = [true]} : vector<[4]xf32>, memref<24xf32, strided<[1]>>
220+
221+
// -----
222+
223+
// Negative test: [1] (scalable 1) is _not_ a unit dimension.
224+
func.func @trailing_scalable_one_dim_transfer_write(%dest : memref<24x1xf32>, %vec: vector<4x[1]xf32>, %index: index) {
225+
%c0 = arith.constant 0 : index
226+
vector.transfer_write %vec, %dest[%index, %c0] {in_bounds = [true, true]} : vector<4x[1]xf32>, memref<24x1xf32>
227+
return
228+
}
229+
// CHECK: func.func @trailing_scalable_one_dim_transfer_write
230+
// CHECK-NOT: vector.shape_cast
231+
// CHECK: vector.transfer_write {{.*}} : vector<4x[1]xf32>, memref<24x1xf32>
232+
// CHECK-NOT: vector.shape_cast

0 commit comments

Comments
 (0)