Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][memref] Add memref alias folding for masked transfers #71476

Merged
merged 2 commits into from
Nov 7, 2023

Conversation

qedawkins
Copy link
Contributor

@qedawkins qedawkins commented Nov 7, 2023

The contents of a mask on a masked transfer are unaffected by the particular region of memory being read/stored to, so just forward the mask in subview folding patterns.

Because masking of vector.transfer ops semantically apply to the
unpermuted input vector (for reads) and permuted output vector (for
writes), they apply independently of any subviews.
@llvmbot
Copy link
Collaborator

llvmbot commented Nov 7, 2023

@llvm/pr-subscribers-mlir-memref

@llvm/pr-subscribers-mlir

Author: Quinn Dawkins (qedawkins)

Changes

Because masking of vector.transfer ops apply to the unpermuted input vector (for reads) and permuted output vector (for writes), they apply independently of any subviews and can thus be forwarded to the folded transfers.


Full diff: https://github.com/llvm/llvm-project/pull/71476.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp (+2-4)
  • (modified) mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir (+57)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index 043e8fbcdd2f6fb..b78c4510ff88585 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -346,8 +346,6 @@ preconditionsFoldSubViewOpImpl(RewriterBase &rewriter, XferOp xferOp,
       "must be a vector transfer op");
   if (xferOp.hasOutOfBoundsDim())
     return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim");
-  if (xferOp.getMask())
-    return rewriter.notifyMatchFailure(xferOp, "masked transfer");
   if (!subviewOp.hasUnitStride()) {
     return rewriter.notifyMatchFailure(
         xferOp, "non-1 stride subview, need to track strides in folded memref");
@@ -428,7 +426,7 @@ LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
             AffineMapAttr::get(expandDimsToRank(
                 op.getPermutationMap(), subViewOp.getSourceType().getRank(),
                 subViewOp.getDroppedDims())),
-            op.getPadding(), /*mask=*/Value(), op.getInBoundsAttr());
+            op.getPadding(), op.getMask(), op.getInBoundsAttr());
       })
       .Case([&](gpu::SubgroupMmaLoadMatrixOp op) {
         rewriter.replaceOpWithNewOp<gpu::SubgroupMmaLoadMatrixOp>(
@@ -557,7 +555,7 @@ LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
             AffineMapAttr::get(expandDimsToRank(
                 op.getPermutationMap(), subViewOp.getSourceType().getRank(),
                 subViewOp.getDroppedDims())),
-            op.getInBoundsAttr());
+            op.getMask(), op.getInBoundsAttr());
       })
       .Case([&](gpu::SubgroupMmaStoreMatrixOp op) {
         rewriter.replaceOpWithNewOp<gpu::SubgroupMmaStoreMatrixOp>(
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 3f11e22749bb16d..2e1319420fb3eaf 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -266,6 +266,63 @@ func.func @fold_vector_transfer_write_with_inner_rank_reduced_subview(
 
 // -----
 
+func.func @fold_masked_vector_transfer_read_with_subview(
+    %arg0 : memref<?x?xf32, strided<[?, ?], offset: ?>>,
+    %arg1: index, %arg2 : index, %arg3 : index, %arg4: index, %arg5 : index,
+    %arg6 : index, %mask : vector<4xi1>) -> vector<4xf32> {
+  %cst = arith.constant 0.0 : f32
+  %0 = memref.subview %arg0[%arg1, %arg2] [%arg3, %arg4] [1, 1]
+      : memref<?x?xf32, strided<[?, ?], offset: ?>> to
+        memref<?x?xf32, strided<[?, ?], offset: ?>>
+  %1 = vector.transfer_read %0[%arg5, %arg6], %cst, %mask {in_bounds = [true]}
+      : memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<4xf32>
+  return %1 : vector<4xf32>
+}
+//   CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+//       CHECK: func @fold_masked_vector_transfer_read_with_subview
+//  CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf32, strided<[?, ?], offset: ?>>
+//  CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG2:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG3:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG4:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG5:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG6:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG7:[a-zA-Z0-9]+]]: vector<4xi1>
+//   CHECK-DAG:    %[[IDX0:.+]] = affine.apply #[[MAP1]]()[%[[ARG1]], %[[ARG5]]]
+//   CHECK-DAG:    %[[IDX1:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG6]]]
+//       CHECK:    vector.transfer_read %[[ARG0]][%[[IDX0]], %[[IDX1]]], %{{.*}}, %[[ARG7]] {{.*}} : memref<?x?xf32
+
+// -----
+
+func.func @fold_masked_vector_transfer_write_with_subview(
+    %arg0 : memref<?x?xf32, strided<[?, ?], offset: ?>>,
+    %arg1 : vector<4xf32>, %arg2: index, %arg3 : index, %arg4 : index,
+    %arg5: index, %arg6 : index, %arg7 : index, %mask : vector<4xi1>) {
+  %cst = arith.constant 0.0 : f32
+  %0 = memref.subview %arg0[%arg2, %arg3] [%arg4, %arg5] [1, 1]
+      : memref<?x?xf32, strided<[?, ?], offset: ?>> to
+        memref<?x?xf32, strided<[?, ?], offset: ?>>
+  vector.transfer_write %arg1, %0[%arg6, %arg7], %mask {in_bounds = [true]}
+      : vector<4xf32>, memref<?x?xf32, strided<[?, ?], offset: ?>>
+  return
+}
+//   CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+//       CHECK: func @fold_masked_vector_transfer_write_with_subview
+//  CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf32, strided<[?, ?], offset: ?>>
+//  CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]: vector<4xf32>
+//  CHECK-SAME:    %[[ARG2:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG3:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG4:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG5:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG6:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG7:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG8:[a-zA-Z0-9]+]]: vector<4xi1>
+//   CHECK-DAG:    %[[IDX0:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG6]]]
+//   CHECK-DAG:    %[[IDX1:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG7]]]
+//   CHECK-DAG:    vector.transfer_write %[[ARG1]], %[[ARG0]][%[[IDX0]], %[[IDX1]]], %[[ARG8]] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32
+
+// -----
+
 //  Test with affine.load/store ops. We only do a basic test here since the
 //  logic is identical to that with memref.load/store ops. The same affine.apply
 //  ops would be generated.

@antiagainst
Copy link
Member

A tricky part with subview ops in general is rank reducing behavior. Thinking a bit actually it should be fine given that folding rank reducing subview ops into the original memref would just mean we see extra dims in the transfer op permuation map. That will be compressed when inferTransferOpMaskType, so we can still use the same mask.

@nicolasvasilache
Copy link
Contributor

A tricky part with subview ops in general is rank reducing behavior. Thinking a bit actually it should be fine given that folding rank reducing subview ops into the original memref would just mean we see extra dims in the transfer op permuation map. That will be compressed when inferTransferOpMaskType, so we can still use the same mask.

Let's add some rank-reducing tests please

@qedawkins qedawkins merged commit 48f980c into llvm:main Nov 7, 2023
3 checks passed
@qedawkins qedawkins deleted the fold_alias_masked_transfer branch November 7, 2023 13:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants