-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][vector] Sink vector.extract/splat into load/store ops #134389
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-arith @llvm/pr-subscribers-mlir Author: Ivan Butygin (Hardcode84) Changes
Gets converted to:
Gets converted to:
Patch is 20.19 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/134389.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index f46aa0428f12f..7fbb437908866 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -469,8 +469,28 @@ def ApplySinkVectorPatternsOp : Op<Transform_Dialect,
%0 = arith.addf %a, %b : vector<4x2xf32>
%r = vector.transpose %0, [1, 0] : vector<2x4xf32>
```
- At the moment, these patterns are limited to vector.broadcast and
- vector.transpose.
+ At the moment, these patterns are limited to vector.broadcast,
+ vector.transpose and vector.extract.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
+def ApplySinkVectorMemPatternsOp : Op<Transform_Dialect,
+ "apply_patterns.vector.sink_mem_ops",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Patterns that remove redundant Vector Ops by merging them with load/store
+ ops
+ ```
+ vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+ vector.extract %0[1] : f32 from vector<4xf32>
+ ```
+ Gets converted to:
+ ```
+ %c1 = arith.constant 1 : index
+ %0 = arith.addi %arg1, %c1 overflow<nsw> : index
+ %1 = memref.load %arg0[%0] : memref<?xf32>
}];
let assemblyFormat = "attr-dict";
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 7de4a6a315750..2d8b12c871be7 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -161,6 +161,20 @@ void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
+/// Patterns that remove redundant Vector Ops by merging them with load/store
+/// ops
+/// ```
+/// vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+/// vector.extract %0[1] : f32 from vector<4xf32>
+/// ```
+/// Gets converted to:
+/// ```
+/// %c1 = arith.constant 1 : index
+/// %0 = arith.addi %arg1, %c1 overflow<nsw> : index
+/// %1 = memref.load %arg0[%0] : memref<?xf32>
+void populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
+
/// Patterns that fold chained vector reductions. These patterns assume that
/// elementwise operations (e.g., `arith.addf` with vector operands) are
/// cheaper than vector reduction.
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 12dcf768dd928..a888d745be443 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -212,6 +212,11 @@ void transform::ApplySinkVectorPatternsOp::populatePatterns(
vector::populateSinkVectorOpsPatterns(patterns);
}
+void transform::ApplySinkVectorMemPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ vector::populateSinkVectorMemOpsPatterns(patterns);
+}
+
//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index b6fac80d871e6..697a4228b3a53 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1103,6 +1103,127 @@ class ExtractOpFromElementwise final
}
};
+/// Pattern to rewrite vector.extract(vector.load) -> vector/memref.load.
+/// ```
+/// vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+/// vector.extract %0[1] : f32 from vector<4xf32>
+/// ```
+/// Gets converted to:
+/// ```
+/// %c1 = arith.constant 1 : index
+/// %0 = arith.addi %arg1, %c1 overflow<nsw> : index
+/// %1 = memref.load %arg0[%0] : memref<?xf32>
+/// ```
+class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ExtractOp op,
+ PatternRewriter &rewriter) const override {
+ auto loadOp = op.getVector().getDefiningOp<vector::LoadOp>();
+ if (!loadOp)
+ return rewriter.notifyMatchFailure(op, "not a load op");
+
+ if (!loadOp->hasOneUse())
+ return rewriter.notifyMatchFailure(op, "expected single op use");
+
+ VectorType memVecType = loadOp.getVectorType();
+ if (memVecType.isScalable())
+ return rewriter.notifyMatchFailure(op,
+ "scalable vectors are not supported");
+
+ MemRefType memType = loadOp.getMemRefType();
+ if (isa<VectorType>(memType.getElementType()))
+ return rewriter.notifyMatchFailure(
+ op, "memrefs of vectors are not supported");
+
+ int64_t rankOffset = memType.getRank() - memVecType.getRank();
+ if (rankOffset < 0)
+ return rewriter.notifyMatchFailure(op, "unsupported ranks combination");
+
+ auto resVecType = dyn_cast<VectorType>(op.getResult().getType());
+ int64_t finalRank = 0;
+ if (resVecType)
+ finalRank = resVecType.getRank();
+
+ SmallVector<Value> indices = loadOp.getIndices();
+ SmallVector<OpFoldResult> extractPos = op.getMixedPosition();
+
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(loadOp);
+ Location loc = loadOp.getLoc();
+ for (auto i : llvm::seq<int64_t>(rankOffset, indices.size() - finalRank)) {
+ OpFoldResult pos = extractPos[i - rankOffset];
+ if (isConstantIntValue(pos, 0))
+ continue;
+
+ Value offset = getValueOrCreateConstantIndexOp(rewriter, loc, pos);
+
+ auto ovf = arith::IntegerOverflowFlags::nsw;
+ indices[i] = rewriter.create<arith::AddIOp>(loc, indices[i], offset, ovf);
+ }
+
+ Value base = loadOp.getBase();
+ if (resVecType) {
+ rewriter.replaceOpWithNewOp<vector::LoadOp>(op, resVecType, base,
+ indices);
+ } else {
+ rewriter.replaceOpWithNewOp<memref::LoadOp>(op, base, indices);
+ }
+ rewriter.eraseOp(loadOp);
+ return success();
+ }
+};
+
+/// Pattern to rewrite vector.store(vector.splat) -> vector/memref.store.
+/// ```
+/// %0 = vector.splat %arg2 : vector<1xf32>
+/// vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
+/// ```
+/// Gets converted to:
+/// ```
+/// memref.store %arg2, %arg0[%arg1] : memref<?xf32>
+/// ```
+class StoreFromSplat final : public OpRewritePattern<vector::StoreOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::StoreOp op,
+ PatternRewriter &rewriter) const override {
+ VectorType vecType = op.getVectorType();
+ if (vecType.isScalable())
+ return rewriter.notifyMatchFailure(op,
+ "scalable vectors are not supported");
+
+ if (isa<VectorType>(op.getMemRefType().getElementType()))
+ return rewriter.notifyMatchFailure(
+ op, "memrefs of vectors are not supported");
+
+ if (vecType.getNumElements() != 1)
+ return rewriter.notifyMatchFailure(
+ op, "only 1-element, vectors are supported");
+
+ Operation *splat = op.getValueToStore().getDefiningOp();
+ if (!isa_and_present<vector::BroadcastOp, vector::SplatOp>(splat))
+ return rewriter.notifyMatchFailure(op, "not a splat");
+
+ if (!splat->hasOneUse())
+ return rewriter.notifyMatchFailure(op, "expected single op use");
+
+ Value source = splat->getOperand(0);
+ Value base = op.getBase();
+ ValueRange indices = op.getIndices();
+
+ if (isa<VectorType>(source.getType())) {
+ rewriter.replaceOpWithNewOp<vector::StoreOp>(op, source, base, indices);
+ } else {
+ rewriter.replaceOpWithNewOp<memref::StoreOp>(op, source, base, indices);
+ }
+ rewriter.eraseOp(splat);
+ return success();
+ }
+};
+
// Helper that returns a vector comparison that constructs a mask:
// mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
//
@@ -2175,6 +2296,12 @@ void mlir::vector::populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
patterns.getContext(), benefit);
}
+void mlir::vector::populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit) {
+ patterns.add<ExtractOpFromLoad, StoreFromSplat>(patterns.getContext(),
+ benefit);
+}
+
void mlir::vector::populateChainedVectorReductionFoldingPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<ChainedReduction>(patterns.getContext(), benefit);
diff --git a/mlir/test/Dialect/Vector/vector-sink-transform.mlir b/mlir/test/Dialect/Vector/vector-sink-transform.mlir
index ef17b69b2444c..4d04276742164 100644
--- a/mlir/test/Dialect/Vector/vector-sink-transform.mlir
+++ b/mlir/test/Dialect/Vector/vector-sink-transform.mlir
@@ -7,6 +7,7 @@ module attributes {transform.with_named_sequence} {
%func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.vector.sink_ops
+ transform.apply_patterns.vector.sink_mem_ops
} : !transform.any_op
transform.yield
}
diff --git a/mlir/test/Dialect/Vector/vector-sink.mlir b/mlir/test/Dialect/Vector/vector-sink.mlir
index 8c8f1797aaab6..ad4fdbe0a7b5a 100644
--- a/mlir/test/Dialect/Vector/vector-sink.mlir
+++ b/mlir/test/Dialect/Vector/vector-sink.mlir
@@ -513,3 +513,192 @@ func.func @negative_extract_vec_fma(%arg0: vector<4xf32>, %arg1: vector<4xf32>,
%1 = vector.extract %0[1] : f32 from vector<4xf32>
return %1 : f32
}
+
+//-----------------------------------------------------------------------------
+// [Pattern: ExtractOpFromLoad]
+//-----------------------------------------------------------------------------
+
+// CHECK-LABEL: @extract_load_scalar
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
+func.func @extract_load_scalar(%arg0: memref<?xf32>, %arg1: index) -> f32 {
+// CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>
+// CHECK: return %[[RES]] : f32
+ %0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+ %1 = vector.extract %0[0] : f32 from vector<4xf32>
+ return %1 : f32
+}
+
+// CHECK-LABEL: @extract_load_scalar_non_zero_off
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
+func.func @extract_load_scalar_non_zero_off(%arg0: memref<?xf32>, %arg1: index) -> f32 {
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[OFF:.*]] = arith.addi %[[ARG1]], %[[C1]] overflow<nsw> : index
+// CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[OFF]]] : memref<?xf32>
+// CHECK: return %[[RES]] : f32
+ %0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+ %1 = vector.extract %0[1] : f32 from vector<4xf32>
+ return %1 : f32
+}
+
+// CHECK-LABEL: @extract_load_scalar_dyn_off
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+func.func @extract_load_scalar_dyn_off(%arg0: memref<?xf32>, %arg1: index, %arg2: index) -> f32 {
+// CHECK: %[[OFF:.*]] = arith.addi %[[ARG1]], %[[ARG2]] overflow<nsw> : index
+// CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[OFF]]] : memref<?xf32>
+// CHECK: return %[[RES]] : f32
+ %0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+ %1 = vector.extract %0[%arg2] : f32 from vector<4xf32>
+ return %1 : f32
+}
+
+// CHECK-LABEL: @extract_load_vec
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+func.func @extract_load_vec(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index) -> vector<4xf32> {
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[OFF:.*]] = arith.addi %[[ARG1]], %[[C1]] overflow<nsw> : index
+// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[OFF]], %[[ARG2]]] : memref<?x?xf32>, vector<4xf32>
+// CHECK: return %[[RES]] : vector<4xf32>
+ %0 = vector.load %arg0[%arg1, %arg2] : memref<?x?xf32>, vector<2x4xf32>
+ %1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
+ return %1 : vector<4xf32>
+}
+
+// CHECK-LABEL: @extract_load_scalar_high_rank
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+func.func @extract_load_scalar_high_rank(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index) -> f32 {
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[OFF:.*]] = arith.addi %[[ARG2]], %[[C1]] overflow<nsw> : index
+// CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[OFF]]] : memref<?x?xf32>
+// CHECK: return %[[RES]] : f32
+ %0 = vector.load %arg0[%arg1, %arg2] : memref<?x?xf32>, vector<4xf32>
+ %1 = vector.extract %0[1] : f32 from vector<4xf32>
+ return %1 : f32
+}
+
+// CHECK-LABEL: @extract_load_vec_high_rank
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
+func.func @extract_load_vec_high_rank(%arg0: memref<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index) -> vector<4xf32> {
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[OFF:.*]] = arith.addi %[[ARG2]], %[[C1]] overflow<nsw> : index
+// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]], %[[OFF]], %[[ARG3]]] : memref<?x?x?xf32>, vector<4xf32>
+// CHECK: return %[[RES]] : vector<4xf32>
+ %0 = vector.load %arg0[%arg1, %arg2, %arg3] : memref<?x?x?xf32>, vector<2x4xf32>
+ %1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
+ return %1 : vector<4xf32>
+}
+
+// CHECK-LABEL: @negative_load_scalar_from_vec_memref
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?xvector<4xf32>>, %[[ARG1:.*]]: index)
+func.func @negative_load_scalar_from_vec_memref(%arg0: memref<?xvector<4xf32>>, %arg1: index) -> f32 {
+// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xvector<4xf32>>, vector<4xf32>
+// CHECK: %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<4xf32>
+// CHECK: return %[[EXT]] : f32
+ %0 = vector.load %arg0[%arg1] : memref<?xvector<4xf32>>, vector<4xf32>
+ %1 = vector.extract %0[0] : f32 from vector<4xf32>
+ return %1 : f32
+}
+
+// CHECK-LABEL: @negative_extract_load_no_single_use
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
+func.func @negative_extract_load_no_single_use(%arg0: memref<?xf32>, %arg1: index) -> (f32, vector<4xf32>) {
+// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<4xf32>
+// CHECK: %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<4xf32>
+// CHECK: return %[[EXT]], %[[RES]] : f32, vector<4xf32>
+ %0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+ %1 = vector.extract %0[0] : f32 from vector<4xf32>
+ return %1, %0 : f32, vector<4xf32>
+}
+
+// CHECK-LABEL: @negative_load_scalable
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
+func.func @negative_load_scalable(%arg0: memref<?xf32>, %arg1: index) -> f32 {
+// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<[1]xf32>
+// CHECK: %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<[1]xf32>
+// CHECK: return %[[EXT]] : f32
+ %0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<[1]xf32>
+ %1 = vector.extract %0[0] : f32 from vector<[1]xf32>
+ return %1 : f32
+}
+
+// CHECK-LABEL: @negative_extract_load_unsupported_ranks
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
+func.func @negative_extract_load_unsupported_ranks(%arg0: memref<?xf32>, %arg1: index) -> vector<4xf32> {
+// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<2x4xf32>
+// CHECK: %[[EXT:.*]] = vector.extract %[[RES]][1] : vector<4xf32> from vector<2x4xf32>
+// CHECK: return %[[EXT]] : vector<4xf32>
+ %0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<2x4xf32>
+ %1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
+ return %1 : vector<4xf32>
+}
+
+//-----------------------------------------------------------------------------
+// [Pattern: StoreFromSplat]
+//-----------------------------------------------------------------------------
+
+// CHECK-LABEL: @store_splat
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
+func.func @store_splat(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) {
+// CHECK: memref.store %[[ARG2]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>
+ %0 = vector.splat %arg2 : vector<1xf32>
+ vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
+ return
+}
+
+// CHECK-LABEL: @store_broadcast
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
+func.func @store_broadcast(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) {
+// CHECK: memref.store %[[ARG2]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>
+ %0 = vector.broadcast %arg2 : f32 to vector<1xf32>
+ vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
+ return
+}
+
+// CHECK-LABEL: @store_broadcast_1d_2d
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: vector<1xf32>)
+func.func @store_broadcast_1d_2d(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index, %arg3: vector<1xf32>) {
+// CHECK: vector.store %[[ARG3]], %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<?x?xf32>, vector<1xf32>
+ %0 = vector.broadcast %arg3 : vector<1xf32> to vector<1x1xf32>
+ vector.store %0, %arg0[%arg1, %arg2] : memref<?x?xf32>, vector<1x1xf32>
+ return
+}
+
+// CHECK-LABEL: @negative_store_scalable
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
+func.func @negative_store_scalable(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) {
+// CHECK: %[[RES:.*]] = vector.splat %[[ARG2]] : vector<[1]xf32>
+// CHECK: vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<[1]xf32>
+ %0 = vector.splat %arg2 : vector<[1]xf32>
+ vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<[1]xf32>
+ return
+}
+
+// CHECK-LABEL: @negative_store_vec_memref
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?xvector<1xf32>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
+func.func @negative_store_vec_memref(%arg0: memref<?xvector<1xf32>>, %arg1: index, %arg2: f32) {
+// CHECK: %[[RES:.*]] = vector.splat %[[ARG2]] : vector<1xf32>
+// CHECK: vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref<?xvector<1xf32>>, vector<1xf32>
+ %0 = vector.splat %arg2 : vector<1xf32>
+ vector.store %0, %arg0[%arg1] : memref<?xvector<1xf32>>, vector<1xf32>
+ return
+}
+
+// CHECK-LABEL: @negative_store_non_1
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
+func.func @negative_store_non_1(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) {
+// CHECK: %[[RES:.*]] = vector.splat %[[ARG2]] : vector<4xf32>
+// CHECK: vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<4xf32>
+ %0 = vector.splat %arg2 : vector<4xf32>
+ vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+ return
+}
+
+// CHECK-LABEL: @negative_store_no_single_use
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
+func.func @negative_store_no_single_use(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) -> vector<1xf32> {
+// CHECK: %[[RES:.*]] = vector.splat %[[ARG2]] : vector<1xf32>
+// CHECK: vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<1xf32>
+// CHECK: return %[[RES:.*]] : vector<1xf32>
+ %0 = vector.splat %arg2 : vector<1xf32>
+ vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
+ return %0 : vector<1xf32>
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index a54ae816570a8..03f907e46c2c6 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -395,6 +395,7 @@ struct TestVectorSinkPatterns
void runOnOperation() override {
RewritePatternSet pa...
[truncated]
|
@llvm/pr-subscribers-mlir-vector Author: Ivan Butygin (Hardcode84) Changes
Gets converted to:
Gets converted to:
Patch is 20.19 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/134389.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index f46aa0428f12f..7fbb437908866 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -469,8 +469,28 @@ def ApplySinkVectorPatternsOp : Op<Transform_Dialect,
%0 = arith.addf %a, %b : vector<4x2xf32>
%r = vector.transpose %0, [1, 0] : vector<2x4xf32>
```
- At the moment, these patterns are limited to vector.broadcast and
- vector.transpose.
+ At the moment, these patterns are limited to vector.broadcast,
+ vector.transpose and vector.extract.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
+def ApplySinkVectorMemPatternsOp : Op<Transform_Dialect,
+ "apply_patterns.vector.sink_mem_ops",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Patterns that remove redundant Vector Ops by merging them with load/store
+ ops
+ ```
+ vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+ vector.extract %0[1] : f32 from vector<4xf32>
+ ```
+ Gets converted to:
+ ```
+ %c1 = arith.constant 1 : index
+ %0 = arith.addi %arg1, %c1 overflow<nsw> : index
+ %1 = memref.load %arg0[%0] : memref<?xf32>
}];
let assemblyFormat = "attr-dict";
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 7de4a6a315750..2d8b12c871be7 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -161,6 +161,20 @@ void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
+/// Patterns that remove redundant Vector Ops by merging them with load/store
+/// ops
+/// ```
+/// vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+/// vector.extract %0[1] : f32 from vector<4xf32>
+/// ```
+/// Gets converted to:
+/// ```
+/// %c1 = arith.constant 1 : index
+/// %0 = arith.addi %arg1, %c1 overflow<nsw> : index
+/// %1 = memref.load %arg0[%0] : memref<?xf32>
+void populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
+
/// Patterns that fold chained vector reductions. These patterns assume that
/// elementwise operations (e.g., `arith.addf` with vector operands) are
/// cheaper than vector reduction.
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 12dcf768dd928..a888d745be443 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -212,6 +212,11 @@ void transform::ApplySinkVectorPatternsOp::populatePatterns(
vector::populateSinkVectorOpsPatterns(patterns);
}
+void transform::ApplySinkVectorMemPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ vector::populateSinkVectorMemOpsPatterns(patterns);
+}
+
//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index b6fac80d871e6..697a4228b3a53 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1103,6 +1103,127 @@ class ExtractOpFromElementwise final
}
};
+/// Pattern to rewrite vector.extract(vector.load) -> vector/memref.load.
+/// ```
+/// vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+/// vector.extract %0[1] : f32 from vector<4xf32>
+/// ```
+/// Gets converted to:
+/// ```
+/// %c1 = arith.constant 1 : index
+/// %0 = arith.addi %arg1, %c1 overflow<nsw> : index
+/// %1 = memref.load %arg0[%0] : memref<?xf32>
+/// ```
+class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ExtractOp op,
+ PatternRewriter &rewriter) const override {
+ auto loadOp = op.getVector().getDefiningOp<vector::LoadOp>();
+ if (!loadOp)
+ return rewriter.notifyMatchFailure(op, "not a load op");
+
+ if (!loadOp->hasOneUse())
+ return rewriter.notifyMatchFailure(op, "expected single op use");
+
+ VectorType memVecType = loadOp.getVectorType();
+ if (memVecType.isScalable())
+ return rewriter.notifyMatchFailure(op,
+ "scalable vectors are not supported");
+
+ MemRefType memType = loadOp.getMemRefType();
+ if (isa<VectorType>(memType.getElementType()))
+ return rewriter.notifyMatchFailure(
+ op, "memrefs of vectors are not supported");
+
+ int64_t rankOffset = memType.getRank() - memVecType.getRank();
+ if (rankOffset < 0)
+ return rewriter.notifyMatchFailure(op, "unsupported ranks combination");
+
+ auto resVecType = dyn_cast<VectorType>(op.getResult().getType());
+ int64_t finalRank = 0;
+ if (resVecType)
+ finalRank = resVecType.getRank();
+
+ SmallVector<Value> indices = loadOp.getIndices();
+ SmallVector<OpFoldResult> extractPos = op.getMixedPosition();
+
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(loadOp);
+ Location loc = loadOp.getLoc();
+ for (auto i : llvm::seq<int64_t>(rankOffset, indices.size() - finalRank)) {
+ OpFoldResult pos = extractPos[i - rankOffset];
+ if (isConstantIntValue(pos, 0))
+ continue;
+
+ Value offset = getValueOrCreateConstantIndexOp(rewriter, loc, pos);
+
+ auto ovf = arith::IntegerOverflowFlags::nsw;
+ indices[i] = rewriter.create<arith::AddIOp>(loc, indices[i], offset, ovf);
+ }
+
+ Value base = loadOp.getBase();
+ if (resVecType) {
+ rewriter.replaceOpWithNewOp<vector::LoadOp>(op, resVecType, base,
+ indices);
+ } else {
+ rewriter.replaceOpWithNewOp<memref::LoadOp>(op, base, indices);
+ }
+ rewriter.eraseOp(loadOp);
+ return success();
+ }
+};
+
+/// Pattern to rewrite vector.store(vector.splat) -> vector/memref.store.
+/// ```
+/// %0 = vector.splat %arg2 : vector<1xf32>
+/// vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
+/// ```
+/// Gets converted to:
+/// ```
+/// memref.store %arg2, %arg0[%arg1] : memref<?xf32>
+/// ```
+class StoreFromSplat final : public OpRewritePattern<vector::StoreOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::StoreOp op,
+ PatternRewriter &rewriter) const override {
+ VectorType vecType = op.getVectorType();
+ if (vecType.isScalable())
+ return rewriter.notifyMatchFailure(op,
+ "scalable vectors are not supported");
+
+ if (isa<VectorType>(op.getMemRefType().getElementType()))
+ return rewriter.notifyMatchFailure(
+ op, "memrefs of vectors are not supported");
+
+ if (vecType.getNumElements() != 1)
+ return rewriter.notifyMatchFailure(
+ op, "only 1-element, vectors are supported");
+
+ Operation *splat = op.getValueToStore().getDefiningOp();
+ if (!isa_and_present<vector::BroadcastOp, vector::SplatOp>(splat))
+ return rewriter.notifyMatchFailure(op, "not a splat");
+
+ if (!splat->hasOneUse())
+ return rewriter.notifyMatchFailure(op, "expected single op use");
+
+ Value source = splat->getOperand(0);
+ Value base = op.getBase();
+ ValueRange indices = op.getIndices();
+
+ if (isa<VectorType>(source.getType())) {
+ rewriter.replaceOpWithNewOp<vector::StoreOp>(op, source, base, indices);
+ } else {
+ rewriter.replaceOpWithNewOp<memref::StoreOp>(op, source, base, indices);
+ }
+ rewriter.eraseOp(splat);
+ return success();
+ }
+};
+
// Helper that returns a vector comparison that constructs a mask:
// mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
//
@@ -2175,6 +2296,12 @@ void mlir::vector::populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
patterns.getContext(), benefit);
}
+void mlir::vector::populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit) {
+ patterns.add<ExtractOpFromLoad, StoreFromSplat>(patterns.getContext(),
+ benefit);
+}
+
void mlir::vector::populateChainedVectorReductionFoldingPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<ChainedReduction>(patterns.getContext(), benefit);
diff --git a/mlir/test/Dialect/Vector/vector-sink-transform.mlir b/mlir/test/Dialect/Vector/vector-sink-transform.mlir
index ef17b69b2444c..4d04276742164 100644
--- a/mlir/test/Dialect/Vector/vector-sink-transform.mlir
+++ b/mlir/test/Dialect/Vector/vector-sink-transform.mlir
@@ -7,6 +7,7 @@ module attributes {transform.with_named_sequence} {
%func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.vector.sink_ops
+ transform.apply_patterns.vector.sink_mem_ops
} : !transform.any_op
transform.yield
}
diff --git a/mlir/test/Dialect/Vector/vector-sink.mlir b/mlir/test/Dialect/Vector/vector-sink.mlir
index 8c8f1797aaab6..ad4fdbe0a7b5a 100644
--- a/mlir/test/Dialect/Vector/vector-sink.mlir
+++ b/mlir/test/Dialect/Vector/vector-sink.mlir
@@ -513,3 +513,192 @@ func.func @negative_extract_vec_fma(%arg0: vector<4xf32>, %arg1: vector<4xf32>,
%1 = vector.extract %0[1] : f32 from vector<4xf32>
return %1 : f32
}
+
+//-----------------------------------------------------------------------------
+// [Pattern: ExtractOpFromLoad]
+//-----------------------------------------------------------------------------
+
+// CHECK-LABEL: @extract_load_scalar
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
+func.func @extract_load_scalar(%arg0: memref<?xf32>, %arg1: index) -> f32 {
+// CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>
+// CHECK: return %[[RES]] : f32
+ %0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+ %1 = vector.extract %0[0] : f32 from vector<4xf32>
+ return %1 : f32
+}
+
+// CHECK-LABEL: @extract_load_scalar_non_zero_off
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
+func.func @extract_load_scalar_non_zero_off(%arg0: memref<?xf32>, %arg1: index) -> f32 {
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[OFF:.*]] = arith.addi %[[ARG1]], %[[C1]] overflow<nsw> : index
+// CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[OFF]]] : memref<?xf32>
+// CHECK: return %[[RES]] : f32
+ %0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+ %1 = vector.extract %0[1] : f32 from vector<4xf32>
+ return %1 : f32
+}
+
+// CHECK-LABEL: @extract_load_scalar_dyn_off
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+func.func @extract_load_scalar_dyn_off(%arg0: memref<?xf32>, %arg1: index, %arg2: index) -> f32 {
+// CHECK: %[[OFF:.*]] = arith.addi %[[ARG1]], %[[ARG2]] overflow<nsw> : index
+// CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[OFF]]] : memref<?xf32>
+// CHECK: return %[[RES]] : f32
+ %0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+ %1 = vector.extract %0[%arg2] : f32 from vector<4xf32>
+ return %1 : f32
+}
+
+// CHECK-LABEL: @extract_load_vec
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+func.func @extract_load_vec(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index) -> vector<4xf32> {
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[OFF:.*]] = arith.addi %[[ARG1]], %[[C1]] overflow<nsw> : index
+// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[OFF]], %[[ARG2]]] : memref<?x?xf32>, vector<4xf32>
+// CHECK: return %[[RES]] : vector<4xf32>
+ %0 = vector.load %arg0[%arg1, %arg2] : memref<?x?xf32>, vector<2x4xf32>
+ %1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
+ return %1 : vector<4xf32>
+}
+
+// CHECK-LABEL: @extract_load_scalar_high_rank
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+func.func @extract_load_scalar_high_rank(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index) -> f32 {
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[OFF:.*]] = arith.addi %[[ARG2]], %[[C1]] overflow<nsw> : index
+// CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[OFF]]] : memref<?x?xf32>
+// CHECK: return %[[RES]] : f32
+ %0 = vector.load %arg0[%arg1, %arg2] : memref<?x?xf32>, vector<4xf32>
+ %1 = vector.extract %0[1] : f32 from vector<4xf32>
+ return %1 : f32
+}
+
+// CHECK-LABEL: @extract_load_vec_high_rank
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
+func.func @extract_load_vec_high_rank(%arg0: memref<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index) -> vector<4xf32> {
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[OFF:.*]] = arith.addi %[[ARG2]], %[[C1]] overflow<nsw> : index
+// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]], %[[OFF]], %[[ARG3]]] : memref<?x?x?xf32>, vector<4xf32>
+// CHECK: return %[[RES]] : vector<4xf32>
+ %0 = vector.load %arg0[%arg1, %arg2, %arg3] : memref<?x?x?xf32>, vector<2x4xf32>
+ %1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
+ return %1 : vector<4xf32>
+}
+
+// CHECK-LABEL: @negative_load_scalar_from_vec_memref
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?xvector<4xf32>>, %[[ARG1:.*]]: index)
+func.func @negative_load_scalar_from_vec_memref(%arg0: memref<?xvector<4xf32>>, %arg1: index) -> f32 {
+// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xvector<4xf32>>, vector<4xf32>
+// CHECK: %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<4xf32>
+// CHECK: return %[[EXT]] : f32
+ %0 = vector.load %arg0[%arg1] : memref<?xvector<4xf32>>, vector<4xf32>
+ %1 = vector.extract %0[0] : f32 from vector<4xf32>
+ return %1 : f32
+}
+
+// CHECK-LABEL: @negative_extract_load_no_single_use
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
+func.func @negative_extract_load_no_single_use(%arg0: memref<?xf32>, %arg1: index) -> (f32, vector<4xf32>) {
+// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<4xf32>
+// CHECK: %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<4xf32>
+// CHECK: return %[[EXT]], %[[RES]] : f32, vector<4xf32>
+ %0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+ %1 = vector.extract %0[0] : f32 from vector<4xf32>
+ return %1, %0 : f32, vector<4xf32>
+}
+
+// CHECK-LABEL: @negative_load_scalable
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
+func.func @negative_load_scalable(%arg0: memref<?xf32>, %arg1: index) -> f32 {
+// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<[1]xf32>
+// CHECK: %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<[1]xf32>
+// CHECK: return %[[EXT]] : f32
+ %0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<[1]xf32>
+ %1 = vector.extract %0[0] : f32 from vector<[1]xf32>
+ return %1 : f32
+}
+
+// CHECK-LABEL: @negative_extract_load_unsupported_ranks
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
+func.func @negative_extract_load_unsupported_ranks(%arg0: memref<?xf32>, %arg1: index) -> vector<4xf32> {
+// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<2x4xf32>
+// CHECK: %[[EXT:.*]] = vector.extract %[[RES]][1] : vector<4xf32> from vector<2x4xf32>
+// CHECK: return %[[EXT]] : vector<4xf32>
+ %0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<2x4xf32>
+ %1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
+ return %1 : vector<4xf32>
+}
+
+//-----------------------------------------------------------------------------
+// [Pattern: StoreFromSplat]
+//-----------------------------------------------------------------------------
+
+// CHECK-LABEL: @store_splat
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
+func.func @store_splat(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) {
+// CHECK: memref.store %[[ARG2]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>
+ %0 = vector.splat %arg2 : vector<1xf32>
+ vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
+ return
+}
+
+// CHECK-LABEL: @store_broadcast
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
+func.func @store_broadcast(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) {
+// CHECK: memref.store %[[ARG2]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>
+ %0 = vector.broadcast %arg2 : f32 to vector<1xf32>
+ vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
+ return
+}
+
+// CHECK-LABEL: @store_broadcast_1d_2d
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: vector<1xf32>)
+func.func @store_broadcast_1d_2d(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index, %arg3: vector<1xf32>) {
+// CHECK: vector.store %[[ARG3]], %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<?x?xf32>, vector<1xf32>
+ %0 = vector.broadcast %arg3 : vector<1xf32> to vector<1x1xf32>
+ vector.store %0, %arg0[%arg1, %arg2] : memref<?x?xf32>, vector<1x1xf32>
+ return
+}
+
+// CHECK-LABEL: @negative_store_scalable
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
+func.func @negative_store_scalable(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) {
+// CHECK: %[[RES:.*]] = vector.splat %[[ARG2]] : vector<[1]xf32>
+// CHECK: vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<[1]xf32>
+ %0 = vector.splat %arg2 : vector<[1]xf32>
+ vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<[1]xf32>
+ return
+}
+
+// CHECK-LABEL: @negative_store_vec_memref
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?xvector<1xf32>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
+func.func @negative_store_vec_memref(%arg0: memref<?xvector<1xf32>>, %arg1: index, %arg2: f32) {
+// CHECK: %[[RES:.*]] = vector.splat %[[ARG2]] : vector<1xf32>
+// CHECK: vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref<?xvector<1xf32>>, vector<1xf32>
+ %0 = vector.splat %arg2 : vector<1xf32>
+ vector.store %0, %arg0[%arg1] : memref<?xvector<1xf32>>, vector<1xf32>
+ return
+}
+
+// CHECK-LABEL: @negative_store_non_1
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
+func.func @negative_store_non_1(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) {
+// CHECK: %[[RES:.*]] = vector.splat %[[ARG2]] : vector<4xf32>
+// CHECK: vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<4xf32>
+ %0 = vector.splat %arg2 : vector<4xf32>
+ vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+ return
+}
+
+// CHECK-LABEL: @negative_store_no_single_use
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
+func.func @negative_store_no_single_use(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) -> vector<1xf32> {
+// CHECK: %[[RES:.*]] = vector.splat %[[ARG2]] : vector<1xf32>
+// CHECK: vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<1xf32>
+// CHECK: return %[[RES:.*]] : vector<1xf32>
+ %0 = vector.splat %arg2 : vector<1xf32>
+ vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
+ return %0 : vector<1xf32>
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index a54ae816570a8..03f907e46c2c6 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -395,6 +395,7 @@ struct TestVectorSinkPatterns
void runOnOperation() override {
RewritePatternSet pa...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, thank you for contributing this!
I was a bit surprised that vector.load
and vector.store
only accept MemRef(s). @dcaballe , since you added vector.store
, do you know the context? That would help us position ApplySinkVectorMemPatternsOp
accordingly (e.g., do we need a dedicated TD Op?).
More comments inline.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice transformation! As mentioned inline, I think we should move these patters to the canonicalizer. Other than that, LGTM!
/// ``` | ||
/// %c1 = arith.constant 1 : index | ||
/// %0 = arith.addi %arg1, %c1 overflow<nsw> : index | ||
/// %1 = memref.load %arg0[%0] : memref<?xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should be a canonicalization pattern iff there's only one use which is a vector.extract. I can't think of a reason why we would want to load the redundant elements. I would clearly document that this only applies to cases with one use/extract op.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should be a canonicalization pattern iff there's only one use which is a vector.extract.
No objections from me, but from a purely maintenance point of view, I'd leave the implementation and most of the tests where they are. Otherwise, we risk "bloating" canonicalization.mlir and e.g. VectorOps.cpp.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One potential usecase where keeping vector.load + extract may be useful is when we are loading vector on aligned address for perf reasons and then using extract with offset to get unaligned data. I don't have such examples in practice, though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can't think of a reason why we would want to load the redundant elements
Increasing the granularity of memory accesses may cause you not to be able to use wider load/store instructions, and undoing this later on and proving that you can use a wider memory access may be hard. We'd be losing information about how many bits are dereferencable and potentially misaligning the access.
This may also change the semantics of load instructions that support OOB behavior -- you can turn an OOB access into an in-bounds access.
For this reason, I don't think this should be on by default.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also vector.extract ... [5] : vector<8xi1>
. Applying the pattern in this case means loses byte alignment which also doesn't seem like a good fit for a canonicalization to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This pattern is constrained to using/extracting only one element so we wouldn't be dropping access pattern information for that case, right? Do you have something else in mind?
Say one of your memory regions is dword-sized but your memory accesses take byte offsets:
%x = vector.load ... : vector<4xi8>
%y = vector.extract %x [2]: i8
The original load is efficient because you are accessing a full dword. However, if you turn it into memref.load ... : i8
, you may no longer know, once the index calculation simplifies with something else, that to get an aligned dword load you you need to also load the preceding bytes vs. only the bytes following this i8 (unaligned). You could resolve that with some masking + shifting, but that comes with some overhead.
Could you help me understand this? We should be able to remove any load operation that is dead, regardless of whether it's in-bounds or OOB, right? What makes this case different?
For example, the buffer instruction on amdgpu allow you to get a default value for any OOB accesses. Looking at the example above, it could be that only the last byte is OOB, but this alone makes the whole vector<4xi8>
have the default value. If you no longer load that last byte, the access would be in-bounds and you would observe a different value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we are actually need any special handling or tests for subbyte types. The only ways we can have load ... vector<8xi1>
are either loading from memref<...xi1>
for which semantics is fully consistent, or loading from memref<...xvector<8xi1>>
which is ignored by current pattern.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Applying this pattern to a vector of bits would lead to memref.load %src[%idx] : memref<8xi1>
, i.e. a load of a single bit. That doesn't feel sane.
Also, in cases like this:
%x = vector.load ... : vector<8xi1>
%y = vector.extract %x [5]: i1
vector load is probably just a scalar load anyway.
My suggestion is to restrict this patter to multi-byte element types (*) and rely on "narrow-type-emulation" to help with sub-bytes.
(*) Multi-byte - at least one byte.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @kuhar, those examples were helpful! I'm still kind of borderline but let’s move forward with this as an independent pattern. The proliferation of dangling “populate” methods is concerning but this case may be worth it.
The original load is efficient because you are accessing a full dword. However, if you turn it into memref.load ... : i8, you may no longer know,
For that example, I would expect the alignment information to be explicit somewhere as vector.load
doesn’t have any default alignment. In the presence of no alignment information, I’m still not sure this transformation is dropping information.
For example, the buffer instruction on amdgpu allow you to get a default value for any OOB accesses. Looking at the example above, it could be that only the last byte is OOB, but this alone makes the whole vector<4xi8> have the default value. If you no longer load that last byte, the access would be in-bounds and you would observe a different value.
Yes but we can’t attribute hardware-specific semantics to vector.load
. We allow OOB reads to accommodate those targets that can “handle” OOB accesses. However, we can’t make assumptions on what the target will do or the actual values of those OOB elements. Doc may need some refinement but we defined it along those lines:
Representation-wise, the ‘vector.load’ operation permits out-of-bounds reads. Support and implementation of out-of-bounds vector loads is target-specific. No assumptions should be made on the value of elements loaded out of bounds. Not all targets may support out-of-bounds vector loads.
A valid lowering of vector.load
could be a scalarized version of it that is checking element by element if it’s OOB and only load in-bounds elements so the OOB accesses might not happen. I'd even say that OOB accesses are not observable as using the OOB elements should be poison, right? I think the behavior you are describing would better fit a masked vector load where the masked-off elements (OOB) are replaced with a padding value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we are actually need any special handling or tests for subbyte types. The only ways we can have load ... vector<8xi1> are either loading from memref<...xi1> for which semantics is fully consistent, or loading from memref<...xvector<8xi1>> which is ignored by current pattern.
I'd be surprised if there is no issue with the data layout as the vector one assumes a packed layout and the scalar one would be unpacked. Looking at the generated LLVM IR for both cases would help
if (!loadOp) | ||
return rewriter.notifyMatchFailure(op, "not a load op"); | ||
|
||
if (!loadOp->hasOneUse()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If moving this to canonicalization, I would add a comment here stating that this condition is the one that makes this a canonicalization pattern and shouldn't be changed.
|
||
Value offset = getValueOrCreateConstantIndexOp(rewriter, loc, pos); | ||
|
||
auto ovf = arith::IntegerOverflowFlags::nsw; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess if the vector load is loading A[+0, +1, +2, +3], it's safe to say that the address of any independent element won't overflow... so nsw sounds ok to me?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding making it a canonicalization, this is a bit controversial topic and there are some folks who would probably disagree. I would prefer to merge it as separate patterns first and then have a dedicated PR and discussion on promoting to canonicalization. Also, this way we will be able to revert them independently, if needed.
(I will update code to the other comments after I'm back from euroLLVM)
/// ``` | ||
/// %c1 = arith.constant 1 : index | ||
/// %0 = arith.addi %arg1, %c1 overflow<nsw> : index | ||
/// %1 = memref.load %arg0[%0] : memref<?xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One potential usecase where keeping vector.load + extract may be useful is when we are loading vector on aligned address for perf reasons and then using extract with offset to get unaligned data. I don't have such examples in practice, though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@Hardcode84 maybe add a test case that extract a single bit from |
6363446
to
aaf353d
Compare
@banach-space @dcaballe updated the code, please take a look |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates - this looks great!
I've left a few final nits and a suggestion to avoid allowing this for vectors of bits. I hope that makes sense and is sufficient for your use case.
if (rankOffset < 0) | ||
return rewriter.notifyMatchFailure(op, "unsupported ranks combination"); | ||
|
||
auto resVecType = dyn_cast<VectorType>(op.getResult().getType()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, that wasn't obvious to me. Perhaps resVecType
-> extractVecType
? This would make it contrast with loadVecType
quite nicely.
/// ``` | ||
/// %c1 = arith.constant 1 : index | ||
/// %0 = arith.addi %arg1, %c1 overflow<nsw> : index | ||
/// %1 = memref.load %arg0[%0] : memref<?xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Applying this pattern to a vector of bits would lead to memref.load %src[%idx] : memref<8xi1>
, i.e. a load of a single bit. That doesn't feel sane.
Also, in cases like this:
%x = vector.load ... : vector<8xi1>
%y = vector.extract %x [5]: i1
vector load is probably just a scalar load anyway.
My suggestion is to restrict this patter to multi-byte element types (*) and rely on "narrow-type-emulation" to help with sub-bytes.
(*) Multi-byte - at least one byte.
Side comment from the offline discussion: |
069322b
to
2febded
Compare
@banach-space @dcaballe I disabled pattern for non-byte aligned types for now |
``` vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32> vector.extract %0[1] : f32 from vector<4xf32> ``` Gets converted to: ``` %c1 = arith.constant 1 : index %0 = arith.addi %arg1, %c1 overflow<nsw> : index %1 = memref.load %arg0[%0] : memref<?xf32> ``` ``` %0 = vector.splat %arg2 : vector<1xf32> vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32> ``` Gets converted to: ``` memref.store %arg2, %arg0[%arg1] : memref<?xf32> ```
2febded
to
cfaef9d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me, thanks for addressing all my comments! I've left a few minor suggestions that I’d appreciate you addressing before landing, but nothing major.
AFAIK, Diego is currently traveling, but from what I can tell, he’s fine with not treating these patterns as canonicalizations for now:
I'm still kind of borderline, but let’s move forward with this as an independent pattern.
With that in mind, would you mind adding a TODO to revisit the idea of making these canonicalizations in the future? Just to reflect that some of us (myself included) would be in favor — we’re just not blocking this for now.
Once these are addressed, this should be good to land (provided there are no new comments).
mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
Outdated
Show resolved
Hide resolved
mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
Outdated
Show resolved
Hide resolved
…4389) ``` vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32> vector.extract %0[1] : f32 from vector<4xf32> ``` Gets converted to: ``` %c1 = arith.constant 1 : index %0 = arith.addi %arg1, %c1 overflow<nsw> : index %1 = memref.load %arg0[%0] : memref<?xf32> ``` ``` %0 = vector.splat %arg2 : vector<1xf32> vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32> ``` Gets converted to: ``` memref.store %arg2, %arg0[%arg1] : memref<?xf32> ```
…4389) ``` vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32> vector.extract %0[1] : f32 from vector<4xf32> ``` Gets converted to: ``` %c1 = arith.constant 1 : index %0 = arith.addi %arg1, %c1 overflow<nsw> : index %1 = memref.load %arg0[%0] : memref<?xf32> ``` ``` %0 = vector.splat %arg2 : vector<1xf32> vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32> ``` Gets converted to: ``` memref.store %arg2, %arg0[%arg1] : memref<?xf32> ```
…4389) ``` vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32> vector.extract %0[1] : f32 from vector<4xf32> ``` Gets converted to: ``` %c1 = arith.constant 1 : index %0 = arith.addi %arg1, %c1 overflow<nsw> : index %1 = memref.load %arg0[%0] : memref<?xf32> ``` ``` %0 = vector.splat %arg2 : vector<1xf32> vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32> ``` Gets converted to: ``` memref.store %arg2, %arg0[%arg1] : memref<?xf32> ```
Gets converted to:
Gets converted to: