-
Notifications
You must be signed in to change notification settings - Fork 13.7k
[mlir][vector] Fix scalability issues in drop innermost unit dims transfer patterns #92402
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: Benjamin Maxwell (MacDue) ChangesPreviously, this rewrite would drop scalable dimensions and treated Full diff: https://github.com/llvm/llvm-project/pull/92402.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 69c497264fd1e..720e638a74b55 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1237,6 +1237,10 @@ getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
return failure();
+ auto isUnitDim = [](VectorType type, int dim) {
+ return type.getDimSize(dim) == 1 && !type.getScalableDims()[dim];
+ };
+
// According to vector.transfer_read/write semantics, the vector can be a
// slice. Thus, we have to offset the check index with `rankDiff` in
// `srcStrides` and source dim sizes.
@@ -1247,8 +1251,7 @@ getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
// It can be folded only if they are 1 and the stride is 1.
int dim = vectorType.getRank() - i - 1;
if (srcStrides[dim + rankDiff] != 1 ||
- srcType.getDimSize(dim + rankDiff) != 1 ||
- vectorType.getDimSize(dim) != 1)
+ srcType.getDimSize(dim + rankDiff) != 1 || !isUnitDim(vectorType, dim))
break;
result++;
}
@@ -1292,7 +1295,8 @@ class DropInnerMostUnitDimsTransferRead
auto resultTargetVecType =
VectorType::get(targetType.getShape().drop_back(dimsToDrop),
- targetType.getElementType());
+ targetType.getElementType(),
+ targetType.getScalableDims().drop_back(dimsToDrop));
auto loc = readOp.getLoc();
SmallVector<OpFoldResult> sizes =
diff --git a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
index 477755b66c020..ddfae5590e4c4 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
@@ -174,3 +174,33 @@ func.func @non_unit_strides(%arg0: memref<512x16x1xf32, strided<[8192, 16, 4], o
// The inner most unit dims can not be dropped if the strides are not ones.
// CHECK: func.func @non_unit_strides
// CHECK-NOT: memref.subview
+
+// -----
+
+func.func @leading_scalable_dimension_transfer_read(%dest : memref<24x1xf32>, %index: index) -> vector<[4]x1xf32> {
+ %c0 = arith.constant 0 : index
+ %cst_0 = arith.constant 0.000000e+00 : f32
+ %4 = vector.transfer_read %dest[%index, %c0], %cst_0 {in_bounds = [true, true]} : memref<24x1xf32>, vector<[4]x1xf32>
+ return %4 : vector<[4]x1xf32>
+}
+// CHECK: func.func @leading_scalable_dimension_transfer_read
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]]
+// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[DEST]][0, 0] [24, 1] [1, 1] : memref<24x1xf32> to memref<24xf32, strided<[1]>>
+// CHECK: %[[READ:.+]] = vector.transfer_read %[[SUBVIEW]][%[[IDX]]], %{{.*}} {in_bounds = [true]} : memref<24xf32, strided<[1]>>, vector<[4]xf32>
+// CHECK: %[[CAST:.+]] = vector.shape_cast %[[READ]] : vector<[4]xf32> to vector<[4]x1xf32>
+// CHECK: return %[[CAST]]
+
+// -----
+
+// Negative test: [1] (scalable 1) is _not_ a unit dimension.
+func.func @trailing_scalable_one_dim_transfer_read(%dest : memref<24x1xf32>, %index: index) -> vector<4x[1]xf32> {
+ %c0 = arith.constant 0 : index
+ %cst_0 = arith.constant 0.000000e+00 : f32
+ %4 = vector.transfer_read %dest[%index, %c0], %cst_0 {in_bounds = [true, true]} : memref<24x1xf32>, vector<4x[1]xf32>
+ return %4 : vector<4x[1]xf32>
+}
+// CHECK: func.func @trailing_scalable_one_dim_transfer_read
+// CHECK-NOT: vector.shape_cast
+// CHECK: vector.transfer_read {{.*}} : memref<24x1xf32>, vector<4x[1]xf32>
+// CHECK-NOT: vector.shape_cast
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a couple of nits, also noticed DropInnerMostUnitDimsTransferWrite
is dropping flags? I'd fix that in this patch as well given how similar they are, accepting anyway as it's not a blocker.
mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
Outdated
Show resolved
Hide resolved
… unit dims Previously, this rewrite would drop scalable dimensions and treated `[1]` (scalable one dim) as a unit dimension. This patch propagates scalable dimensions and ensures `[1]` is not treated as a unit dimension.
2035225
to
26255c1
Compare
We should also test peeling here, but that is currently waiting on llvm/llvm-project#92402 (to fix one of the small cases).
We should also test peeling here, but that is currently waiting on llvm/llvm-project#92402 (to fix one of the small cases). Signed-off-by: Benjamin Maxwell <benjamin.maxwell@arm.com>
We should also test peeling here, but that is currently waiting on llvm/llvm-project#92402 (to fix one of the small cases). Signed-off-by: Benjamin Maxwell <benjamin.maxwell@arm.com>
Previously, these rewrites would drop scalable dimensions and treated
[1]
(scalable one dim) as a unit dimension. This patch propagates scalable dimensions and ensures[1]
is not treated as a unit dimension.