Skip to content

[mlir][vector] Fix scalability issues in drop innermost unit dims transfer patterns #92402

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 17, 2024

Conversation

MacDue
Copy link
Member

@MacDue MacDue commented May 16, 2024

Previously, these rewrites would drop scalable dimensions and treated [1] (scalable one dim) as a unit dimension. This patch propagates scalable dimensions and ensures [1] is not treated as a unit dimension.

@llvmbot
Copy link
Member

llvmbot commented May 16, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Benjamin Maxwell (MacDue)

Changes

Previously, this rewrite would drop scalable dimensions and treated [1] (scalable one dim) as a unit dimension. This patch propagates scalable dimensions and ensures [1] is not treated as a unit dimension.


Full diff: https://github.com/llvm/llvm-project/pull/92402.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+7-3)
  • (modified) mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir (+30)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 69c497264fd1e..720e638a74b55 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1237,6 +1237,10 @@ getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
   if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
     return failure();
 
+  auto isUnitDim = [](VectorType type, int dim) {
+    return type.getDimSize(dim) == 1 && !type.getScalableDims()[dim];
+  };
+
   // According to vector.transfer_read/write semantics, the vector can be a
   // slice. Thus, we have to offset the check index with `rankDiff` in
   // `srcStrides` and source dim sizes.
@@ -1247,8 +1251,7 @@ getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
     // It can be folded only if they are 1 and the stride is 1.
     int dim = vectorType.getRank() - i - 1;
     if (srcStrides[dim + rankDiff] != 1 ||
-        srcType.getDimSize(dim + rankDiff) != 1 ||
-        vectorType.getDimSize(dim) != 1)
+        srcType.getDimSize(dim + rankDiff) != 1 || !isUnitDim(vectorType, dim))
       break;
     result++;
   }
@@ -1292,7 +1295,8 @@ class DropInnerMostUnitDimsTransferRead
 
     auto resultTargetVecType =
         VectorType::get(targetType.getShape().drop_back(dimsToDrop),
-                        targetType.getElementType());
+                        targetType.getElementType(),
+                        targetType.getScalableDims().drop_back(dimsToDrop));
 
     auto loc = readOp.getLoc();
     SmallVector<OpFoldResult> sizes =
diff --git a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
index 477755b66c020..ddfae5590e4c4 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
@@ -174,3 +174,33 @@ func.func @non_unit_strides(%arg0: memref<512x16x1xf32, strided<[8192, 16, 4], o
 // The inner most unit dims can not be dropped if the strides are not ones.
 // CHECK:     func.func @non_unit_strides
 // CHECK-NOT:   memref.subview
+
+// -----
+
+func.func @leading_scalable_dimension_transfer_read(%dest : memref<24x1xf32>, %index: index) -> vector<[4]x1xf32> {
+  %c0 = arith.constant 0 : index
+  %cst_0 = arith.constant 0.000000e+00 : f32
+  %4 = vector.transfer_read %dest[%index, %c0], %cst_0 {in_bounds = [true, true]} : memref<24x1xf32>, vector<[4]x1xf32>
+  return %4 : vector<[4]x1xf32>
+}
+// CHECK:      func.func @leading_scalable_dimension_transfer_read
+// CHECK-SAME:   %[[DEST:[a-zA-Z0-9]+]]
+// CHECK-SAME:   %[[IDX:[a-zA-Z0-9]+]]
+// CHECK:        %[[SUBVIEW:.+]] = memref.subview %[[DEST]][0, 0] [24, 1] [1, 1] : memref<24x1xf32> to memref<24xf32, strided<[1]>>
+// CHECK:        %[[READ:.+]] = vector.transfer_read %[[SUBVIEW]][%[[IDX]]], %{{.*}} {in_bounds = [true]} : memref<24xf32, strided<[1]>>, vector<[4]xf32>
+// CHECK:        %[[CAST:.+]] = vector.shape_cast %[[READ]] : vector<[4]xf32> to vector<[4]x1xf32>
+// CHECK:        return %[[CAST]]
+
+// -----
+
+// Negative test: [1] (scalable 1) is _not_ a unit dimension.
+func.func @trailing_scalable_one_dim_transfer_read(%dest : memref<24x1xf32>, %index: index) -> vector<4x[1]xf32> {
+  %c0 = arith.constant 0 : index
+  %cst_0 = arith.constant 0.000000e+00 : f32
+  %4 = vector.transfer_read %dest[%index, %c0], %cst_0 {in_bounds = [true, true]} : memref<24x1xf32>, vector<4x[1]xf32>
+  return %4 : vector<4x[1]xf32>
+}
+// CHECK:      func.func @trailing_scalable_one_dim_transfer_read
+// CHECK-NOT:    vector.shape_cast
+// CHECK:        vector.transfer_read {{.*}} : memref<24x1xf32>, vector<4x[1]xf32>
+// CHECK-NOT:    vector.shape_cast

Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a couple of nits, also noticed DropInnerMostUnitDimsTransferWrite is dropping flags? I'd fix that in this patch as well given how similar they are, accepting anyway as it's not a blocker.

MacDue added 2 commits May 17, 2024 14:04
… unit dims

Previously, this rewrite would drop scalable dimensions and treated `[1]`
(scalable one dim) as a unit dimension. This patch propagates scalable
dimensions and ensures `[1]` is not treated as a unit dimension.
@MacDue MacDue force-pushed the fix_dropped_scalablity branch from 2035225 to 26255c1 Compare May 17, 2024 14:51
@MacDue MacDue changed the title [mlir][vector] Fix scalability issues in drop innermost transfer_read unit dims [mlir][vector] Fix scalability issues in drop innermost transfer unit dims patterns May 17, 2024
@MacDue MacDue changed the title [mlir][vector] Fix scalability issues in drop innermost transfer unit dims patterns [mlir][vector] Fix scalability issues in drop innermost unit dims transfer patterns May 17, 2024
@MacDue MacDue merged commit e01ff82 into llvm:main May 17, 2024
3 of 4 checks passed
@MacDue MacDue deleted the fix_dropped_scalablity branch May 17, 2024 16:14
MacDue added a commit to MacDue/iree that referenced this pull request May 21, 2024
We should also test peeling here, but that is currently waiting on
llvm/llvm-project#92402 (to fix one of the small
cases).
MacDue added a commit to MacDue/iree that referenced this pull request May 21, 2024
We should also test peeling here, but that is currently waiting on
llvm/llvm-project#92402 (to fix one of the small
cases).

Signed-off-by: Benjamin Maxwell <benjamin.maxwell@arm.com>
MacDue added a commit to MacDue/iree that referenced this pull request May 22, 2024
We should also test peeling here, but that is currently waiting on
llvm/llvm-project#92402 (to fix one of the small
cases).

Signed-off-by: Benjamin Maxwell <benjamin.maxwell@arm.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants