Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][vector] Add leading unit dim folding patterns for masked transfers #71466

Merged
merged 1 commit into from
Nov 7, 2023

Conversation

qedawkins
Copy link
Contributor

This handles vector.transfer_read, vector.transfer_write, and vector.constant_mask. The unit dims are only relevant for masks created by create_mask and constant_mask if the mask size for the unit dim is non-one, in which case all subsequent sizes must also be zero. From the perspective of the vector transfers, however, these unit dims can just be dropped directly.

…fers

This handles `vector.transfer_read`, `vector.transfer_write`, and
`vector.constant_mask`. The unit dims are only relevant for masks
created by `create_mask` and `constant_mask` if the mask size for the
unit dim is non-one, in which case all subsequent sizes must also be
zero. From the perspective of the vector transfers, however, these unit
dims can just be dropped directly.
@llvmbot
Copy link
Collaborator

llvmbot commented Nov 7, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Quinn Dawkins (qedawkins)

Changes

This handles vector.transfer_read, vector.transfer_write, and vector.constant_mask. The unit dims are only relevant for masks created by create_mask and constant_mask if the mask size for the unit dim is non-one, in which case all subsequent sizes must also be zero. From the perspective of the vector transfers, however, these unit dims can just be dropped directly.


Full diff: https://github.com/llvm/llvm-project/pull/71466.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp (+59-10)
  • (modified) mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir (+35)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 6bbb293fa2a6b5c..75f32b23e57b0d6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -6,6 +6,8 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include <numeric>
+
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -208,9 +210,6 @@ struct CastAwayTransferReadLeadingOneDim
     if (read.getTransferRank() == 0)
       return failure();
 
-    if (read.getMask())
-      return failure();
-
     auto shapedType = cast<ShapedType>(read.getSource().getType());
     if (shapedType.getElementType() != read.getVectorType().getElementType())
       return failure();
@@ -233,10 +232,18 @@ struct CastAwayTransferReadLeadingOneDim
       inBoundsAttr = rewriter.getArrayAttr(
           read.getInBoundsAttr().getValue().take_back(newType.getRank()));
 
+    Value mask = Value();
+    if (read.getMask()) {
+      // The mask shape must always match the shape of the written vector, so we
+      // can safely use the same extraction indices.
+      int64_t dropDim = oldType.getRank() - newType.getRank();
+      mask = rewriter.create<vector::ExtractOp>(read.getLoc(), read.getMask(),
+                                                splatZero(dropDim));
+    }
+
     auto newRead = rewriter.create<vector::TransferReadOp>(
         read.getLoc(), newType, read.getSource(), read.getIndices(),
-        AffineMapAttr::get(newMap), read.getPadding(), /*mask=*/Value(),
-        inBoundsAttr);
+        AffineMapAttr::get(newMap), read.getPadding(), mask, inBoundsAttr);
     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(read, oldType, newRead);
 
     return success();
@@ -256,9 +263,6 @@ struct CastAwayTransferWriteLeadingOneDim
     if (write.getTransferRank() == 0)
       return failure();
 
-    if (write.getMask())
-      return failure();
-
     auto shapedType = dyn_cast<ShapedType>(write.getSource().getType());
     if (shapedType.getElementType() != write.getVectorType().getElementType())
       return failure();
@@ -283,10 +287,21 @@ struct CastAwayTransferWriteLeadingOneDim
 
     auto newVector = rewriter.create<vector::ExtractOp>(
         write.getLoc(), write.getVector(), splatZero(dropDim));
+
+    if (write.getMask()) {
+      // The mask shape must always match the shape of the written vector, so we
+      // can safely use the same extraction indices.
+      auto newMask = rewriter.create<vector::ExtractOp>(
+          write.getLoc(), write.getMask(), splatZero(dropDim));
+      rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
+          write, newVector, write.getSource(), write.getIndices(),
+          AffineMapAttr::get(newMap), newMask, inBoundsAttr);
+      return success();
+    }
+
     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
         write, newVector, write.getSource(), write.getIndices(),
         AffineMapAttr::get(newMap), inBoundsAttr);
-
     return success();
   }
 };
@@ -467,6 +482,40 @@ class CastAwayElementwiseLeadingOneDim : public RewritePattern {
   }
 };
 
