Skip to content

[mlir][vector] Sink vector.extract/splat into load/store ops #134389

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 22, 2025

Conversation

Hardcode84
Copy link
Contributor

vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
vector.extract %0[1] : f32 from vector<4xf32>

Gets converted to:

%c1 = arith.constant 1 : index
%0 = arith.addi %arg1, %c1 overflow<nsw> : index
%1 = memref.load %arg0[%0] : memref<?xf32>
%0 = vector.splat %arg2 : vector<1xf32>
vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>

Gets converted to:

memref.store %arg2, %arg0[%arg1] : memref<?xf32>

@llvmbot
Copy link
Member

llvmbot commented Apr 4, 2025

@llvm/pr-subscribers-mlir-arith

@llvm/pr-subscribers-mlir

Author: Ivan Butygin (Hardcode84)

Changes
vector.load %arg0[%arg1] : memref&lt;?xf32&gt;, vector&lt;4xf32&gt;
vector.extract %0[1] : f32 from vector&lt;4xf32&gt;

Gets converted to:

%c1 = arith.constant 1 : index
%0 = arith.addi %arg1, %c1 overflow&lt;nsw&gt; : index
%1 = memref.load %arg0[%0] : memref&lt;?xf32&gt;
%0 = vector.splat %arg2 : vector&lt;1xf32&gt;
vector.store %0, %arg0[%arg1] : memref&lt;?xf32&gt;, vector&lt;1xf32&gt;

Gets converted to:

memref.store %arg2, %arg0[%arg1] : memref&lt;?xf32&gt;

Patch is 20.19 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/134389.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td (+22-2)
  • (modified) mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h (+14)
  • (modified) mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp (+5)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+127)
  • (modified) mlir/test/Dialect/Vector/vector-sink-transform.mlir (+1)
  • (modified) mlir/test/Dialect/Vector/vector-sink.mlir (+189)
  • (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+1)
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index f46aa0428f12f..7fbb437908866 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -469,8 +469,28 @@ def ApplySinkVectorPatternsOp : Op<Transform_Dialect,
     %0 = arith.addf %a, %b : vector<4x2xf32>
     %r = vector.transpose %0, [1, 0] : vector<2x4xf32>
     ```
-    At the moment, these patterns are limited to vector.broadcast and
-    vector.transpose.
+    At the moment, these patterns are limited to vector.broadcast,
+    vector.transpose and vector.extract.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
+def ApplySinkVectorMemPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.vector.sink_mem_ops",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Patterns that remove redundant Vector Ops by merging them with load/store
+    ops
+    ```
+    vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+    vector.extract %0[1] : f32 from vector<4xf32>
+    ```
+    Gets converted to:
+    ```
+    %c1 = arith.constant 1 : index
+    %0 = arith.addi %arg1, %c1 overflow<nsw> : index
+    %1 = memref.load %arg0[%0] : memref<?xf32>
   }];
 
   let assemblyFormat = "attr-dict";
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 7de4a6a315750..2d8b12c871be7 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -161,6 +161,20 @@ void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
 void populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
                                    PatternBenefit benefit = 1);
 
+/// Patterns that remove redundant Vector Ops by merging them with load/store
+/// ops
+/// ```
+/// vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+/// vector.extract %0[1] : f32 from vector<4xf32>
+/// ```
+/// Gets converted to:
+/// ```
+/// %c1 = arith.constant 1 : index
+/// %0 = arith.addi %arg1, %c1 overflow<nsw> : index
+/// %1 = memref.load %arg0[%0] : memref<?xf32>
+void populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns,
+                                      PatternBenefit benefit = 1);
+
 /// Patterns that fold chained vector reductions. These patterns assume that
 /// elementwise operations (e.g., `arith.addf` with vector operands) are
 /// cheaper than vector reduction.
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 12dcf768dd928..a888d745be443 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -212,6 +212,11 @@ void transform::ApplySinkVectorPatternsOp::populatePatterns(
   vector::populateSinkVectorOpsPatterns(patterns);
 }
 
+void transform::ApplySinkVectorMemPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  vector::populateSinkVectorMemOpsPatterns(patterns);
+}
+
 //===----------------------------------------------------------------------===//
 // Transform op registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index b6fac80d871e6..697a4228b3a53 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1103,6 +1103,127 @@ class ExtractOpFromElementwise final
   }
 };
 
+/// Pattern to rewrite vector.extract(vector.load) -> vector/memref.load.
+/// ```
+///  vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+///  vector.extract %0[1] : f32 from vector<4xf32>
+/// ```
+/// Gets converted to:
+/// ```
+/// %c1 = arith.constant 1 : index
+/// %0 = arith.addi %arg1, %c1 overflow<nsw> : index
+/// %1 = memref.load %arg0[%0] : memref<?xf32>
+/// ```
+class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ExtractOp op,
+                                PatternRewriter &rewriter) const override {
+    auto loadOp = op.getVector().getDefiningOp<vector::LoadOp>();
+    if (!loadOp)
+      return rewriter.notifyMatchFailure(op, "not a load op");
+
+    if (!loadOp->hasOneUse())
+      return rewriter.notifyMatchFailure(op, "expected single op use");
+
+    VectorType memVecType = loadOp.getVectorType();
+    if (memVecType.isScalable())
+      return rewriter.notifyMatchFailure(op,
+                                         "scalable vectors are not supported");
+
+    MemRefType memType = loadOp.getMemRefType();
+    if (isa<VectorType>(memType.getElementType()))
+      return rewriter.notifyMatchFailure(
+          op, "memrefs of vectors are not supported");
+
+    int64_t rankOffset = memType.getRank() - memVecType.getRank();
+    if (rankOffset < 0)
+      return rewriter.notifyMatchFailure(op, "unsupported ranks combination");
+
+    auto resVecType = dyn_cast<VectorType>(op.getResult().getType());
+    int64_t finalRank = 0;
+    if (resVecType)
+      finalRank = resVecType.getRank();
+
+    SmallVector<Value> indices = loadOp.getIndices();
+    SmallVector<OpFoldResult> extractPos = op.getMixedPosition();
+
+    OpBuilder::InsertionGuard g(rewriter);
+    rewriter.setInsertionPoint(loadOp);
+    Location loc = loadOp.getLoc();
+    for (auto i : llvm::seq<int64_t>(rankOffset, indices.size() - finalRank)) {
+      OpFoldResult pos = extractPos[i - rankOffset];
+      if (isConstantIntValue(pos, 0))
+        continue;
+
+      Value offset = getValueOrCreateConstantIndexOp(rewriter, loc, pos);
+
+      auto ovf = arith::IntegerOverflowFlags::nsw;
+      indices[i] = rewriter.create<arith::AddIOp>(loc, indices[i], offset, ovf);
+    }
+
+    Value base = loadOp.getBase();
+    if (resVecType) {
+      rewriter.replaceOpWithNewOp<vector::LoadOp>(op, resVecType, base,
+                                                  indices);
+    } else {
+      rewriter.replaceOpWithNewOp<memref::LoadOp>(op, base, indices);
+    }
+    rewriter.eraseOp(loadOp);
+    return success();
+  }
+};
+
+/// Pattern to rewrite vector.store(vector.splat) -> vector/memref.store.
+/// ```
+/// %0 = vector.splat %arg2 : vector<1xf32>
+/// vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
+/// ```
+/// Gets converted to:
+/// ```
+/// memref.store %arg2, %arg0[%arg1] : memref<?xf32>
+/// ```
+class StoreFromSplat final : public OpRewritePattern<vector::StoreOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::StoreOp op,
+                                PatternRewriter &rewriter) const override {
+    VectorType vecType = op.getVectorType();
+    if (vecType.isScalable())
+      return rewriter.notifyMatchFailure(op,
+                                         "scalable vectors are not supported");
+
+    if (isa<VectorType>(op.getMemRefType().getElementType()))
+      return rewriter.notifyMatchFailure(
+          op, "memrefs of vectors are not supported");
+
+    if (vecType.getNumElements() != 1)
+      return rewriter.notifyMatchFailure(
+          op, "only 1-element, vectors are supported");
+
+    Operation *splat = op.getValueToStore().getDefiningOp();
+    if (!isa_and_present<vector::BroadcastOp, vector::SplatOp>(splat))
+      return rewriter.notifyMatchFailure(op, "not a splat");
+
+    if (!splat->hasOneUse())
+      return rewriter.notifyMatchFailure(op, "expected single op use");
+
+    Value source = splat->getOperand(0);
+    Value base = op.getBase();
+    ValueRange indices = op.getIndices();
+
+    if (isa<VectorType>(source.getType())) {
+      rewriter.replaceOpWithNewOp<vector::StoreOp>(op, source, base, indices);
+    } else {
+      rewriter.replaceOpWithNewOp<memref::StoreOp>(op, source, base, indices);
+    }
+    rewriter.eraseOp(splat);
+    return success();
+  }
+};
+
 // Helper that returns a vector comparison that constructs a mask:
 //     mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
 //
@@ -2175,6 +2296,12 @@ void mlir::vector::populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
       patterns.getContext(), benefit);
 }
 
+void mlir::vector::populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns,
+                                                    PatternBenefit benefit) {
+  patterns.add<ExtractOpFromLoad, StoreFromSplat>(patterns.getContext(),
+                                                  benefit);
+}
+
 void mlir::vector::populateChainedVectorReductionFoldingPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
   patterns.add<ChainedReduction>(patterns.getContext(), benefit);
diff --git a/mlir/test/Dialect/Vector/vector-sink-transform.mlir b/mlir/test/Dialect/Vector/vector-sink-transform.mlir
index ef17b69b2444c..4d04276742164 100644
--- a/mlir/test/Dialect/Vector/vector-sink-transform.mlir
+++ b/mlir/test/Dialect/Vector/vector-sink-transform.mlir
@@ -7,6 +7,7 @@ module attributes {transform.with_named_sequence} {
     %func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op
     transform.apply_patterns to %func {
       transform.apply_patterns.vector.sink_ops
+      transform.apply_patterns.vector.sink_mem_ops
     } : !transform.any_op
     transform.yield
   }
diff --git a/mlir/test/Dialect/Vector/vector-sink.mlir b/mlir/test/Dialect/Vector/vector-sink.mlir
index 8c8f1797aaab6..ad4fdbe0a7b5a 100644
--- a/mlir/test/Dialect/Vector/vector-sink.mlir
+++ b/mlir/test/Dialect/Vector/vector-sink.mlir
@@ -513,3 +513,192 @@ func.func @negative_extract_vec_fma(%arg0: vector<4xf32>, %arg1: vector<4xf32>,
   %1 = vector.extract %0[1] : f32 from vector<4xf32>
   return %1 : f32
 }
+
+//-----------------------------------------------------------------------------
+// [Pattern: ExtractOpFromLoad]
+//-----------------------------------------------------------------------------
+
+// CHECK-LABEL: @extract_load_scalar
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
+func.func @extract_load_scalar(%arg0: memref<?xf32>, %arg1: index) -> f32 {
+// CHECK:   %[[RES:.*]] = memref.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>
+// CHECK:   return %[[RES]] : f32
+  %0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+  %1 = vector.extract %0[0] : f32 from vector<4xf32>
+  return %1 : f32
+}
+
+// CHECK-LABEL: @extract_load_scalar_non_zero_off
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
+func.func @extract_load_scalar_non_zero_off(%arg0: memref<?xf32>, %arg1: index) -> f32 {
+// CHECK:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK:   %[[OFF:.*]] = arith.addi %[[ARG1]], %[[C1]] overflow<nsw> : index
+// CHECK:   %[[RES:.*]] = memref.load %[[ARG0]][%[[OFF]]] : memref<?xf32>
+// CHECK:   return %[[RES]] : f32
+  %0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+  %1 = vector.extract %0[1] : f32 from vector<4xf32>
+  return %1 : f32
+}
+
+// CHECK-LABEL: @extract_load_scalar_dyn_off
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+func.func @extract_load_scalar_dyn_off(%arg0: memref<?xf32>, %arg1: index, %arg2: index) -> f32 {
+// CHECK:   %[[OFF:.*]] = arith.addi %[[ARG1]], %[[ARG2]] overflow<nsw> : index
+// CHECK:   %[[RES:.*]] = memref.load %[[ARG0]][%[[OFF]]] : memref<?xf32>
+// CHECK:   return %[[RES]] : f32
+  %0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+  %1 = vector.extract %0[%arg2] : f32 from vector<4xf32>
+  return %1 : f32
+}
+
+// CHECK-LABEL: @extract_load_vec
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+func.func @extract_load_vec(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index) -> vector<4xf32> {
+// CHECK:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK:   %[[OFF:.*]] = arith.addi %[[ARG1]], %[[C1]] overflow<nsw> : index
+// CHECK:   %[[RES:.*]] = vector.load %[[ARG0]][%[[OFF]], %[[ARG2]]] : memref<?x?xf32>, vector<4xf32>
+// CHECK:   return %[[RES]] : vector<4xf32>
+  %0 = vector.load %arg0[%arg1, %arg2] : memref<?x?xf32>, vector<2x4xf32>
+  %1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
+  return %1 : vector<4xf32>
+}
+
+// CHECK-LABEL: @extract_load_scalar_high_rank
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+func.func @extract_load_scalar_high_rank(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index) -> f32 {
+// CHECK:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK:   %[[OFF:.*]] = arith.addi %[[ARG2]], %[[C1]] overflow<nsw> : index
+// CHECK:   %[[RES:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[OFF]]] : memref<?x?xf32>
+// CHECK:   return %[[RES]] : f32
+  %0 = vector.load %arg0[%arg1, %arg2] : memref<?x?xf32>, vector<4xf32>
+  %1 = vector.extract %0[1] : f32 from vector<4xf32>
+  return %1 : f32
+}
+
+// CHECK-LABEL: @extract_load_vec_high_rank
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?x?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
+func.func @extract_load_vec_high_rank(%arg0: memref<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index) -> vector<4xf32> {
+// CHECK:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK:   %[[OFF:.*]] = arith.addi %[[ARG2]], %[[C1]] overflow<nsw> : index
+// CHECK:   %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]], %[[OFF]], %[[ARG3]]] : memref<?x?x?xf32>, vector<4xf32>
+// CHECK:   return %[[RES]] : vector<4xf32>
+  %0 = vector.load %arg0[%arg1, %arg2, %arg3] : memref<?x?x?xf32>, vector<2x4xf32>
+  %1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
+  return %1 : vector<4xf32>
+}
+
+// CHECK-LABEL: @negative_load_scalar_from_vec_memref
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xvector<4xf32>>, %[[ARG1:.*]]: index)
+func.func @negative_load_scalar_from_vec_memref(%arg0: memref<?xvector<4xf32>>, %arg1: index) -> f32 {
+// CHECK:   %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xvector<4xf32>>, vector<4xf32>
+// CHECK:   %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<4xf32>
+// CHECK:   return %[[EXT]] : f32
+  %0 = vector.load %arg0[%arg1] : memref<?xvector<4xf32>>, vector<4xf32>
+  %1 = vector.extract %0[0] : f32 from vector<4xf32>
+  return %1 : f32
+}
+
+// CHECK-LABEL: @negative_extract_load_no_single_use
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
+func.func @negative_extract_load_no_single_use(%arg0: memref<?xf32>, %arg1: index) -> (f32, vector<4xf32>) {
+// CHECK:   %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<4xf32>
+// CHECK:   %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<4xf32>
+// CHECK:   return %[[EXT]], %[[RES]] : f32, vector<4xf32>
+  %0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+  %1 = vector.extract %0[0] : f32 from vector<4xf32>
+  return %1, %0 : f32, vector<4xf32>
+}
+
+// CHECK-LABEL: @negative_load_scalable
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
+func.func @negative_load_scalable(%arg0: memref<?xf32>, %arg1: index) -> f32 {
+// CHECK:   %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<[1]xf32>
+// CHECK:   %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<[1]xf32>
+// CHECK:   return %[[EXT]] : f32
+  %0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<[1]xf32>
+  %1 = vector.extract %0[0] : f32 from vector<[1]xf32>
+  return %1 : f32
+}
+
+// CHECK-LABEL: @negative_extract_load_unsupported_ranks
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
+func.func @negative_extract_load_unsupported_ranks(%arg0: memref<?xf32>, %arg1: index) -> vector<4xf32> {
+// CHECK:   %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<2x4xf32>
+// CHECK:   %[[EXT:.*]] = vector.extract %[[RES]][1] : vector<4xf32> from vector<2x4xf32>
+// CHECK:   return %[[EXT]] : vector<4xf32>
+  %0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<2x4xf32>
+  %1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
+  return %1 : vector<4xf32>
+}
+
+//-----------------------------------------------------------------------------
+// [Pattern: StoreFromSplat]
+//-----------------------------------------------------------------------------
+
+// CHECK-LABEL: @store_splat
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
+func.func @store_splat(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) {
+// CHECK:   memref.store %[[ARG2]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>
+  %0 = vector.splat %arg2 : vector<1xf32>
+  vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
+  return
+}
+
+// CHECK-LABEL: @store_broadcast
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
+func.func @store_broadcast(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) {
+// CHECK:   memref.store %[[ARG2]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>
+  %0 = vector.broadcast %arg2 : f32 to vector<1xf32>
+  vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
+  return
+}
+
+// CHECK-LABEL: @store_broadcast_1d_2d
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: vector<1xf32>)
+func.func @store_broadcast_1d_2d(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index, %arg3: vector<1xf32>) {
+// CHECK:   vector.store %[[ARG3]], %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<?x?xf32>, vector<1xf32>
+  %0 = vector.broadcast %arg3 : vector<1xf32> to vector<1x1xf32>
+  vector.store %0, %arg0[%arg1, %arg2] : memref<?x?xf32>, vector<1x1xf32>
+  return
+}
+
+// CHECK-LABEL: @negative_store_scalable
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
+func.func @negative_store_scalable(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) {
+// CHECK:   %[[RES:.*]] = vector.splat %[[ARG2]] : vector<[1]xf32>
+// CHECK:   vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<[1]xf32>
+  %0 = vector.splat %arg2 : vector<[1]xf32>
+  vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<[1]xf32>
+  return
+}
+
+// CHECK-LABEL: @negative_store_vec_memref
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xvector<1xf32>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
+func.func @negative_store_vec_memref(%arg0: memref<?xvector<1xf32>>, %arg1: index, %arg2: f32) {
+// CHECK:   %[[RES:.*]] = vector.splat %[[ARG2]] : vector<1xf32>
+// CHECK:   vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref<?xvector<1xf32>>, vector<1xf32>
+  %0 = vector.splat %arg2 : vector<1xf32>
+  vector.store %0, %arg0[%arg1] : memref<?xvector<1xf32>>, vector<1xf32>
+  return
+}
+
+// CHECK-LABEL: @negative_store_non_1
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
+func.func @negative_store_non_1(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) {
+// CHECK:   %[[RES:.*]] = vector.splat %[[ARG2]] : vector<4xf32>
+// CHECK:   vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<4xf32>
+  %0 = vector.splat %arg2 : vector<4xf32>
+  vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+  return
+}
+
+// CHECK-LABEL: @negative_store_no_single_use
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
+func.func @negative_store_no_single_use(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) -> vector<1xf32> {
+// CHECK:   %[[RES:.*]] = vector.splat %[[ARG2]] : vector<1xf32>
+// CHECK:   vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<1xf32>
+// CHECK:   return %[[RES:.*]] : vector<1xf32>
+  %0 = vector.splat %arg2 : vector<1xf32>
+  vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
+  return %0 : vector<1xf32>
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index a54ae816570a8..03f907e46c2c6 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -395,6 +395,7 @@ struct TestVectorSinkPatterns
   void runOnOperation() override {
     RewritePatternSet pa...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Apr 4, 2025

@llvm/pr-subscribers-mlir-vector

Author: Ivan Butygin (Hardcode84)

Changes
vector.load %arg0[%arg1] : memref&lt;?xf32&gt;, vector&lt;4xf32&gt;
vector.extract %0[1] : f32 from vector&lt;4xf32&gt;

Gets converted to:

%c1 = arith.constant 1 : index
%0 = arith.addi %arg1, %c1 overflow&lt;nsw&gt; : index
%1 = memref.load %arg0[%0] : memref&lt;?xf32&gt;
%0 = vector.splat %arg2 : vector&lt;1xf32&gt;
vector.store %0, %arg0[%arg1] : memref&lt;?xf32&gt;, vector&lt;1xf32&gt;

Gets converted to:

memref.store %arg2, %arg0[%arg1] : memref&lt;?xf32&gt;

Patch is 20.19 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/134389.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td (+22-2)
  • (modified) mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h (+14)
  • (modified) mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp (+5)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+127)
  • (modified) mlir/test/Dialect/Vector/vector-sink-transform.mlir (+1)
  • (modified) mlir/test/Dialect/Vector/vector-sink.mlir (+189)
  • (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+1)
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index f46aa0428f12f..7fbb437908866 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -469,8 +469,28 @@ def ApplySinkVectorPatternsOp : Op<Transform_Dialect,
     %0 = arith.addf %a, %b : vector<4x2xf32>
     %r = vector.transpose %0, [1, 0] : vector<2x4xf32>
     ```
-    At the moment, these patterns are limited to vector.broadcast and
-    vector.transpose.
+    At the moment, these patterns are limited to vector.broadcast,
+    vector.transpose and vector.extract.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
+def ApplySinkVectorMemPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.vector.sink_mem_ops",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Patterns that remove redundant Vector Ops by merging them with load/store
+    ops
+    ```
+    vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+    vector.extract %0[1] : f32 from vector<4xf32>
+    ```
+    Gets converted to:
+    ```
+    %c1 = arith.constant 1 : index
+    %0 = arith.addi %arg1, %c1 overflow<nsw> : index
+    %1 = memref.load %arg0[%0] : memref<?xf32>
   }];
 
   let assemblyFormat = "attr-dict";
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 7de4a6a315750..2d8b12c871be7 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -161,6 +161,20 @@ void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
 void populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
                                    PatternBenefit benefit = 1);
 
+/// Patterns that remove redundant Vector Ops by merging them with load/store
+/// ops
+/// ```
+/// vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+/// vector.extract %0[1] : f32 from vector<4xf32>
+/// ```
+/// Gets converted to:
+/// ```
+/// %c1 = arith.constant 1 : index
+/// %0 = arith.addi %arg1, %c1 overflow<nsw> : index
+/// %1 = memref.load %arg0[%0] : memref<?xf32>
+void populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns,
+                                      PatternBenefit benefit = 1);
+
 /// Patterns that fold chained vector reductions. These patterns assume that
 /// elementwise operations (e.g., `arith.addf` with vector operands) are
 /// cheaper than vector reduction.
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 12dcf768dd928..a888d745be443 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -212,6 +212,11 @@ void transform::ApplySinkVectorPatternsOp::populatePatterns(
   vector::populateSinkVectorOpsPatterns(patterns);
 }
 
+void transform::ApplySinkVectorMemPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  vector::populateSinkVectorMemOpsPatterns(patterns);
+}
+
 //===----------------------------------------------------------------------===//
 // Transform op registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index b6fac80d871e6..697a4228b3a53 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1103,6 +1103,127 @@ class ExtractOpFromElementwise final
   }
 };
 
+/// Pattern to rewrite vector.extract(vector.load) -> vector/memref.load.
+/// ```
+///  vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+///  vector.extract %0[1] : f32 from vector<4xf32>
+/// ```
+/// Gets converted to:
+/// ```
+/// %c1 = arith.constant 1 : index
+/// %0 = arith.addi %arg1, %c1 overflow<nsw> : index
+/// %1 = memref.load %arg0[%0] : memref<?xf32>
+/// ```
+class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ExtractOp op,
+                                PatternRewriter &rewriter) const override {
+    auto loadOp = op.getVector().getDefiningOp<vector::LoadOp>();
+    if (!loadOp)
+      return rewriter.notifyMatchFailure(op, "not a load op");
+
+    if (!loadOp->hasOneUse())
+      return rewriter.notifyMatchFailure(op, "expected single op use");
+
+    VectorType memVecType = loadOp.getVectorType();
+    if (memVecType.isScalable())
+      return rewriter.notifyMatchFailure(op,
+                                         "scalable vectors are not supported");
+
+    MemRefType memType = loadOp.getMemRefType();
+    if (isa<VectorType>(memType.getElementType()))
+      return rewriter.notifyMatchFailure(
+          op, "memrefs of vectors are not supported");
+
+    int64_t rankOffset = memType.getRank() - memVecType.getRank();
+    if (rankOffset < 0)
+      return rewriter.notifyMatchFailure(op, "unsupported ranks combination");
+
+    auto resVecType = dyn_cast<VectorType>(op.getResult().getType());
+    int64_t finalRank = 0;
+    if (resVecType)
+      finalRank = resVecType.getRank();
+
+    SmallVector<Value> indices = loadOp.getIndices();
+    SmallVector<OpFoldResult> extractPos = op.getMixedPosition();
+
+    OpBuilder::InsertionGuard g(rewriter);
+    rewriter.setInsertionPoint(loadOp);
+    Location loc = loadOp.getLoc();
+    for (auto i : llvm::seq<int64_t>(rankOffset, indices.size() - finalRank)) {
+      OpFoldResult pos = extractPos[i - rankOffset];
+      if (isConstantIntValue(pos, 0))
+        continue;
+
+      Value offset = getValueOrCreateConstantIndexOp(rewriter, loc, pos);
+
+      auto ovf = arith::IntegerOverflowFlags::nsw;
+      indices[i] = rewriter.create<arith::AddIOp>(loc, indices[i], offset, ovf);
+    }
+
+    Value base = loadOp.getBase();
+    if (resVecType) {
+      rewriter.replaceOpWithNewOp<vector::LoadOp>(op, resVecType, base,
+                                                  indices);
+    } else {
+      rewriter.replaceOpWithNewOp<memref::LoadOp>(op, base, indices);
+    }
+    rewriter.eraseOp(loadOp);
+    return success();
+  }
+};
+
+/// Pattern to rewrite vector.store(vector.splat) -> vector/memref.store.
+/// ```
+/// %0 = vector.splat %arg2 : vector<1xf32>
+/// vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
+/// ```
+/// Gets converted to:
+/// ```
+/// memref.store %arg2, %arg0[%arg1] : memref<?xf32>
+/// ```
+class StoreFromSplat final : public OpRewritePattern<vector::StoreOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::StoreOp op,
+                                PatternRewriter &rewriter) const override {
+    VectorType vecType = op.getVectorType();
+    if (vecType.isScalable())
+      return rewriter.notifyMatchFailure(op,
+                                         "scalable vectors are not supported");
+
+    if (isa<VectorType>(op.getMemRefType().getElementType()))
+      return rewriter.notifyMatchFailure(
+          op, "memrefs of vectors are not supported");
+
+    if (vecType.getNumElements() != 1)
+      return rewriter.notifyMatchFailure(
+          op, "only 1-element, vectors are supported");
+
+    Operation *splat = op.getValueToStore().getDefiningOp();
+    if (!isa_and_present<vector::BroadcastOp, vector::SplatOp>(splat))
+      return rewriter.notifyMatchFailure(op, "not a splat");
+
+    if (!splat->hasOneUse())
+      return rewriter.notifyMatchFailure(op, "expected single op use");
+
+    Value source = splat->getOperand(0);
+    Value base = op.getBase();
+    ValueRange indices = op.getIndices();
+
+    if (isa<VectorType>(source.getType())) {
+      rewriter.replaceOpWithNewOp<vector::StoreOp>(op, source, base, indices);
+    } else {
+      rewriter.replaceOpWithNewOp<memref::StoreOp>(op, source, base, indices);
+    }
+    rewriter.eraseOp(splat);
+    return success();
+  }
+};
+
 // Helper that returns a vector comparison that constructs a mask:
 //     mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
 //
@@ -2175,6 +2296,12 @@ void mlir::vector::populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
       patterns.getContext(), benefit);
 }
 
+void mlir::vector::populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns,
+                                                    PatternBenefit benefit) {
+  patterns.add<ExtractOpFromLoad, StoreFromSplat>(patterns.getContext(),
+                                                  benefit);
+}
+
 void mlir::vector::populateChainedVectorReductionFoldingPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
   patterns.add<ChainedReduction>(patterns.getContext(), benefit);
diff --git a/mlir/test/Dialect/Vector/vector-sink-transform.mlir b/mlir/test/Dialect/Vector/vector-sink-transform.mlir
index ef17b69b2444c..4d04276742164 100644
--- a/mlir/test/Dialect/Vector/vector-sink-transform.mlir
+++ b/mlir/test/Dialect/Vector/vector-sink-transform.mlir
@@ -7,6 +7,7 @@ module attributes {transform.with_named_sequence} {
     %func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op
     transform.apply_patterns to %func {
       transform.apply_patterns.vector.sink_ops
+      transform.apply_patterns.vector.sink_mem_ops
     } : !transform.any_op
     transform.yield
   }
diff --git a/mlir/test/Dialect/Vector/vector-sink.mlir b/mlir/test/Dialect/Vector/vector-sink.mlir
index 8c8f1797aaab6..ad4fdbe0a7b5a 100644
--- a/mlir/test/Dialect/Vector/vector-sink.mlir
+++ b/mlir/test/Dialect/Vector/vector-sink.mlir
@@ -513,3 +513,192 @@ func.func @negative_extract_vec_fma(%arg0: vector<4xf32>, %arg1: vector<4xf32>,
   %1 = vector.extract %0[1] : f32 from vector<4xf32>
   return %1 : f32
 }
+
+//-----------------------------------------------------------------------------
+// [Pattern: ExtractOpFromLoad]
+//-----------------------------------------------------------------------------
+
+// CHECK-LABEL: @extract_load_scalar
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
+func.func @extract_load_scalar(%arg0: memref<?xf32>, %arg1: index) -> f32 {
+// CHECK:   %[[RES:.*]] = memref.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>
+// CHECK:   return %[[RES]] : f32
+  %0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+  %1 = vector.extract %0[0] : f32 from vector<4xf32>
+  return %1 : f32
+}
+
+// CHECK-LABEL: @extract_load_scalar_non_zero_off
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
+func.func @extract_load_scalar_non_zero_off(%arg0: memref<?xf32>, %arg1: index) -> f32 {
+// CHECK:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK:   %[[OFF:.*]] = arith.addi %[[ARG1]], %[[C1]] overflow<nsw> : index
+// CHECK:   %[[RES:.*]] = memref.load %[[ARG0]][%[[OFF]]] : memref<?xf32>
+// CHECK:   return %[[RES]] : f32
+  %0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+  %1 = vector.extract %0[1] : f32 from vector<4xf32>
+  return %1 : f32
+}
+
+// CHECK-LABEL: @extract_load_scalar_dyn_off
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+func.func @extract_load_scalar_dyn_off(%arg0: memref<?xf32>, %arg1: index, %arg2: index) -> f32 {
+// CHECK:   %[[OFF:.*]] = arith.addi %[[ARG1]], %[[ARG2]] overflow<nsw> : index
+// CHECK:   %[[RES:.*]] = memref.load %[[ARG0]][%[[OFF]]] : memref<?xf32>
+// CHECK:   return %[[RES]] : f32
+  %0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+  %1 = vector.extract %0[%arg2] : f32 from vector<4xf32>
+  return %1 : f32
+}
+
+// CHECK-LABEL: @extract_load_vec
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+func.func @extract_load_vec(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index) -> vector<4xf32> {
+// CHECK:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK:   %[[OFF:.*]] = arith.addi %[[ARG1]], %[[C1]] overflow<nsw> : index
+// CHECK:   %[[RES:.*]] = vector.load %[[ARG0]][%[[OFF]], %[[ARG2]]] : memref<?x?xf32>, vector<4xf32>
+// CHECK:   return %[[RES]] : vector<4xf32>
+  %0 = vector.load %arg0[%arg1, %arg2] : memref<?x?xf32>, vector<2x4xf32>
+  %1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
+  return %1 : vector<4xf32>
+}
+
+// CHECK-LABEL: @extract_load_scalar_high_rank
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+func.func @extract_load_scalar_high_rank(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index) -> f32 {
+// CHECK:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK:   %[[OFF:.*]] = arith.addi %[[ARG2]], %[[C1]] overflow<nsw> : index
+// CHECK:   %[[RES:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[OFF]]] : memref<?x?xf32>
+// CHECK:   return %[[RES]] : f32
+  %0 = vector.load %arg0[%arg1, %arg2] : memref<?x?xf32>, vector<4xf32>
+  %1 = vector.extract %0[1] : f32 from vector<4xf32>
+  return %1 : f32
+}
+
+// CHECK-LABEL: @extract_load_vec_high_rank
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?x?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
+func.func @extract_load_vec_high_rank(%arg0: memref<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index) -> vector<4xf32> {
+// CHECK:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK:   %[[OFF:.*]] = arith.addi %[[ARG2]], %[[C1]] overflow<nsw> : index
+// CHECK:   %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]], %[[OFF]], %[[ARG3]]] : memref<?x?x?xf32>, vector<4xf32>
+// CHECK:   return %[[RES]] : vector<4xf32>
+  %0 = vector.load %arg0[%arg1, %arg2, %arg3] : memref<?x?x?xf32>, vector<2x4xf32>
+  %1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
+  return %1 : vector<4xf32>
+}
+
+// CHECK-LABEL: @negative_load_scalar_from_vec_memref
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xvector<4xf32>>, %[[ARG1:.*]]: index)
+func.func @negative_load_scalar_from_vec_memref(%arg0: memref<?xvector<4xf32>>, %arg1: index) -> f32 {
+// CHECK:   %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xvector<4xf32>>, vector<4xf32>
+// CHECK:   %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<4xf32>
+// CHECK:   return %[[EXT]] : f32
+  %0 = vector.load %arg0[%arg1] : memref<?xvector<4xf32>>, vector<4xf32>
+  %1 = vector.extract %0[0] : f32 from vector<4xf32>
+  return %1 : f32
+}
+
+// CHECK-LABEL: @negative_extract_load_no_single_use
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
+func.func @negative_extract_load_no_single_use(%arg0: memref<?xf32>, %arg1: index) -> (f32, vector<4xf32>) {
+// CHECK:   %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<4xf32>
+// CHECK:   %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<4xf32>
+// CHECK:   return %[[EXT]], %[[RES]] : f32, vector<4xf32>
+  %0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+  %1 = vector.extract %0[0] : f32 from vector<4xf32>
+  return %1, %0 : f32, vector<4xf32>
+}
+
+// CHECK-LABEL: @negative_load_scalable
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
+func.func @negative_load_scalable(%arg0: memref<?xf32>, %arg1: index) -> f32 {
+// CHECK:   %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<[1]xf32>
+// CHECK:   %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<[1]xf32>
+// CHECK:   return %[[EXT]] : f32
+  %0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<[1]xf32>
+  %1 = vector.extract %0[0] : f32 from vector<[1]xf32>
+  return %1 : f32
+}
+
+// CHECK-LABEL: @negative_extract_load_unsupported_ranks
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
+func.func @negative_extract_load_unsupported_ranks(%arg0: memref<?xf32>, %arg1: index) -> vector<4xf32> {
+// CHECK:   %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<2x4xf32>
+// CHECK:   %[[EXT:.*]] = vector.extract %[[RES]][1] : vector<4xf32> from vector<2x4xf32>
+// CHECK:   return %[[EXT]] : vector<4xf32>
+  %0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<2x4xf32>
+  %1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
+  return %1 : vector<4xf32>
+}
+
+//-----------------------------------------------------------------------------
+// [Pattern: StoreFromSplat]
+//-----------------------------------------------------------------------------
+
+// CHECK-LABEL: @store_splat
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
+func.func @store_splat(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) {
+// CHECK:   memref.store %[[ARG2]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>
+  %0 = vector.splat %arg2 : vector<1xf32>
+  vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
+  return
+}
+
+// CHECK-LABEL: @store_broadcast
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
+func.func @store_broadcast(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) {
+// CHECK:   memref.store %[[ARG2]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>
+  %0 = vector.broadcast %arg2 : f32 to vector<1xf32>
+  vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
+  return
+}
+
+// CHECK-LABEL: @store_broadcast_1d_2d
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: vector<1xf32>)
+func.func @store_broadcast_1d_2d(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index, %arg3: vector<1xf32>) {
+// CHECK:   vector.store %[[ARG3]], %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<?x?xf32>, vector<1xf32>
+  %0 = vector.broadcast %arg3 : vector<1xf32> to vector<1x1xf32>
+  vector.store %0, %arg0[%arg1, %arg2] : memref<?x?xf32>, vector<1x1xf32>
+  return
+}
+
+// CHECK-LABEL: @negative_store_scalable
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
+func.func @negative_store_scalable(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) {
+// CHECK:   %[[RES:.*]] = vector.splat %[[ARG2]] : vector<[1]xf32>
+// CHECK:   vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<[1]xf32>
+  %0 = vector.splat %arg2 : vector<[1]xf32>
+  vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<[1]xf32>
+  return
+}
+
+// CHECK-LABEL: @negative_store_vec_memref
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xvector<1xf32>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
+func.func @negative_store_vec_memref(%arg0: memref<?xvector<1xf32>>, %arg1: index, %arg2: f32) {
+// CHECK:   %[[RES:.*]] = vector.splat %[[ARG2]] : vector<1xf32>
+// CHECK:   vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref<?xvector<1xf32>>, vector<1xf32>
+  %0 = vector.splat %arg2 : vector<1xf32>
+  vector.store %0, %arg0[%arg1] : memref<?xvector<1xf32>>, vector<1xf32>
+  return
+}
+
+// CHECK-LABEL: @negative_store_non_1
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
+func.func @negative_store_non_1(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) {
+// CHECK:   %[[RES:.*]] = vector.splat %[[ARG2]] : vector<4xf32>
+// CHECK:   vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<4xf32>
+  %0 = vector.splat %arg2 : vector<4xf32>
+  vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<4xf32>
+  return
+}
+
+// CHECK-LABEL: @negative_store_no_single_use
+//  CHECK-SAME:   (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
+func.func @negative_store_no_single_use(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) -> vector<1xf32> {
+// CHECK:   %[[RES:.*]] = vector.splat %[[ARG2]] : vector<1xf32>
+// CHECK:   vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<1xf32>
+// CHECK:   return %[[RES:.*]] : vector<1xf32>
+  %0 = vector.splat %arg2 : vector<1xf32>
+  vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
+  return %0 : vector<1xf32>
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index a54ae816570a8..03f907e46c2c6 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -395,6 +395,7 @@ struct TestVectorSinkPatterns
   void runOnOperation() override {
     RewritePatternSet pa...
[truncated]

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thank you for contributing this!

I was a bit surprised that vector.load and vector.store only accept MemRef(s). @dcaballe , since you added vector.store, do you know the context? That would help us position ApplySinkVectorMemPatternsOp accordingly (e.g., do we need a dedicated TD Op?).

More comments inline.

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice transformation! As mentioned inline, I think we should move these patters to the canonicalizer. Other than that, LGTM!

/// ```
/// %c1 = arith.constant 1 : index
/// %0 = arith.addi %arg1, %c1 overflow<nsw> : index
/// %1 = memref.load %arg0[%0] : memref<?xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be a canonicalization pattern iff there's only one use which is a vector.extract. I can't think of a reason why we would want to load the redundant elements. I would clearly document that this only applies to cases with one use/extract op.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be a canonicalization pattern iff there's only one use which is a vector.extract.

No objections from me, but from a purely maintenance point of view, I'd leave the implementation and most of the tests where they are. Otherwise, we risk "bloating" canonicalization.mlir and e.g. VectorOps.cpp.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One potential usecase where keeping vector.load + extract may be useful is when we are loading vector on aligned address for perf reasons and then using extract with offset to get unaligned data. I don't have such examples in practice, though.

Copy link
Member

@kuhar kuhar Apr 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't think of a reason why we would want to load the redundant elements

Increasing the granularity of memory accesses may cause you not to be able to use wider load/store instructions, and undoing this later on and proving that you can use a wider memory access may be hard. We'd be losing information about how many bits are dereferencable and potentially misaligning the access.
This may also change the semantics of load instructions that support OOB behavior -- you can turn an OOB access into an in-bounds access.

For this reason, I don't think this should be on by default.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also vector.extract ... [5] : vector<8xi1>. Applying the pattern in this case means loses byte alignment which also doesn't seem like a good fit for a canonicalization to me.

Copy link
Member

@kuhar kuhar Apr 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pattern is constrained to using/extracting only one element so we wouldn't be dropping access pattern information for that case, right? Do you have something else in mind?

Say one of your memory regions is dword-sized but your memory accesses take byte offsets:

%x = vector.load ... : vector<4xi8>
%y = vector.extract %x [2]: i8

The original load is efficient because you are accessing a full dword. However, if you turn it into memref.load ... : i8, you may no longer know, once the index calculation simplifies with something else, that to get an aligned dword load you you need to also load the preceding bytes vs. only the bytes following this i8 (unaligned). You could resolve that with some masking + shifting, but that comes with some overhead.

Could you help me understand this? We should be able to remove any load operation that is dead, regardless of whether it's in-bounds or OOB, right? What makes this case different?

For example, the buffer instruction on amdgpu allow you to get a default value for any OOB accesses. Looking at the example above, it could be that only the last byte is OOB, but this alone makes the whole vector<4xi8> have the default value. If you no longer load that last byte, the access would be in-bounds and you would observe a different value.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we are actually need any special handling or tests for subbyte types. The only ways we can have load ... vector<8xi1> are either loading from memref<...xi1> for which semantics is fully consistent, or loading from memref<...xvector<8xi1>> which is ignored by current pattern.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Applying this pattern to a vector of bits would lead to memref.load %src[%idx] : memref<8xi1>, i.e. a load of a single bit. That doesn't feel sane.

Also, in cases like this:

%x = vector.load ... : vector<8xi1>
%y = vector.extract %x [5]: i1

vector load is probably just a scalar load anyway.

My suggestion is to restrict this patter to multi-byte element types (*) and rely on "narrow-type-emulation" to help with sub-bytes.

(*) Multi-byte - at least one byte.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @kuhar, those examples were helpful! I'm still kind of borderline but let’s move forward with this as an independent pattern. The proliferation of dangling “populate” methods is concerning but this case may be worth it.

The original load is efficient because you are accessing a full dword. However, if you turn it into memref.load ... : i8, you may no longer know,

For that example, I would expect the alignment information to be explicit somewhere as vector.load doesn’t have any default alignment. In the presence of no alignment information, I’m still not sure this transformation is dropping information.

For example, the buffer instruction on amdgpu allow you to get a default value for any OOB accesses. Looking at the example above, it could be that only the last byte is OOB, but this alone makes the whole vector<4xi8> have the default value. If you no longer load that last byte, the access would be in-bounds and you would observe a different value.

Yes but we can’t attribute hardware-specific semantics to vector.load. We allow OOB reads to accommodate those targets that can “handle” OOB accesses. However, we can’t make assumptions on what the target will do or the actual values of those OOB elements. Doc may need some refinement but we defined it along those lines:

Representation-wise, the ‘vector.load’ operation permits out-of-bounds reads. Support and implementation of out-of-bounds vector loads is target-specific. No assumptions should be made on the value of elements loaded out of bounds. Not all targets may support out-of-bounds vector loads. 

A valid lowering of vector.load could be a scalarized version of it that is checking element by element if it’s OOB and only load in-bounds elements so the OOB accesses might not happen. I'd even say that OOB accesses are not observable as using the OOB elements should be poison, right? I think the behavior you are describing would better fit a masked vector load where the masked-off elements (OOB) are replaced with a padding value.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we are actually need any special handling or tests for subbyte types. The only ways we can have load ... vector<8xi1> are either loading from memref<...xi1> for which semantics is fully consistent, or loading from memref<...xvector<8xi1>> which is ignored by current pattern.

I'd be surprised if there is no issue with the data layout as the vector one assumes a packed layout and the scalar one would be unpacked. Looking at the generated LLVM IR for both cases would help

if (!loadOp)
return rewriter.notifyMatchFailure(op, "not a load op");

if (!loadOp->hasOneUse())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If moving this to canonicalization, I would add a comment here stating that this condition is the one that makes this a canonicalization pattern and shouldn't be changed.


Value offset = getValueOrCreateConstantIndexOp(rewriter, loc, pos);

auto ovf = arith::IntegerOverflowFlags::nsw;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess if the vector load is loading A[+0, +1, +2, +3], it's safe to say that the address of any independent element won't overflow... so nsw sounds ok to me?

Copy link
Contributor Author

@Hardcode84 Hardcode84 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding making it a canonicalization, this is a bit controversial topic and there are some folks who would probably disagree. I would prefer to merge it as separate patterns first and then have a dedicated PR and discussion on promoting to canonicalization. Also, this way we will be able to revert them independently, if needed.

(I will update code to the other comments after I'm back from euroLLVM)

/// ```
/// %c1 = arith.constant 1 : index
/// %0 = arith.addi %arg1, %c1 overflow<nsw> : index
/// %1 = memref.load %arg0[%0] : memref<?xf32>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One potential usecase where keeping vector.load + extract may be useful is when we are loading vector on aligned address for perf reasons and then using extract with offset to get unaligned data. I don't have such examples in practice, though.

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@kuhar
Copy link
Member

kuhar commented Apr 11, 2025

@Hardcode84 maybe add a test case that extract a single bit from vector<8xi1> to show what happens in the sub-byte case?

@Hardcode84
Copy link
Contributor Author

@banach-space @dcaballe updated the code, please take a look

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates - this looks great!

I've left a few final nits and a suggestion to avoid allowing this for vectors of bits. I hope that makes sense and is sufficient for your use case.

if (rankOffset < 0)
return rewriter.notifyMatchFailure(op, "unsupported ranks combination");

auto resVecType = dyn_cast<VectorType>(op.getResult().getType());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, that wasn't obvious to me. Perhaps resVecType -> extractVecType? This would make it contrast with loadVecType quite nicely.

/// ```
/// %c1 = arith.constant 1 : index
/// %0 = arith.addi %arg1, %c1 overflow<nsw> : index
/// %1 = memref.load %arg0[%0] : memref<?xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Applying this pattern to a vector of bits would lead to memref.load %src[%idx] : memref<8xi1>, i.e. a load of a single bit. That doesn't feel sane.

Also, in cases like this:

%x = vector.load ... : vector<8xi1>
%y = vector.extract %x [5]: i1

vector load is probably just a scalar load anyway.

My suggestion is to restrict this patter to multi-byte element types (*) and rely on "narrow-type-emulation" to help with sub-bytes.

(*) Multi-byte - at least one byte.

@Hardcode84
Copy link
Contributor Author

Side comment from the offline discussion: vector.load/store currently supports optional OOB sematics while memref ops I will generate doesn't. So, technically, these patterns change semantics of the code, but I believe this semantics is not very useful in the current vague form, we should have an explicit flag in vector ops to enable/disable it. Explicit flag can also be useful during lowering (e.g. for AMDGPU buffer ops you need to explicitly decide if you want enable or disable OOB handling during lowering).

@Hardcode84
Copy link
Contributor Author

@banach-space @dcaballe I disabled pattern for non-byte aligned types for now

```
vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
vector.extract %0[1] : f32 from vector<4xf32>
```
Gets converted to:
```
%c1 = arith.constant 1 : index
%0 = arith.addi %arg1, %c1 overflow<nsw> : index
%1 = memref.load %arg0[%0] : memref<?xf32>
```

```
%0 = vector.splat %arg2 : vector<1xf32>
vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
```
Gets converted to:
```
memref.store %arg2, %arg0[%arg1] : memref<?xf32>
```
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me, thanks for addressing all my comments! I've left a few minor suggestions that I’d appreciate you addressing before landing, but nothing major.

AFAIK, Diego is currently traveling, but from what I can tell, he’s fine with not treating these patterns as canonicalizations for now:

I'm still kind of borderline, but let’s move forward with this as an independent pattern.

With that in mind, would you mind adding a TODO to revisit the idea of making these canonicalizations in the future? Just to reflect that some of us (myself included) would be in favor — we’re just not blocking this for now.

Once these are addressed, this should be good to land (provided there are no new comments).

@Hardcode84 Hardcode84 merged commit e87aa0c into llvm:main Apr 22, 2025
11 checks passed
@Hardcode84 Hardcode84 deleted the sink-vec-mem-ops branch April 22, 2025 14:18
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
…4389)

```
vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
vector.extract %0[1] : f32 from vector<4xf32>
```
Gets converted to:
```
%c1 = arith.constant 1 : index
%0 = arith.addi %arg1, %c1 overflow<nsw> : index
%1 = memref.load %arg0[%0] : memref<?xf32>
```

```
%0 = vector.splat %arg2 : vector<1xf32>
vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
```
Gets converted to:
```
memref.store %arg2, %arg0[%arg1] : memref<?xf32>
```
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
…4389)

```
vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
vector.extract %0[1] : f32 from vector<4xf32>
```
Gets converted to:
```
%c1 = arith.constant 1 : index
%0 = arith.addi %arg1, %c1 overflow<nsw> : index
%1 = memref.load %arg0[%0] : memref<?xf32>
```

```
%0 = vector.splat %arg2 : vector<1xf32>
vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
```
Gets converted to:
```
memref.store %arg2, %arg0[%arg1] : memref<?xf32>
```
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
…4389)

```
vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
vector.extract %0[1] : f32 from vector<4xf32>
```
Gets converted to:
```
%c1 = arith.constant 1 : index
%0 = arith.addi %arg1, %c1 overflow<nsw> : index
%1 = memref.load %arg0[%0] : memref<?xf32>
```

```
%0 = vector.splat %arg2 : vector<1xf32>
vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
```
Gets converted to:
```
memref.store %arg2, %arg0[%arg1] : memref<?xf32>
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants