Conversation
…fers This handles `vector.transfer_read`, `vector.transfer_write`, and `vector.constant_mask`. The unit dims are only relevant for masks created by `create_mask` and `constant_mask` if the mask size for the unit dim is non-one, in which case all subsequent sizes must also be zero. From the perspective of the vector transfers, however, these unit dims can just be dropped directly.
Member
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: Quinn Dawkins (qedawkins) ChangesThis handles Full diff: https://github.com/llvm/llvm-project/pull/71466.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 6bbb293fa2a6b5c..75f32b23e57b0d6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -6,6 +6,8 @@
//
//===----------------------------------------------------------------------===//
+#include <numeric>
+
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -208,9 +210,6 @@ struct CastAwayTransferReadLeadingOneDim
if (read.getTransferRank() == 0)
return failure();
- if (read.getMask())
- return failure();
-
auto shapedType = cast<ShapedType>(read.getSource().getType());
if (shapedType.getElementType() != read.getVectorType().getElementType())
return failure();
@@ -233,10 +232,18 @@ struct CastAwayTransferReadLeadingOneDim
inBoundsAttr = rewriter.getArrayAttr(
read.getInBoundsAttr().getValue().take_back(newType.getRank()));
+ Value mask = Value();
+ if (read.getMask()) {
+ // The mask shape must always match the shape of the written vector, so we
+ // can safely use the same extraction indices.
+ int64_t dropDim = oldType.getRank() - newType.getRank();
+ mask = rewriter.create<vector::ExtractOp>(read.getLoc(), read.getMask(),
+ splatZero(dropDim));
+ }
+
auto newRead = rewriter.create<vector::TransferReadOp>(
read.getLoc(), newType, read.getSource(), read.getIndices(),
- AffineMapAttr::get(newMap), read.getPadding(), /*mask=*/Value(),
- inBoundsAttr);
+ AffineMapAttr::get(newMap), read.getPadding(), mask, inBoundsAttr);
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(read, oldType, newRead);
return success();
@@ -256,9 +263,6 @@ struct CastAwayTransferWriteLeadingOneDim
if (write.getTransferRank() == 0)
return failure();
- if (write.getMask())
- return failure();
-
auto shapedType = dyn_cast<ShapedType>(write.getSource().getType());
if (shapedType.getElementType() != write.getVectorType().getElementType())
return failure();
@@ -283,10 +287,21 @@ struct CastAwayTransferWriteLeadingOneDim
auto newVector = rewriter.create<vector::ExtractOp>(
write.getLoc(), write.getVector(), splatZero(dropDim));
+
+ if (write.getMask()) {
+ // The mask shape must always match the shape of the written vector, so we
+ // can safely use the same extraction indices.
+ auto newMask = rewriter.create<vector::ExtractOp>(
+ write.getLoc(), write.getMask(), splatZero(dropDim));
+ rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
+ write, newVector, write.getSource(), write.getIndices(),
+ AffineMapAttr::get(newMap), newMask, inBoundsAttr);
+ return success();
+ }
+
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
write, newVector, write.getSource(), write.getIndices(),
AffineMapAttr::get(newMap), inBoundsAttr);
-
return success();
}
};
@@ -467,6 +482,40 @@ class CastAwayElementwiseLeadingOneDim : public RewritePattern {
}
};
+// Drops leading 1 dimensions from vector.constant_mask and inserts a
+// vector.broadcast back to the original shape.
+struct CastAwayConstantMaskLeadingOneDim
+ : public OpRewritePattern<vector::ConstantMaskOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ConstantMaskOp mask,
+ PatternRewriter &rewriter) const override {
+ VectorType oldType = mask.getType();
+ VectorType newType = trimLeadingOneDims(oldType);
+
+ if (newType == oldType)
+ return failure();
+
+ int64_t dropDim = oldType.getRank() - newType.getRank();
+ SmallVector<int64_t> dimSizes;
+ for (auto attr : mask.getMaskDimSizes())
+ dimSizes.push_back(llvm::cast<IntegerAttr>(attr).getInt());
+
+ // If any of the dropped unit dims has a size of `0`, the entire mask is a
+ // zero mask, else the unit dim has no effect on the mask.
+ int64_t flatLeadingSize =
+ std::accumulate(dimSizes.begin(), dimSizes.begin() + dropDim + 1,
+ static_cast<int64_t>(1), std::multiplies<int64_t>());
+ SmallVector<int64_t> newDimSizes({flatLeadingSize});
+ newDimSizes.append(dimSizes.begin() + dropDim + 1, dimSizes.end());
+
+ auto newMask = rewriter.create<vector::ConstantMaskOp>(
+ mask.getLoc(), newType, rewriter.getI64ArrayAttr(newDimSizes));
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(mask, oldType, newMask);
+ return success();
+ }
+};
+
} // namespace
void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
@@ -474,7 +523,7 @@ void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
patterns
.add<CastAwayExtractStridedSliceLeadingOneDim,
CastAwayInsertStridedSliceLeadingOneDim, CastAwayInsertLeadingOneDim,
- CastAwayTransferReadLeadingOneDim,
+ CastAwayConstantMaskLeadingOneDim, CastAwayTransferReadLeadingOneDim,
CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim,
CastAwayContractionLeadingOneDim>(patterns.getContext(), benefit);
populateShapeCastFoldingPatterns(patterns, benefit);
diff --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
index e5b27b04dcc8096..5de30206927db2f 100644
--- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
@@ -209,6 +209,20 @@ func.func @cast_away_transfer_read_leading_one_dims(%arg0: memref<1x4x8x16xf16>)
return %0: vector<1x4xf16>
}
+// CHECK-LABEL: func @cast_away_masked_transfer_read_leading_one_dims
+func.func @cast_away_masked_transfer_read_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xi1>) -> vector<1x4xf16> {
+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
+ %c0 = arith.constant 0 : index
+ // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f16
+ %f0 = arith.constant 0. : f16
+ // CHECK: %[[MASK_CAST:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1>
+ // CHECK: %[[READ:.+]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[F0]], %[[MASK_CAST]] {in_bounds = [true]} : memref<1x4x8x16xf16>, vector<4xf16>
+ // CHECK: %[[CAST:.+]] = vector.broadcast %[[READ]] : vector<4xf16> to vector<1x4xf16>
+ %0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %f0, %arg1 {in_bounds = [true, true]} : memref<1x4x8x16xf16>, vector<1x4xf16>
+ // CHECK: return %[[CAST]]
+ return %0: vector<1x4xf16>
+}
+
// CHECK-LABEL: func @cast_away_transfer_read_leading_one_dims_one_element
func.func @cast_away_transfer_read_leading_one_dims_one_element(%arg0: memref<1x1x1x1xf16>) -> vector<1x1xf16> {
%c0 = arith.constant 0 : index
@@ -229,6 +243,18 @@ func.func @cast_away_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>
return
}
+// CHECK-LABEL: func @cast_away_masked_transfer_write_leading_one_dims
+func.func @cast_away_masked_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xf16>, %arg2: vector<1x4xi1>) {
+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
+ %c0 = arith.constant 0 : index
+ // CHECK: %[[CAST:.+]] = vector.extract %{{.*}}[0] : vector<4xf16> from vector<1x4xf16>
+ // CHECK: %[[MASK_CAST:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1>
+ // CHECK: vector.transfer_write %[[CAST]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[MASK_CAST]] {in_bounds = [true]} : vector<4xf16>, memref<1x4x8x16xf16>
+
+ vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0], %arg2 {in_bounds = [true, true]} : vector<1x4xf16>, memref<1x4x8x16xf16>
+ return
+}
+
// CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims_one_element
func.func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1x1x1x1xf16>, %arg1: vector<1x1xf16>) {
%c0 = arith.constant 0 : index
@@ -410,3 +436,12 @@ func.func @cast_away_insert_leading_one_dims_one_two_dest_scalable(%s: vector<1x
%0 = vector.insert %s, %v [0, 0, 7] : vector<1x[8]xi1> into vector<1x1x8x1x[8]xi1>
return %0: vector<1x1x8x1x[8]xi1>
}
+
+// CHECK-LABEL: func.func @cast_away_constant_mask() -> vector<1x1x8x2x1xi1> {
+// CHECK: %[[MASK:.*]] = vector.constant_mask [6, 1, 1] : vector<8x2x1xi1>
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[MASK]] : vector<8x2x1xi1> to vector<1x1x8x2x1xi1>
+// CHECK: return %[[BCAST]] : vector<1x1x8x2x1xi1>
+func.func @cast_away_constant_mask() -> vector<1x1x8x2x1xi1> {
+ %0 = vector.constant_mask [1, 1, 6, 1, 1] : vector<1x1x8x2x1xi1>
+ return %0: vector<1x1x8x2x1xi1>
+}
|
This was referenced Nov 8, 2023
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This handles
vector.transfer_read,vector.transfer_write, andvector.constant_mask. The unit dims are only relevant for masks created bycreate_maskandconstant_maskif the mask size for the unit dim is non-one, in which case all subsequent sizes must also be zero. From the perspective of the vector transfers, however, these unit dims can just be dropped directly.