+// Drops leading 1 dimensions from vector.constant_mask and inserts a
+// vector.broadcast back to the original shape.
+struct CastAwayConstantMaskLeadingOneDim
+    : public OpRewritePattern<vector::ConstantMaskOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ConstantMaskOp mask,
+                                PatternRewriter &rewriter) const override {
+    VectorType oldType = mask.getType();
+    VectorType newType = trimLeadingOneDims(oldType);
+
+    if (newType == oldType)
+      return failure();
+
+    int64_t dropDim = oldType.getRank() - newType.getRank();
+    SmallVector<int64_t> dimSizes;
+    for (auto attr : mask.getMaskDimSizes())
+      dimSizes.push_back(llvm::cast<IntegerAttr>(attr).getInt());
+
+    // If any of the dropped unit dims has a size of `0`, the entire mask is a
+    // zero mask, else the unit dim has no effect on the mask.
+    int64_t flatLeadingSize =
+        std::accumulate(dimSizes.begin(), dimSizes.begin() + dropDim + 1,
+                        static_cast<int64_t>(1), std::multiplies<int64_t>());
+    SmallVector<int64_t> newDimSizes({flatLeadingSize});
+    newDimSizes.append(dimSizes.begin() + dropDim + 1, dimSizes.end());
+
+    auto newMask = rewriter.create<vector::ConstantMaskOp>(
+        mask.getLoc(), newType, rewriter.getI64ArrayAttr(newDimSizes));
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(mask, oldType, newMask);
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
@@ -474,7 +523,7 @@ void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
   patterns
       .add<CastAwayExtractStridedSliceLeadingOneDim,
            CastAwayInsertStridedSliceLeadingOneDim, CastAwayInsertLeadingOneDim,
-           CastAwayTransferReadLeadingOneDim,
+           CastAwayConstantMaskLeadingOneDim, CastAwayTransferReadLeadingOneDim,
            CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim,
            CastAwayContractionLeadingOneDim>(patterns.getContext(), benefit);
   populateShapeCastFoldingPatterns(patterns, benefit);
diff --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
index e5b27b04dcc8096..5de30206927db2f 100644
--- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
@@ -209,6 +209,20 @@ func.func @cast_away_transfer_read_leading_one_dims(%arg0: memref<1x4x8x16xf16>)
   return %0: vector<1x4xf16>
 }
 
+// CHECK-LABEL: func @cast_away_masked_transfer_read_leading_one_dims
+func.func @cast_away_masked_transfer_read_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xi1>) -> vector<1x4xf16> {
+  // CHECK: %[[C0:.+]] = arith.constant 0 : index
+  %c0 = arith.constant 0 : index
+  // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f16
+  %f0 = arith.constant 0. : f16
+  // CHECK: %[[MASK_CAST:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1>
+  // CHECK: %[[READ:.+]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[F0]], %[[MASK_CAST]] {in_bounds = [true]} : memref<1x4x8x16xf16>, vector<4xf16>
+  // CHECK: %[[CAST:.+]] = vector.broadcast %[[READ]] : vector<4xf16> to vector<1x4xf16>
+  %0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %f0, %arg1 {in_bounds = [true, true]} : memref<1x4x8x16xf16>, vector<1x4xf16>
+  // CHECK: return %[[CAST]]
+  return %0: vector<1x4xf16>
+}
+
 // CHECK-LABEL: func @cast_away_transfer_read_leading_one_dims_one_element
 func.func @cast_away_transfer_read_leading_one_dims_one_element(%arg0: memref<1x1x1x1xf16>) -> vector<1x1xf16> {
   %c0 = arith.constant 0 : index
@@ -229,6 +243,18 @@ func.func @cast_away_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>
   return
 }
 
+// CHECK-LABEL: func @cast_away_masked_transfer_write_leading_one_dims
+func.func @cast_away_masked_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xf16>, %arg2: vector<1x4xi1>) {
+  // CHECK: %[[C0:.+]] = arith.constant 0 : index
+  %c0 = arith.constant 0 : index
+  // CHECK: %[[CAST:.+]] = vector.extract %{{.*}}[0] : vector<4xf16> from vector<1x4xf16>
+  // CHECK: %[[MASK_CAST:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1>
+  // CHECK: vector.transfer_write %[[CAST]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[MASK_CAST]] {in_bounds = [true]} : vector<4xf16>, memref<1x4x8x16xf16>
+
+  vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0], %arg2 {in_bounds = [true, true]} : vector<1x4xf16>, memref<1x4x8x16xf16>
+  return
+}
+
 // CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims_one_element
 func.func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1x1x1x1xf16>, %arg1: vector<1x1xf16>) {
   %c0 = arith.constant 0 : index
@@ -410,3 +436,12 @@ func.func @cast_away_insert_leading_one_dims_one_two_dest_scalable(%s: vector<1x
   %0 = vector.insert %s, %v [0, 0, 7] : vector<1x[8]xi1> into vector<1x1x8x1x[8]xi1>
   return %0: vector<1x1x8x1x[8]xi1>
 }
+
+// CHECK-LABEL:   func.func @cast_away_constant_mask() -> vector<1x1x8x2x1xi1> {
+// CHECK:           %[[MASK:.*]] = vector.constant_mask [6, 1, 1] : vector<8x2x1xi1>
+// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[MASK]] : vector<8x2x1xi1> to vector<1x1x8x2x1xi1>
+// CHECK:           return %[[BCAST]] : vector<1x1x8x2x1xi1>
+func.func @cast_away_constant_mask() -> vector<1x1x8x2x1xi1> {
+  %0 = vector.constant_mask [1, 1, 6, 1, 1] : vector<1x1x8x2x1xi1>
+  return %0: vector<1x1x8x2x1xi1>
+}

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Copy link
Member

@antiagainst antiagainst left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks!

@qedawkins qedawkins merged commit 796d48b into llvm:main Nov 7, 2023
6 checks passed
@qedawkins qedawkins deleted the masked_leading_unit_dim_transfer branch November 7, 2023 01:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants