Skip to content

[mlir][linalg] Refactor vectorization hooks to improve code reuse #141244

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

banach-space
Copy link
Contributor

@banach-space banach-space commented May 23, 2025

This patch refactors two vectorization hooks in Vectorization.cpp:

  • createWriteOrMaskedWrite gains a new parameter for write indices,
    aligning it with its counterpart createReadOrMaskedRead.
  • vectorizeAsInsertSliceOp is updated to reuse both of the above
    hooks, rather than re-implementing similar logic.

CONTEXT

This is effectively a refactoring of the logic for vectorizing
tensor.insert_slice. Recent updates added masking support:

At the time, reuse of the shared create* hooks wasn't feasible due to
missing parameters and overly rigid assumptions. This patch resolves
that and moves us closer to a more maintainable structure.

CHANGES IN createWriteOrMaskedWrite

  • Introduces a clear distinction between the destination tensor and the
    vector to store, via named variables like destType/vecToStoreType,
    destShape/vecToStoreShape, etc.
  • Ensures the correct rank and shape are used for attributes like
    in_bounds. For example, the size of the in_bounds attr now matches
    the source vector rank, not the tensor rank.
  • Drops the assumption that vecToStoreRank == destRank - this doesn't
    hold in many real examples.
  • Deduces mask dimensions from vecToStoreShape (vector) instead of
    destShape (tensor). (Eventually we should not require
    inputVecSizesForLeadingDims at all - mask shape should be inferred.)

NEW HELPER: isMaskTriviallyFoldable

Adds a utility to detect when masking is unnecessary. This avoids
inserting redundant masks and reduces the burden on canonicalization to
clean them up later.

Example where masking is provably unnecessary:

%2 = vector.mask %1 {
  vector.transfer_write %0, %arg1[%c0, %c0, %c0, %c0, %c0, %c0]
    {in_bounds = [true, true, true]}
    : vector<1x2x3xf32>, tensor<9x8x7x1x2x3xf32>
} : vector<1x2x3xi1> -> tensor<9x8x7x1x2x3xf32>

Also, without this hook, tests are more complicated and require more
matching.

VECTORIZATION BEHAVIOUR

This patch preserves the current behaviour around masking and the use
ofin_bounds attribute. Specifically:

  • useInBoundsInsteadOfMasking is set when no input vector sizes are
    available.
  • The vectorizer continues to infer vector sizes where needed.

Note: the computation of the in_bounds attribute is not always correct. That
issue is tracked here:

This will be addressed separately.

TEST CHANGES

Only affects vectorization of:

  • tensor.insert_slice (now refactored to use shared hooks)

Test diffs involve additional arith.constant Ops due to increased reuse of
shared helpers (which generate their own constants). This will be cleaned up
via constant caching (see #138265).

NOTE FOR REVIEWERS

This is a fairly substantial rewrite. You may find it easier to review
createWriteOrMaskedWrite as a new method rather than diffing
line-by-line.

TODOs (future PRs)

Further alignment of createWriteOrMaskedWrite and
createReadOrMaskedRead:

(*) This method will eventually be moved out of Vectorization.cpp, which
isn't the right long-term home for it.

@llvmbot
Copy link
Member

llvmbot commented May 23, 2025

@llvm/pr-subscribers-mlir-vector

Author: Andrzej Warzyński (banach-space)

Changes

This patch refactors two vectorization hooks in Vectorization.cpp:

  • createWriteOrMaskedWrite gains a new parameter for write indices,
    aligning it with its counterpart createReadOrMaskedRead.
  • vectorizeAsInsertSliceOp is updated to reuse both of the above
    hooks, rather than re-implementing similar logic.

CONTEXT

This is effectively a refactoring of the logic for vectorizing
tensor.insert_slice. Recent updates added masking support:

At the time, reuse of the shared create* hooks wasn't feasible due to
missing parameters and overly rigid assumptions. This patch resolves
that and moves us closer to a more maintainable structure.

CHANGES IN vectorizeAsInsertSliceOp

  • Introduces a clear distinction between the destination tensor and the
    vector to store, via named variables like destType/vecToStoreType,
    destShape/vecToStoreShape, etc.
  • Ensures the correct rank and shape are used for attributes like
    in_bounds. For example, the size of the in_bounds array now matches
    the source vector rank, not the tensor rank.
  • Drops the assumption that vecToStoreRank == destRank — this doesn't
    hold in many real examples.
  • Deduces mask dimensions from vecToStoreShape (vector) instead of
    destShape (tensor). (Eventually we should not require
    inputVecSizesForLeadingDims at all — mask shape should be inferred.)

NEW HELPER: isMaskTriviallyFoldable

Adds a utility to detect when masking is unnecessary. This avoids
inserting redundant masks and reduces the burden on canonicalization to
clean them up later.

Example where masking is provably unnecessary:

%2 = vector.mask %1 {
  vector.transfer_write %0, %arg1[%c0, %c0, %c0, %c0, %c0, %c0]
    {in_bounds = [true, true, true]}
    : vector&lt;1x2x3xf32&gt;, tensor&lt;9x8x7x1x2x3xf32&gt;
} : vector&lt;1x2x3xi1&gt; -&gt; tensor&lt;9x8x7x1x2x3xf32&gt;

Also, without this hook, tests are more complicated and require more
matching.

TEST CHANGES

This patch primarily affects vectorization of:

  • tensor.insert_slice, now refactored to use shared hooks.

tensor.pad vectorization patterns, which internally use
tensor.insert_slice, are also effectively updated. Note, only
pad-with-patterns.mlir is affected.

Most test updates involve the insertion of masks that were previously
missing — this reflects a correctness fix, not a regression. In all
cases, the added masks are indeed required.

You’ll also notice more repeated constants (arith.constant 0 : index),
due to increased use of helper hooks. This will be cleaned up separately
via a constant cache (see #138265 for discussion).

NOTE FOR REVIEWERS

This is a fairly substantial rewrite. You may find it easier to review
createWriteOrMaskedWrite as a new method rather than diffing
line-by-line.

TODOs (future PRs)

Further alignment of createWriteOrMaskedWrite and
createReadOrMaskedRead:

  • Move createWriteOrMaskedWrite next to createReadOrMaskedRead (in
    VectorUtils.cpp)
  • Make createReadOrMaskedRead leverage isMaskTriviallyFoldable.

(* This method will eventually be moved out of Vectorization.cpp, which isn't the right long-term home for it.)


Patch is 34.25 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/141244.diff

7 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+166-92)
  • (modified) mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp (+3-1)
  • (modified) mlir/test/Dialect/LLVM/transform-e2e.mlir (+6-4)
  • (modified) mlir/test/Dialect/Linalg/vectorization.mlir (-1)
  • (modified) mlir/test/Dialect/Linalg/vectorization/insert-slice-with-patterns.mlir (+10-1)
  • (modified) mlir/test/Dialect/Linalg/vectorization/insert-slice.mlir (+51-30)
  • (modified) mlir/test/Dialect/Linalg/vectorization/pad-with-patterns.mlir (+17-10)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index c5b62227777a7..0113ba86a5ae3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1506,20 +1506,104 @@ static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
   return applyPermutation(destShape, linalg::getPackInverseDestPerm(packOp));
 }
 
+/// Determines whether the mask for a corresponding `vector.transfer_write` op
+/// is trivially foldable (i.e., guaranteed to be all true).
+///
+/// Requirements:
+///   * All involved shapes (destination, mask) are static.
+///   * All write indices are constant.
+///   * All mask sizes are constant.
+///
+/// Once verified, the method checks for each destination dimension `d`:
+///   (1) destDimSize[rankDiff + d] <= maskShape[d]
+///   (2) destDimSize[rankDiff + d] <= writeIndex[d] + maskSize[d]
+///
+/// rankDiff = rank(dest) - rank(mask).
+///
+/// This method takes a conservative view: it may return false even if the mask
+/// is technically foldable.
+///
+/// EXAMPLE 1 (trivially foldable):
+///   %c0 = arith.constant 0 : index
+///   vector.transfer_write %vecToStore_1, %dest{[%c0, %c0]
+///     {in_bounds = [true, true]}
+///   : vector<5x1xi32>, tensor<5x1xi32>
+///
+/// EXAMPLE 2 (not trivially foldable - vector shape exceeds the tensor shape):
+///   %c0 = arith.constant 0 : index
+///   vector.transfer_write %vecToStore_2, %dest[%c0, %c0]
+///     {in_bounds = [true, true]}
+///   : vector<8x1xi32>, tensor<5x1xi32>
+///
+/// TODO: Re-use in createReadOrMaskedRead
+static bool isMaskTriviallyFoldable(SmallVector<OpFoldResult> &maskSizes,
+                                    SmallVector<Value> &writeIdxs,
+                                    ArrayRef<int64_t> destShape,
+                                    ArrayRef<int64_t> maskShape) {
+  // Masking is unavoidable in the case of dynamic tensors.
+  if (ShapedType::isDynamicShape(destShape))
+    return false;
+
+  // Collect all constant mask sizes.
+  SmallVector<int64_t, 4> cstMaskSizes;
+  for (auto [i, dimSize] : llvm::enumerate(maskSizes)) {
+    if (auto intSize = getConstantIntValue(dimSize)) {
+      cstMaskSizes.push_back(*intSize);
+    }
+  }
+
+  // If any of the mask sizes is non-constant, bail out.
+  if (cstMaskSizes.size() != maskShape.size())
+    return false;
+
+  // Collect all constant write indices.
+  SmallVector<int64_t, 4> cstWriteIdxs;
+  for (auto [i, idx] : llvm::enumerate(writeIdxs)) {
+    APSInt intVal;
+    if (matchPattern(idx, m_ConstantInt(&intVal))) {
+      cstWriteIdxs.push_back(intVal.getSExtValue());
+    }
+  }
+
+  // If any of the write indices is non-constant, bail out.
+  if (cstWriteIdxs.size() != destShape.size())
+    return false;
+
+  // Go over all destination dims and check (1) and (2). Take into account that:
+  //  * The number of mask sizes will match the rank of the vector to store.
+  //    This could be lower than the rank of the destination tensor.
+  //  * Mask sizes could be larger than the corresponding mask shape (hence
+  //  `clamp`).
+  // TODO: The 2nd item should be rejected by the verifier.
+  int64_t rankDiff = destShape.size() - cstMaskSizes.size();
+  for (auto [i, idx] : llvm::enumerate(cstMaskSizes)) {
+    if (/*(1)*/ maskShape[i] > destShape[rankDiff + i] ||
+        /*(2)*/ destShape[rankDiff + i] <
+            (std::clamp(cstMaskSizes[i], int64_t(0), maskShape[i]) +
+             cstWriteIdxs[i]))
+      return false;
+  }
+
+  return true;
+}
+
 /// Creates an optionally masked TransferWriteOp
 ///
 /// Generates the following operation:
 ///   %res = vector.transfer_write %vectorToStore into %dest
 ///
-/// If the leading N dimensions of the destination tensor do not match
+/// If the leading N dimensions of the vector to store do not match
 /// `inputVecSizesForLeadingDims` (N = rank(inputVecSizesForLeadingDims)),
 /// masking is applied to ensure correctness:
 ///
-///   %mask = vector.create_mask(%destShape)
+///   %mask = vector.create_mask(%destShape) : %vectorToStoreShape
 ///   %res = vector.mask %mask {
 ///     vector.transfer_write %vectorToStore into %dest
 ///   }
 ///
+/// The mask shape is identical to `vectorToStore` (with the element type ==
+/// i1), and the mask values are based on the shape of the `dest` tensor.
+///
 /// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
 /// is used instead of masking:
 ///
@@ -1528,75 +1612,99 @@ static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
 ///   %res = vector.transfer_write %input into %dest
 ///       {in_bounds = in_bounds_flags}
 ///
-/// NOTE: All write offsets are set to 0.
-/// TODO: Allow specyfying write offsets.
-/// NOTE: When N < rank(input), the missing vector sizes are effectively
-/// extracted from the trailing sizes of `destSizes`. This means those sizes
-/// must be static.
-/// TODO: Support cases where an arbitrary dim is dynamic - this will require
-/// specifying all the vector sizes.
+/// `writeIndices` specifies the offsets to use. If empty, all indices are set
+/// to 0.
+///
+/// NOTE: When N < rank(vectorToStore), the missing vector sizes are taken from
+/// `valueToStore`.
+/// TODO: `inputVecSizesForLeadingDims` should not be required - these sizes are
+/// already provided in `vectorToStore`.
 static Operation *
 createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
                          Value dest,
                          ArrayRef<int64_t> inputVecSizesForLeadingDims,
+                         SmallVector<Value> writeIndices = {},
                          bool useInBoundsInsteadOfMasking = false) {
 
   ShapedType destType = cast<ShapedType>(dest.getType());
-  assert(cast<VectorType>(vectorToStore.getType()).getRank() ==
-             static_cast<int64_t>(destType.getRank()) &&
-         "Rank mismatch!");
-  (void)destType;
+  int64_t destRank = destType.getRank();
+  auto destShape = destType.getShape();
 
-  int64_t rank = cast<ShapedType>(dest.getType()).getRank();
-  auto destShape = cast<ShapedType>(dest.getType()).getShape();
+  VectorType vecToStoreType = cast<VectorType>(vectorToStore.getType());
+  int64_t vecToStoreRank = vecToStoreType.getRank();
+  auto vecToStoreShape = vecToStoreType.getShape();
 
   // Compute the in_bounds attribute
-  SmallVector<bool> inBoundsVal(rank, true);
+  SmallVector<bool> inBoundsVal(vecToStoreRank, true);
   if (useInBoundsInsteadOfMasking) {
     // In this case, assume that all the required vector sizes have been
     // provided.
     assert(inputVecSizesForLeadingDims.size() ==
-               static_cast<size_t>(destType.getRank()) &&
+               static_cast<size_t>(vecToStoreType.getRank()) &&
            "Insufficient number of input vector sizes!");
     // Update the inBounds attribute.
-    for (unsigned i = 0; i < rank; i++)
+    for (unsigned i = 0; i < destRank; i++)
       inBoundsVal[i] = (destShape[i] == inputVecSizesForLeadingDims[i]) &&
                        !ShapedType::isDynamic(destShape[i]);
   }
 
+  // If missing, initialize the write indices to 0.
+  assert(writeIndices.empty() ||
+         writeIndices.size() == static_cast<size_t>(destRank) &&
+             "Invalid number of write indices!");
+  if (writeIndices.empty()) {
+    auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+    writeIndices = SmallVector<Value>(destRank, zero);
+  }
+
   // Generate the xfer_write Op
-  auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
-  Operation *write = builder.create<vector::TransferWriteOp>(
-      loc,
-      /*vector=*/vectorToStore,
-      /*source=*/dest,
-      /*indices=*/SmallVector<Value>(rank, zero),
-      /*inBounds=*/inBoundsVal);
-  assert(llvm::none_of(
-             destShape.drop_front(inputVecSizesForLeadingDims.size()),
-             [](int64_t size) { return size == ShapedType::kDynamic; }) &&
-         "Only dims aligned with inputVecSizesForLeadingDims may be dynamic");
+  Operation *write =
+      builder.create<vector::TransferWriteOp>(loc,
+                                              /*vector=*/vectorToStore,
+                                              /*source=*/dest,
+                                              /*indices=*/writeIndices,
+                                              /*inBounds=*/inBoundsVal);
 
   // If masking is disabled, exit.
   if (useInBoundsInsteadOfMasking)
     return write;
 
+  assert(llvm::none_of(
+             destShape.drop_front(inputVecSizesForLeadingDims.size()),
+             [](int64_t size) { return size == ShapedType::kDynamic; }) &&
+         "Only dims aligned with inputVecSizesForLeadingDims may be dynamic");
+
   // Check if masking is needed.
   bool needMaskForWrite =
       !llvm::equal(inputVecSizesForLeadingDims,
-                   destShape.take_front(inputVecSizesForLeadingDims.size()));
+                   destShape.take_front(destRank - vecToStoreRank +
+                                        inputVecSizesForLeadingDims.size()));
 
   // If masking is needed, generate the mask and mask the operation.
   if (needMaskForWrite) {
+    // Get the mask shape + type. Missing mask dimensions are taken from
+    // `vectorToStore`.
     SmallVector<int64_t> writeMaskShape;
     writeMaskShape.append(inputVecSizesForLeadingDims.begin(),
                           inputVecSizesForLeadingDims.end());
-    writeMaskShape.append(destShape.begin() +
-                              inputVecSizesForLeadingDims.size(),
-                          destShape.end());
+    if (vecToStoreRank >
+        static_cast<int64_t>(inputVecSizesForLeadingDims.size()))
+      writeMaskShape.append(vecToStoreShape.begin() +
+                                inputVecSizesForLeadingDims.size(),
+                            vecToStoreShape.end());
     auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
-    Value maskForWrite = builder.create<vector::CreateMaskOp>(
-        loc, writeMaskType, tensor::getMixedSizes(builder, loc, dest));
+
+    SmallVector<OpFoldResult> destSizes =
+        tensor::getMixedSizes(builder, loc, dest);
+    SmallVector<OpFoldResult> maskSizes(destSizes.end() - writeMaskShape.size(),
+                                        destSizes.end());
+
+    if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
+                                writeMaskShape))
+      return write;
+
+    Value maskForWrite = builder.createOrFold<vector::CreateMaskOp>(
+        loc, writeMaskType, maskSizes);
     write = mlir::vector::maskOperation(builder, write, maskForWrite);
   }
 
@@ -1700,10 +1808,10 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
   Value dest = rewriter.create<tensor::EmptyOp>(
       loc, reifiedReturnShapes[0],
       transposeOp.getResult().getType().getElementType());
-  Operation *write =
-      createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(), dest,
-                               /*inputVecSizesForLeadingDims=*/inputVectorSizes,
-                               /*useInBoundsInsteadOfMasking=*/false);
+  Operation *write = createWriteOrMaskedWrite(
+      rewriter, loc, transposeOp.getResult(), dest,
+      /*inputVecSizesForLeadingDims=*/inputVectorSizes, /*writeIndices=*/{},
+      /*useInBoundsInsteadOfMasking=*/false);
   newResults.push_back(write->getResult(0));
   return success();
 }
@@ -1839,10 +1947,10 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
   Value dest = rewriter.create<tensor::EmptyOp>(
       loc, reifiedRetShapes[0],
       shapeCastOp.getResult().getType().getElementType());
-  Operation *write =
-      createWriteOrMaskedWrite(rewriter, loc, shapeCastOp.getResult(), dest,
-                               /*inputVecSizesForLeadingDims=*/writeVectorSizes,
-                               useInBoundsInsteadOfMasking);
+  Operation *write = createWriteOrMaskedWrite(
+      rewriter, loc, shapeCastOp.getResult(), dest,
+      /*inputVecSizesForLeadingDims=*/writeVectorSizes,
+      /*writeIndices=*/{}, useInBoundsInsteadOfMasking);
   newResults.push_back(write->getResult(0));
   return success();
 }
@@ -1874,10 +1982,10 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
   // Create Xfer write Op
   Value dest = rewriter.create<tensor::EmptyOp>(
       loc, reifiedReturnShapes[0], padOp.getResultType().getElementType());
-  Operation *write =
-      createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest,
-                               /*inputVecSizesForLeadingDims=*/inputVectorSizes,
-                               /*useInBoundsInsteadOfMasking=*/false);
+  Operation *write = createWriteOrMaskedWrite(
+      rewriter, loc, maskedRead, dest,
+      /*inputVecSizesForLeadingDims=*/inputVectorSizes, {},
+      /*useInBoundsInsteadOfMasking=*/false);
   newResults.push_back(write->getResult(0));
   return success();
 }
@@ -2922,53 +3030,19 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
   auto vecType = VectorType::get(vecShape, sourceType.getElementType());
 
   // 3. Generate TransferReadOp + TransferWriteOp
-  ReifiedRankedShapedTypeDims reifiedSrcSizes;
-  Value maskOp;
-
-  // If vector sizes are user provided, make sure to mask. First, generate the
-  // mask.
-  if (!inputVectorSizes.empty()) {
-    auto *srcDefOp = source.getDefiningOp();
-    if (!srcDefOp) {
-      LDBG("Unable to get the defining Op of " << sliceOp);
-      return failure();
-    }
-
-    LogicalResult status =
-        cast<ReifyRankedShapedTypeOpInterface>(srcDefOp).reifyResultShapes(
-            rewriter, reifiedSrcSizes);
-    if (status.failed()) {
-      LDBG("Unable to reify result shapes of " << srcDefOp);
-      return failure();
-    }
-
-    // Create the mask
-    auto readMaskType = VectorType::get(inputVectorSizes, rewriter.getI1Type());
-    maskOp = rewriter.create<vector::CreateMaskOp>(
-        sliceOp.getLoc(), readMaskType, reifiedSrcSizes[0]);
-  }
+  auto loc = sliceOp.getLoc();
 
+  // Create read
   SmallVector<Value> readIndices(
-      vecType.getRank(),
-      rewriter.create<arith::ConstantIndexOp>(sliceOp.getLoc(), 0));
-  Operation *read = rewriter.create<vector::TransferReadOp>(
-      sliceOp.getLoc(), vecType, source, readIndices, padValue,
-      ArrayRef<bool>{readInBounds});
-
-  if (maskOp) {
-    read = mlir::vector::maskOperation(rewriter, read, maskOp);
-  }
-
-  auto writeIndices = getValueOrCreateConstantIndexOp(
-      rewriter, sliceOp.getLoc(), sliceOp.getMixedOffsets());
-
-  Operation *write = rewriter.create<vector::TransferWriteOp>(
-      sliceOp.getLoc(), read->getResult(0), sliceOp.getDest(), writeIndices,
-      ArrayRef<bool>{writeInBounds});
-
-  if (maskOp) {
-    write = mlir::vector::maskOperation(rewriter, write, maskOp);
-  }
+      vecType.getRank(), rewriter.create<arith::ConstantIndexOp>(loc, 0));
+  Value read = mlir::vector::createReadOrMaskedRead(
+      rewriter, loc, source, vecType.getShape(), padValue);
+
+  // Create write
+  auto writeIndices =
+      getValueOrCreateConstantIndexOp(rewriter, loc, sliceOp.getMixedOffsets());
+  Operation *write = createWriteOrMaskedWrite(
+      rewriter, loc, read, sliceOp.getDest(), vecType.getShape(), writeIndices);
 
   // 4. Finalize
   newResults.push_back(write->getResult(0));
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index d5dd6f2027be8..dda4856596bba 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -337,13 +337,13 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
   auto sourceShape = sourceShapedType.getShape();
   assert(sourceShape.size() == inputVectorSizes.size() &&
          "expected same ranks.");
-  auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type());
   auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());
   assert(padValue.getType() == sourceShapedType.getElementType() &&
          "expected same pad element type to match source element type");
   int64_t readRank = inputVectorSizes.size();
   auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
   SmallVector<bool> inBoundsVal(readRank, true);
+
   if (useInBoundsInsteadOfMasking) {
     // Update the inBounds attribute.
     for (unsigned i = 0; i < readRank; i++)
@@ -362,6 +362,8 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
     return transferReadOp;
   SmallVector<OpFoldResult> mixedSourceDims =
       tensor::getMixedSizes(builder, loc, source);
+
+  auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type());
   Value mask =
       builder.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
   return mlir::vector::maskOperation(builder, transferReadOp, mask)
diff --git a/mlir/test/Dialect/LLVM/transform-e2e.mlir b/mlir/test/Dialect/LLVM/transform-e2e.mlir
index c00b47fb936e9..98cfaf249c898 100644
--- a/mlir/test/Dialect/LLVM/transform-e2e.mlir
+++ b/mlir/test/Dialect/LLVM/transform-e2e.mlir
@@ -18,16 +18,14 @@ module attributes {transform.with_named_sequence} {
     %1, %loops:3 = transform.structured.tile_using_for %0 tile_sizes [2, 2, 2] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
     %2 = transform.get_parent_op %1 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
     transform.structured.vectorize_children_and_apply_patterns %2 : (!transform.any_op) -> !transform.any_op
-    %b = transform.bufferization.one_shot_bufferize layout{IdentityLayoutMap}
-        %module_op {bufferize_function_boundaries = true}
-        : (!transform.any_op) -> !transform.any_op
 
-    %f = transform.structured.match ops{["func.func"]} in %b
+    %f = transform.structured.match ops{["func.func"]} in %module_op
       : (!transform.any_op) -> !transform.any_op
 
     // TODO: group these lower-level controls into various properly named vector
     // lowering TD macros.
     transform.apply_patterns to %f {
+      transform.apply_patterns.vector.lower_masked_transfers
       transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
       transform.apply_patterns.vector.transfer_permutation_patterns
       transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerparallel"
@@ -37,6 +35,10 @@ module attributes {transform.with_named_sequence} {
       transform.apply_patterns.vector.lower_shape_cast
       transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_1d"
     } : !transform.any_op
+
+    %b = transform.bufferization.one_shot_bufferize layout{IdentityLayoutMap}
+        %module_op {bufferize_function_boundaries = true}
+        : (!transform.any_op) -> !transform.any_op
     transform.yield
   }
 }
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 8c6760fa50325..9a18f040d57cd 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -1027,4 +1027,3 @@ func.func @test_vectorize_unpack_no_vector_sizes_permute(%source: tensor<4x7x4xf
     transform.yield
   }
  }
-
diff --git a/mlir/test/Dialect/Linalg/vectorization/insert-slice-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization/insert-slice-with-patterns.mlir
index f7764be9be73f..d1f2ed194f6ce 100644
--- a/mlir/test/Dialect/Linalg/vectorization/insert-slice-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/insert-slice-with-patterns.mlir
@@ -67,10 +67,19 @@ module attributes {transform.with_named_sequence} {
 // CHECK-SAME:      %[[ARG_0:.*]]: tensor<1x?x3xf32>,
 // CHECK-SAME:      %[[PAD:.*]]: f32,
 // CHECK-SAME:      %[[SIZE:.*]]: index) -> tensor<9x8x7x1x2x3xf32> {
+// CHECK:           %[[C3:.*]] = arith.constant 3 : index
+// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
 // CHECK:           %[[EMPTY:.*]] = tensor.empty() : tensor<9x8x7x1x2x3xf32>
 // CHECK:           %[[BC:.*]] = vector.broadcast %[[PAD]] : f32 to vector<9x8x7x1x2x3xf32>
 // CHECK:           %[[WRITE:.*]] = vector.transfer_write %[[BC]], %[[EMPTY]]{{.*}} {in_bounds = [true, true, true, true, true, true]} : vector<9x8x7...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented May 23, 2025

@llvm/pr-subscribers-mlir-linalg

Author: Andrzej Warzyński (banach-space)

Changes

This patch refactors two vectorization hooks in Vectorization.cpp:

  • createWriteOrMaskedWrite gains a new parameter for write indices,
    aligning it with its counterpart createReadOrMaskedRead.
  • vectorizeAsInsertSliceOp is updated to reuse both of the above
    hooks, rather than re-implementing similar logic.

CONTEXT

This is effectively a refactoring of the logic for vectorizing
tensor.insert_slice. Recent updates added masking support:

At the time, reuse of the shared create* hooks wasn't feasible due to
missing parameters and overly rigid assumptions. This patch resolves
that and moves us closer to a more maintainable structure.

CHANGES IN vectorizeAsInsertSliceOp

  • Introduces a clear distinction between the destination tensor and the
    vector to store, via named variables like destType/vecToStoreType,
    destShape/vecToStoreShape, etc.
  • Ensures the correct rank and shape are used for attributes like
    in_bounds. For example, the size of the in_bounds array now matches
    the source vector rank, not the tensor rank.
  • Drops the assumption that vecToStoreRank == destRank — this doesn't
    hold in many real examples.
  • Deduces mask dimensions from vecToStoreShape (vector) instead of
    destShape (tensor). (Eventually we should not require
    inputVecSizesForLeadingDims at all — mask shape should be inferred.)

NEW HELPER: isMaskTriviallyFoldable

Adds a utility to detect when masking is unnecessary. This avoids
inserting redundant masks and reduces the burden on canonicalization to
clean them up later.

Example where masking is provably unnecessary:

%2 = vector.mask %1 {
  vector.transfer_write %0, %arg1[%c0, %c0, %c0, %c0, %c0, %c0]
    {in_bounds = [true, true, true]}
    : vector&lt;1x2x3xf32&gt;, tensor&lt;9x8x7x1x2x3xf32&gt;
} : vector&lt;1x2x3xi1&gt; -&gt; tensor&lt;9x8x7x1x2x3xf32&gt;

Also, without this hook, tests are more complicated and require more
matching.

TEST CHANGES

This patch primarily affects vectorization of:

  • tensor.insert_slice, now refactored to use shared hooks.

tensor.pad vectorization patterns, which internally use
tensor.insert_slice, are also effectively updated. Note, only
pad-with-patterns.mlir is affected.

Most test updates involve the insertion of masks that were previously
missing — this reflects a correctness fix, not a regression. In all
cases, the added masks are indeed required.

You’ll also notice more repeated constants (arith.constant 0 : index),
due to increased use of helper hooks. This will be cleaned up separately
via a constant cache (see #138265 for discussion).

NOTE FOR REVIEWERS

This is a fairly substantial rewrite. You may find it easier to review
createWriteOrMaskedWrite as a new method rather than diffing
line-by-line.

TODOs (future PRs)

Further alignment of createWriteOrMaskedWrite and
createReadOrMaskedRead:

  • Move createWriteOrMaskedWrite next to createReadOrMaskedRead (in
    VectorUtils.cpp)
  • Make createReadOrMaskedRead leverage isMaskTriviallyFoldable.

(* This method will eventually be moved out of Vectorization.cpp, which isn't the right long-term home for it.)


Patch is 34.25 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/141244.diff

7 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+166-92)
  • (modified) mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp (+3-1)
  • (modified) mlir/test/Dialect/LLVM/transform-e2e.mlir (+6-4)
  • (modified) mlir/test/Dialect/Linalg/vectorization.mlir (-1)
  • (modified) mlir/test/Dialect/Linalg/vectorization/insert-slice-with-patterns.mlir (+10-1)
  • (modified) mlir/test/Dialect/Linalg/vectorization/insert-slice.mlir (+51-30)
  • (modified) mlir/test/Dialect/Linalg/vectorization/pad-with-patterns.mlir (+17-10)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index c5b62227777a7..0113ba86a5ae3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1506,20 +1506,104 @@ static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
   return applyPermutation(destShape, linalg::getPackInverseDestPerm(packOp));
 }
 
+/// Determines whether the mask for a corresponding `vector.transfer_write` op
+/// is trivially foldable (i.e., guaranteed to be all true).
+///
+/// Requirements:
+///   * All involved shapes (destination, mask) are static.
+///   * All write indices are constant.
+///   * All mask sizes are constant.
+///
+/// Once verified, the method checks for each destination dimension `d`:
+///   (1) destDimSize[rankDiff + d] <= maskShape[d]
+///   (2) destDimSize[rankDiff + d] <= writeIndex[d] + maskSize[d]
+///
+/// rankDiff = rank(dest) - rank(mask).
+///
+/// This method takes a conservative view: it may return false even if the mask
+/// is technically foldable.
+///
+/// EXAMPLE 1 (trivially foldable):
+///   %c0 = arith.constant 0 : index
+///   vector.transfer_write %vecToStore_1, %dest{[%c0, %c0]
+///     {in_bounds = [true, true]}
+///   : vector<5x1xi32>, tensor<5x1xi32>
+///
+/// EXAMPLE 2 (not trivially foldable - vector shape exceeds the tensor shape):
+///   %c0 = arith.constant 0 : index
+///   vector.transfer_write %vecToStore_2, %dest[%c0, %c0]
+///     {in_bounds = [true, true]}
+///   : vector<8x1xi32>, tensor<5x1xi32>
+///
+/// TODO: Re-use in createReadOrMaskedRead
+static bool isMaskTriviallyFoldable(SmallVector<OpFoldResult> &maskSizes,
+                                    SmallVector<Value> &writeIdxs,
+                                    ArrayRef<int64_t> destShape,
+                                    ArrayRef<int64_t> maskShape) {
+  // Masking is unavoidable in the case of dynamic tensors.
+  if (ShapedType::isDynamicShape(destShape))
+    return false;
+
+  // Collect all constant mask sizes.
+  SmallVector<int64_t, 4> cstMaskSizes;
+  for (auto [i, dimSize] : llvm::enumerate(maskSizes)) {
+    if (auto intSize = getConstantIntValue(dimSize)) {
+      cstMaskSizes.push_back(*intSize);
+    }
+  }
+
+  // If any of the mask sizes is non-constant, bail out.
+  if (cstMaskSizes.size() != maskShape.size())
+    return false;
+
+  // Collect all constant write indices.
+  SmallVector<int64_t, 4> cstWriteIdxs;
+  for (auto [i, idx] : llvm::enumerate(writeIdxs)) {
+    APSInt intVal;
+    if (matchPattern(idx, m_ConstantInt(&intVal))) {
+      cstWriteIdxs.push_back(intVal.getSExtValue());
+    }
+  }
+
+  // If any of the write indices is non-constant, bail out.
+  if (cstWriteIdxs.size() != destShape.size())
+    return false;
+
+  // Go over all destination dims and check (1) and (2). Take into account that:
+  //  * The number of mask sizes will match the rank of the vector to store.
+  //    This could be lower than the rank of the destination tensor.
+  //  * Mask sizes could be larger than the corresponding mask shape (hence
+  //  `clamp`).
+  // TODO: The 2nd item should be rejected by the verifier.
+  int64_t rankDiff = destShape.size() - cstMaskSizes.size();
+  for (auto [i, idx] : llvm::enumerate(cstMaskSizes)) {
+    if (/*(1)*/ maskShape[i] > destShape[rankDiff + i] ||
+        /*(2)*/ destShape[rankDiff + i] <
+            (std::clamp(cstMaskSizes[i], int64_t(0), maskShape[i]) +
+             cstWriteIdxs[i]))
+      return false;
+  }
+
+  return true;
+}
+
 /// Creates an optionally masked TransferWriteOp
 ///
 /// Generates the following operation:
 ///   %res = vector.transfer_write %vectorToStore into %dest
 ///
-/// If the leading N dimensions of the destination tensor do not match
+/// If the leading N dimensions of the vector to store do not match
 /// `inputVecSizesForLeadingDims` (N = rank(inputVecSizesForLeadingDims)),
 /// masking is applied to ensure correctness:
 ///
-///   %mask = vector.create_mask(%destShape)
+///   %mask = vector.create_mask(%destShape) : %vectorToStoreShape
 ///   %res = vector.mask %mask {
 ///     vector.transfer_write %vectorToStore into %dest
 ///   }
 ///
+/// The mask shape is identical to `vectorToStore` (with the element type ==
+/// i1), and the mask values are based on the shape of the `dest` tensor.
+///
 /// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
 /// is used instead of masking:
 ///
@@ -1528,75 +1612,99 @@ static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
 ///   %res = vector.transfer_write %input into %dest
 ///       {in_bounds = in_bounds_flags}
 ///
-/// NOTE: All write offsets are set to 0.
-/// TODO: Allow specyfying write offsets.
-/// NOTE: When N < rank(input), the missing vector sizes are effectively
-/// extracted from the trailing sizes of `destSizes`. This means those sizes
-/// must be static.
-/// TODO: Support cases where an arbitrary dim is dynamic - this will require
-/// specifying all the vector sizes.
+/// `writeIndices` specifies the offsets to use. If empty, all indices are set
+/// to 0.
+///
+/// NOTE: When N < rank(vectorToStore), the missing vector sizes are taken from
+/// `valueToStore`.
+/// TODO: `inputVecSizesForLeadingDims` should not be required - these sizes are
+/// already provided in `vectorToStore`.
 static Operation *
 createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
                          Value dest,
                          ArrayRef<int64_t> inputVecSizesForLeadingDims,
+                         SmallVector<Value> writeIndices = {},
                          bool useInBoundsInsteadOfMasking = false) {
 
   ShapedType destType = cast<ShapedType>(dest.getType());
-  assert(cast<VectorType>(vectorToStore.getType()).getRank() ==
-             static_cast<int64_t>(destType.getRank()) &&
-         "Rank mismatch!");
-  (void)destType;
+  int64_t destRank = destType.getRank();
+  auto destShape = destType.getShape();
 
-  int64_t rank = cast<ShapedType>(dest.getType()).getRank();
-  auto destShape = cast<ShapedType>(dest.getType()).getShape();
+  VectorType vecToStoreType = cast<VectorType>(vectorToStore.getType());
+  int64_t vecToStoreRank = vecToStoreType.getRank();
+  auto vecToStoreShape = vecToStoreType.getShape();
 
   // Compute the in_bounds attribute
-  SmallVector<bool> inBoundsVal(rank, true);
+  SmallVector<bool> inBoundsVal(vecToStoreRank, true);
   if (useInBoundsInsteadOfMasking) {
     // In this case, assume that all the required vector sizes have been
     // provided.
     assert(inputVecSizesForLeadingDims.size() ==
-               static_cast<size_t>(destType.getRank()) &&
+               static_cast<size_t>(vecToStoreType.getRank()) &&
            "Insufficient number of input vector sizes!");
     // Update the inBounds attribute.
-    for (unsigned i = 0; i < rank; i++)
+    for (unsigned i = 0; i < destRank; i++)
       inBoundsVal[i] = (destShape[i] == inputVecSizesForLeadingDims[i]) &&
                        !ShapedType::isDynamic(destShape[i]);
   }
 
+  // If missing, initialize the write indices to 0.
+  assert(writeIndices.empty() ||
+         writeIndices.size() == static_cast<size_t>(destRank) &&
+             "Invalid number of write indices!");
+  if (writeIndices.empty()) {
+    auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+    writeIndices = SmallVector<Value>(destRank, zero);
+  }
+
   // Generate the xfer_write Op
-  auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
-  Operation *write = builder.create<vector::TransferWriteOp>(
-      loc,
-      /*vector=*/vectorToStore,
-      /*source=*/dest,
-      /*indices=*/SmallVector<Value>(rank, zero),
-      /*inBounds=*/inBoundsVal);
-  assert(llvm::none_of(
-             destShape.drop_front(inputVecSizesForLeadingDims.size()),
-             [](int64_t size) { return size == ShapedType::kDynamic; }) &&
-         "Only dims aligned with inputVecSizesForLeadingDims may be dynamic");
+  Operation *write =
+      builder.create<vector::TransferWriteOp>(loc,
+                                              /*vector=*/vectorToStore,
+                                              /*source=*/dest,
+                                              /*indices=*/writeIndices,
+                                              /*inBounds=*/inBoundsVal);
 
   // If masking is disabled, exit.
   if (useInBoundsInsteadOfMasking)
     return write;
 
+  assert(llvm::none_of(
+             destShape.drop_front(inputVecSizesForLeadingDims.size()),
+             [](int64_t size) { return size == ShapedType::kDynamic; }) &&
+         "Only dims aligned with inputVecSizesForLeadingDims may be dynamic");
+
   // Check if masking is needed.
   bool needMaskForWrite =
       !llvm::equal(inputVecSizesForLeadingDims,
-                   destShape.take_front(inputVecSizesForLeadingDims.size()));
+                   destShape.take_front(destRank - vecToStoreRank +
+                                        inputVecSizesForLeadingDims.size()));
 
   // If masking is needed, generate the mask and mask the operation.
   if (needMaskForWrite) {
+    // Get the mask shape + type. Missing mask dimensions are taken from
+    // `vectorToStore`.
     SmallVector<int64_t> writeMaskShape;
     writeMaskShape.append(inputVecSizesForLeadingDims.begin(),
                           inputVecSizesForLeadingDims.end());
-    writeMaskShape.append(destShape.begin() +
-                              inputVecSizesForLeadingDims.size(),
-                          destShape.end());
+    if (vecToStoreRank >
+        static_cast<int64_t>(inputVecSizesForLeadingDims.size()))
+      writeMaskShape.append(vecToStoreShape.begin() +
+                                inputVecSizesForLeadingDims.size(),
+                            vecToStoreShape.end());
     auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
-    Value maskForWrite = builder.create<vector::CreateMaskOp>(
-        loc, writeMaskType, tensor::getMixedSizes(builder, loc, dest));
+
+    SmallVector<OpFoldResult> destSizes =
+        tensor::getMixedSizes(builder, loc, dest);
+    SmallVector<OpFoldResult> maskSizes(destSizes.end() - writeMaskShape.size(),
+                                        destSizes.end());
+
+    if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
+                                writeMaskShape))
+      return write;
+
+    Value maskForWrite = builder.createOrFold<vector::CreateMaskOp>(
+        loc, writeMaskType, maskSizes);
     write = mlir::vector::maskOperation(builder, write, maskForWrite);
   }
 
@@ -1700,10 +1808,10 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
   Value dest = rewriter.create<tensor::EmptyOp>(
       loc, reifiedReturnShapes[0],
       transposeOp.getResult().getType().getElementType());
-  Operation *write =
-      createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(), dest,
-                               /*inputVecSizesForLeadingDims=*/inputVectorSizes,
-                               /*useInBoundsInsteadOfMasking=*/false);
+  Operation *write = createWriteOrMaskedWrite(
+      rewriter, loc, transposeOp.getResult(), dest,
+      /*inputVecSizesForLeadingDims=*/inputVectorSizes, /*writeIndices=*/{},
+      /*useInBoundsInsteadOfMasking=*/false);
   newResults.push_back(write->getResult(0));
   return success();
 }
@@ -1839,10 +1947,10 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
   Value dest = rewriter.create<tensor::EmptyOp>(
       loc, reifiedRetShapes[0],
       shapeCastOp.getResult().getType().getElementType());
-  Operation *write =
-      createWriteOrMaskedWrite(rewriter, loc, shapeCastOp.getResult(), dest,
-                               /*inputVecSizesForLeadingDims=*/writeVectorSizes,
-                               useInBoundsInsteadOfMasking);
+  Operation *write = createWriteOrMaskedWrite(
+      rewriter, loc, shapeCastOp.getResult(), dest,
+      /*inputVecSizesForLeadingDims=*/writeVectorSizes,
+      /*writeIndices=*/{}, useInBoundsInsteadOfMasking);
   newResults.push_back(write->getResult(0));
   return success();
 }
@@ -1874,10 +1982,10 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
   // Create Xfer write Op
   Value dest = rewriter.create<tensor::EmptyOp>(
       loc, reifiedReturnShapes[0], padOp.getResultType().getElementType());
-  Operation *write =
-      createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest,
-                               /*inputVecSizesForLeadingDims=*/inputVectorSizes,
-                               /*useInBoundsInsteadOfMasking=*/false);
+  Operation *write = createWriteOrMaskedWrite(
+      rewriter, loc, maskedRead, dest,
+      /*inputVecSizesForLeadingDims=*/inputVectorSizes, {},
+      /*useInBoundsInsteadOfMasking=*/false);
   newResults.push_back(write->getResult(0));
   return success();
 }
@@ -2922,53 +3030,19 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
   auto vecType = VectorType::get(vecShape, sourceType.getElementType());
 
   // 3. Generate TransferReadOp + TransferWriteOp
-  ReifiedRankedShapedTypeDims reifiedSrcSizes;
-  Value maskOp;
-
-  // If vector sizes are user provided, make sure to mask. First, generate the
-  // mask.
-  if (!inputVectorSizes.empty()) {
-    auto *srcDefOp = source.getDefiningOp();
-    if (!srcDefOp) {
-      LDBG("Unable to get the defining Op of " << sliceOp);
-      return failure();
-    }
-
-    LogicalResult status =
-        cast<ReifyRankedShapedTypeOpInterface>(srcDefOp).reifyResultShapes(
-            rewriter, reifiedSrcSizes);
-    if (status.failed()) {
-      LDBG("Unable to reify result shapes of " << srcDefOp);
-      return failure();
-    }
-
-    // Create the mask
-    auto readMaskType = VectorType::get(inputVectorSizes, rewriter.getI1Type());
-    maskOp = rewriter.create<vector::CreateMaskOp>(
-        sliceOp.getLoc(), readMaskType, reifiedSrcSizes[0]);
-  }
+  auto loc = sliceOp.getLoc();
 
+  // Create read
   SmallVector<Value> readIndices(
-      vecType.getRank(),
-      rewriter.create<arith::ConstantIndexOp>(sliceOp.getLoc(), 0));
-  Operation *read = rewriter.create<vector::TransferReadOp>(
-      sliceOp.getLoc(), vecType, source, readIndices, padValue,
-      ArrayRef<bool>{readInBounds});
-
-  if (maskOp) {
-    read = mlir::vector::maskOperation(rewriter, read, maskOp);
-  }
-
-  auto writeIndices = getValueOrCreateConstantIndexOp(
-      rewriter, sliceOp.getLoc(), sliceOp.getMixedOffsets());
-
-  Operation *write = rewriter.create<vector::TransferWriteOp>(
-      sliceOp.getLoc(), read->getResult(0), sliceOp.getDest(), writeIndices,
-      ArrayRef<bool>{writeInBounds});
-
-  if (maskOp) {
-    write = mlir::vector::maskOperation(rewriter, write, maskOp);
-  }
+      vecType.getRank(), rewriter.create<arith::ConstantIndexOp>(loc, 0));
+  Value read = mlir::vector::createReadOrMaskedRead(
+      rewriter, loc, source, vecType.getShape(), padValue);
+
+  // Create write
+  auto writeIndices =
+      getValueOrCreateConstantIndexOp(rewriter, loc, sliceOp.getMixedOffsets());
+  Operation *write = createWriteOrMaskedWrite(
+      rewriter, loc, read, sliceOp.getDest(), vecType.getShape(), writeIndices);
 
   // 4. Finalize
   newResults.push_back(write->getResult(0));
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index d5dd6f2027be8..dda4856596bba 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -337,13 +337,13 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
   auto sourceShape = sourceShapedType.getShape();
   assert(sourceShape.size() == inputVectorSizes.size() &&
          "expected same ranks.");
-  auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type());
   auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());
   assert(padValue.getType() == sourceShapedType.getElementType() &&
          "expected same pad element type to match source element type");
   int64_t readRank = inputVectorSizes.size();
   auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
   SmallVector<bool> inBoundsVal(readRank, true);
+
   if (useInBoundsInsteadOfMasking) {
     // Update the inBounds attribute.
     for (unsigned i = 0; i < readRank; i++)
@@ -362,6 +362,8 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
     return transferReadOp;
   SmallVector<OpFoldResult> mixedSourceDims =
       tensor::getMixedSizes(builder, loc, source);
+
+  auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type());
   Value mask =
       builder.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
   return mlir::vector::maskOperation(builder, transferReadOp, mask)
diff --git a/mlir/test/Dialect/LLVM/transform-e2e.mlir b/mlir/test/Dialect/LLVM/transform-e2e.mlir
index c00b47fb936e9..98cfaf249c898 100644
--- a/mlir/test/Dialect/LLVM/transform-e2e.mlir
+++ b/mlir/test/Dialect/LLVM/transform-e2e.mlir
@@ -18,16 +18,14 @@ module attributes {transform.with_named_sequence} {
     %1, %loops:3 = transform.structured.tile_using_for %0 tile_sizes [2, 2, 2] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
     %2 = transform.get_parent_op %1 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
     transform.structured.vectorize_children_and_apply_patterns %2 : (!transform.any_op) -> !transform.any_op
-    %b = transform.bufferization.one_shot_bufferize layout{IdentityLayoutMap}
-        %module_op {bufferize_function_boundaries = true}
-        : (!transform.any_op) -> !transform.any_op
 
-    %f = transform.structured.match ops{["func.func"]} in %b
+    %f = transform.structured.match ops{["func.func"]} in %module_op
       : (!transform.any_op) -> !transform.any_op
 
     // TODO: group these lower-level controls into various properly named vector
     // lowering TD macros.
     transform.apply_patterns to %f {
+      transform.apply_patterns.vector.lower_masked_transfers
       transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
       transform.apply_patterns.vector.transfer_permutation_patterns
       transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerparallel"
@@ -37,6 +35,10 @@ module attributes {transform.with_named_sequence} {
       transform.apply_patterns.vector.lower_shape_cast
       transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_1d"
     } : !transform.any_op
+
+    %b = transform.bufferization.one_shot_bufferize layout{IdentityLayoutMap}
+        %module_op {bufferize_function_boundaries = true}
+        : (!transform.any_op) -> !transform.any_op
     transform.yield
   }
 }
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 8c6760fa50325..9a18f040d57cd 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -1027,4 +1027,3 @@ func.func @test_vectorize_unpack_no_vector_sizes_permute(%source: tensor<4x7x4xf
     transform.yield
   }
  }
-
diff --git a/mlir/test/Dialect/Linalg/vectorization/insert-slice-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization/insert-slice-with-patterns.mlir
index f7764be9be73f..d1f2ed194f6ce 100644
--- a/mlir/test/Dialect/Linalg/vectorization/insert-slice-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/insert-slice-with-patterns.mlir
@@ -67,10 +67,19 @@ module attributes {transform.with_named_sequence} {
 // CHECK-SAME:      %[[ARG_0:.*]]: tensor<1x?x3xf32>,
 // CHECK-SAME:      %[[PAD:.*]]: f32,
 // CHECK-SAME:      %[[SIZE:.*]]: index) -> tensor<9x8x7x1x2x3xf32> {
+// CHECK:           %[[C3:.*]] = arith.constant 3 : index
+// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
 // CHECK:           %[[EMPTY:.*]] = tensor.empty() : tensor<9x8x7x1x2x3xf32>
 // CHECK:           %[[BC:.*]] = vector.broadcast %[[PAD]] : f32 to vector<9x8x7x1x2x3xf32>
 // CHECK:           %[[WRITE:.*]] = vector.transfer_write %[[BC]], %[[EMPTY]]{{.*}} {in_bounds = [true, true, true, true, true, true]} : vector<9x8x7...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented May 23, 2025

@llvm/pr-subscribers-mlir-llvm

Author: Andrzej Warzyński (banach-space)

Changes

This patch refactors two vectorization hooks in Vectorization.cpp:

  • createWriteOrMaskedWrite gains a new parameter for write indices,
    aligning it with its counterpart createReadOrMaskedRead.
  • vectorizeAsInsertSliceOp is updated to reuse both of the above
    hooks, rather than re-implementing similar logic.

CONTEXT

This is effectively a refactoring of the logic for vectorizing
tensor.insert_slice. Recent updates added masking support:

At the time, reuse of the shared create* hooks wasn't feasible due to
missing parameters and overly rigid assumptions. This patch resolves
that and moves us closer to a more maintainable structure.

CHANGES IN vectorizeAsInsertSliceOp

  • Introduces a clear distinction between the destination tensor and the
    vector to store, via named variables like destType/vecToStoreType,
    destShape/vecToStoreShape, etc.
  • Ensures the correct rank and shape are used for attributes like
    in_bounds. For example, the size of the in_bounds array now matches
    the source vector rank, not the tensor rank.
  • Drops the assumption that vecToStoreRank == destRank — this doesn't
    hold in many real examples.
  • Deduces mask dimensions from vecToStoreShape (vector) instead of
    destShape (tensor). (Eventually we should not require
    inputVecSizesForLeadingDims at all — mask shape should be inferred.)

NEW HELPER: isMaskTriviallyFoldable

Adds a utility to detect when masking is unnecessary. This avoids
inserting redundant masks and reduces the burden on canonicalization to
clean them up later.

Example where masking is provably unnecessary:

%2 = vector.mask %1 {
  vector.transfer_write %0, %arg1[%c0, %c0, %c0, %c0, %c0, %c0]
    {in_bounds = [true, true, true]}
    : vector&lt;1x2x3xf32&gt;, tensor&lt;9x8x7x1x2x3xf32&gt;
} : vector&lt;1x2x3xi1&gt; -&gt; tensor&lt;9x8x7x1x2x3xf32&gt;

Also, without this hook, tests are more complicated and require more
matching.

TEST CHANGES

This patch primarily affects vectorization of:

  • tensor.insert_slice, now refactored to use shared hooks.

tensor.pad vectorization patterns, which internally use
tensor.insert_slice, are also effectively updated. Note, only
pad-with-patterns.mlir is affected.

Most test updates involve the insertion of masks that were previously
missing — this reflects a correctness fix, not a regression. In all
cases, the added masks are indeed required.

You’ll also notice more repeated constants (arith.constant 0 : index),
due to increased use of helper hooks. This will be cleaned up separately
via a constant cache (see #138265 for discussion).

NOTE FOR REVIEWERS

This is a fairly substantial rewrite. You may find it easier to review
createWriteOrMaskedWrite as a new method rather than diffing
line-by-line.

TODOs (future PRs)

Further alignment of createWriteOrMaskedWrite and
createReadOrMaskedRead:

  • Move createWriteOrMaskedWrite next to createReadOrMaskedRead (in
    VectorUtils.cpp)
  • Make createReadOrMaskedRead leverage isMaskTriviallyFoldable.

(* This method will eventually be moved out of Vectorization.cpp, which isn't the right long-term home for it.)


Patch is 34.25 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/141244.diff

7 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+166-92)
  • (modified) mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp (+3-1)
  • (modified) mlir/test/Dialect/LLVM/transform-e2e.mlir (+6-4)
  • (modified) mlir/test/Dialect/Linalg/vectorization.mlir (-1)
  • (modified) mlir/test/Dialect/Linalg/vectorization/insert-slice-with-patterns.mlir (+10-1)
  • (modified) mlir/test/Dialect/Linalg/vectorization/insert-slice.mlir (+51-30)
  • (modified) mlir/test/Dialect/Linalg/vectorization/pad-with-patterns.mlir (+17-10)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index c5b62227777a7..0113ba86a5ae3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1506,20 +1506,104 @@ static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
   return applyPermutation(destShape, linalg::getPackInverseDestPerm(packOp));
 }
 
+/// Determines whether the mask for a corresponding `vector.transfer_write` op
+/// is trivially foldable (i.e., guaranteed to be all true).
+///
+/// Requirements:
+///   * All involved shapes (destination, mask) are static.
+///   * All write indices are constant.
+///   * All mask sizes are constant.
+///
+/// Once verified, the method checks for each destination dimension `d`:
+///   (1) destDimSize[rankDiff + d] <= maskShape[d]
+///   (2) destDimSize[rankDiff + d] <= writeIndex[d] + maskSize[d]
+///
+/// rankDiff = rank(dest) - rank(mask).
+///
+/// This method takes a conservative view: it may return false even if the mask
+/// is technically foldable.
+///
+/// EXAMPLE 1 (trivially foldable):
+///   %c0 = arith.constant 0 : index
+///   vector.transfer_write %vecToStore_1, %dest{[%c0, %c0]
+///     {in_bounds = [true, true]}
+///   : vector<5x1xi32>, tensor<5x1xi32>
+///
+/// EXAMPLE 2 (not trivially foldable - vector shape exceeds the tensor shape):
+///   %c0 = arith.constant 0 : index
+///   vector.transfer_write %vecToStore_2, %dest[%c0, %c0]
+///     {in_bounds = [true, true]}
+///   : vector<8x1xi32>, tensor<5x1xi32>
+///
+/// TODO: Re-use in createReadOrMaskedRead
+static bool isMaskTriviallyFoldable(SmallVector<OpFoldResult> &maskSizes,
+                                    SmallVector<Value> &writeIdxs,
+                                    ArrayRef<int64_t> destShape,
+                                    ArrayRef<int64_t> maskShape) {
+  // Masking is unavoidable in the case of dynamic tensors.
+  if (ShapedType::isDynamicShape(destShape))
+    return false;
+
+  // Collect all constant mask sizes.
+  SmallVector<int64_t, 4> cstMaskSizes;
+  for (auto [i, dimSize] : llvm::enumerate(maskSizes)) {
+    if (auto intSize = getConstantIntValue(dimSize)) {
+      cstMaskSizes.push_back(*intSize);
+    }
+  }
+
+  // If any of the mask sizes is non-constant, bail out.
+  if (cstMaskSizes.size() != maskShape.size())
+    return false;
+
+  // Collect all constant write indices.
+  SmallVector<int64_t, 4> cstWriteIdxs;
+  for (auto [i, idx] : llvm::enumerate(writeIdxs)) {
+    APSInt intVal;
+    if (matchPattern(idx, m_ConstantInt(&intVal))) {
+      cstWriteIdxs.push_back(intVal.getSExtValue());
+    }
+  }
+
+  // If any of the write indices is non-constant, bail out.
+  if (cstWriteIdxs.size() != destShape.size())
+    return false;
+
+  // Go over all destination dims and check (1) and (2). Take into account that:
+  //  * The number of mask sizes will match the rank of the vector to store.
+  //    This could be lower than the rank of the destination tensor.
+  //  * Mask sizes could be larger than the corresponding mask shape (hence
+  //  `clamp`).
+  // TODO: The 2nd item should be rejected by the verifier.
+  int64_t rankDiff = destShape.size() - cstMaskSizes.size();
+  for (auto [i, idx] : llvm::enumerate(cstMaskSizes)) {
+    if (/*(1)*/ maskShape[i] > destShape[rankDiff + i] ||
+        /*(2)*/ destShape[rankDiff + i] <
+            (std::clamp(cstMaskSizes[i], int64_t(0), maskShape[i]) +
+             cstWriteIdxs[i]))
+      return false;
+  }
+
+  return true;
+}
+
 /// Creates an optionally masked TransferWriteOp
 ///
 /// Generates the following operation:
 ///   %res = vector.transfer_write %vectorToStore into %dest
 ///
-/// If the leading N dimensions of the destination tensor do not match
+/// If the leading N dimensions of the vector to store do not match
 /// `inputVecSizesForLeadingDims` (N = rank(inputVecSizesForLeadingDims)),
 /// masking is applied to ensure correctness:
 ///
-///   %mask = vector.create_mask(%destShape)
+///   %mask = vector.create_mask(%destShape) : %vectorToStoreShape
 ///   %res = vector.mask %mask {
 ///     vector.transfer_write %vectorToStore into %dest
 ///   }
 ///
+/// The mask shape is identical to `vectorToStore` (with the element type ==
+/// i1), and the mask values are based on the shape of the `dest` tensor.
+///
 /// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
 /// is used instead of masking:
 ///
@@ -1528,75 +1612,99 @@ static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
 ///   %res = vector.transfer_write %input into %dest
 ///       {in_bounds = in_bounds_flags}
 ///
-/// NOTE: All write offsets are set to 0.
-/// TODO: Allow specyfying write offsets.
-/// NOTE: When N < rank(input), the missing vector sizes are effectively
-/// extracted from the trailing sizes of `destSizes`. This means those sizes
-/// must be static.
-/// TODO: Support cases where an arbitrary dim is dynamic - this will require
-/// specifying all the vector sizes.
+/// `writeIndices` specifies the offsets to use. If empty, all indices are set
+/// to 0.
+///
+/// NOTE: When N < rank(vectorToStore), the missing vector sizes are taken from
+/// `valueToStore`.
+/// TODO: `inputVecSizesForLeadingDims` should not be required - these sizes are
+/// already provided in `vectorToStore`.
 static Operation *
 createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
                          Value dest,
                          ArrayRef<int64_t> inputVecSizesForLeadingDims,
+                         SmallVector<Value> writeIndices = {},
                          bool useInBoundsInsteadOfMasking = false) {
 
   ShapedType destType = cast<ShapedType>(dest.getType());
-  assert(cast<VectorType>(vectorToStore.getType()).getRank() ==
-             static_cast<int64_t>(destType.getRank()) &&
-         "Rank mismatch!");
-  (void)destType;
+  int64_t destRank = destType.getRank();
+  auto destShape = destType.getShape();
 
-  int64_t rank = cast<ShapedType>(dest.getType()).getRank();
-  auto destShape = cast<ShapedType>(dest.getType()).getShape();
+  VectorType vecToStoreType = cast<VectorType>(vectorToStore.getType());
+  int64_t vecToStoreRank = vecToStoreType.getRank();
+  auto vecToStoreShape = vecToStoreType.getShape();
 
   // Compute the in_bounds attribute
-  SmallVector<bool> inBoundsVal(rank, true);
+  SmallVector<bool> inBoundsVal(vecToStoreRank, true);
   if (useInBoundsInsteadOfMasking) {
     // In this case, assume that all the required vector sizes have been
     // provided.
     assert(inputVecSizesForLeadingDims.size() ==
-               static_cast<size_t>(destType.getRank()) &&
+               static_cast<size_t>(vecToStoreType.getRank()) &&
            "Insufficient number of input vector sizes!");
     // Update the inBounds attribute.
-    for (unsigned i = 0; i < rank; i++)
+    for (unsigned i = 0; i < destRank; i++)
       inBoundsVal[i] = (destShape[i] == inputVecSizesForLeadingDims[i]) &&
                        !ShapedType::isDynamic(destShape[i]);
   }
 
+  // If missing, initialize the write indices to 0.
+  assert(writeIndices.empty() ||
+         writeIndices.size() == static_cast<size_t>(destRank) &&
+             "Invalid number of write indices!");
+  if (writeIndices.empty()) {
+    auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+    writeIndices = SmallVector<Value>(destRank, zero);
+  }
+
   // Generate the xfer_write Op
-  auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
-  Operation *write = builder.create<vector::TransferWriteOp>(
-      loc,
-      /*vector=*/vectorToStore,
-      /*source=*/dest,
-      /*indices=*/SmallVector<Value>(rank, zero),
-      /*inBounds=*/inBoundsVal);
-  assert(llvm::none_of(
-             destShape.drop_front(inputVecSizesForLeadingDims.size()),
-             [](int64_t size) { return size == ShapedType::kDynamic; }) &&
-         "Only dims aligned with inputVecSizesForLeadingDims may be dynamic");
+  Operation *write =
+      builder.create<vector::TransferWriteOp>(loc,
+                                              /*vector=*/vectorToStore,
+                                              /*source=*/dest,
+                                              /*indices=*/writeIndices,
+                                              /*inBounds=*/inBoundsVal);
 
   // If masking is disabled, exit.
   if (useInBoundsInsteadOfMasking)
     return write;
 
+  assert(llvm::none_of(
+             destShape.drop_front(inputVecSizesForLeadingDims.size()),
+             [](int64_t size) { return size == ShapedType::kDynamic; }) &&
+         "Only dims aligned with inputVecSizesForLeadingDims may be dynamic");
+
   // Check if masking is needed.
   bool needMaskForWrite =
       !llvm::equal(inputVecSizesForLeadingDims,
-                   destShape.take_front(inputVecSizesForLeadingDims.size()));
+                   destShape.take_front(destRank - vecToStoreRank +
+                                        inputVecSizesForLeadingDims.size()));
 
   // If masking is needed, generate the mask and mask the operation.
   if (needMaskForWrite) {
+    // Get the mask shape + type. Missing mask dimensions are taken from
+    // `vectorToStore`.
     SmallVector<int64_t> writeMaskShape;
     writeMaskShape.append(inputVecSizesForLeadingDims.begin(),
                           inputVecSizesForLeadingDims.end());
-    writeMaskShape.append(destShape.begin() +
-                              inputVecSizesForLeadingDims.size(),
-                          destShape.end());
+    if (vecToStoreRank >
+        static_cast<int64_t>(inputVecSizesForLeadingDims.size()))
+      writeMaskShape.append(vecToStoreShape.begin() +
+                                inputVecSizesForLeadingDims.size(),
+                            vecToStoreShape.end());
     auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
-    Value maskForWrite = builder.create<vector::CreateMaskOp>(
-        loc, writeMaskType, tensor::getMixedSizes(builder, loc, dest));
+
+    SmallVector<OpFoldResult> destSizes =
+        tensor::getMixedSizes(builder, loc, dest);
+    SmallVector<OpFoldResult> maskSizes(destSizes.end() - writeMaskShape.size(),
+                                        destSizes.end());
+
+    if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
+                                writeMaskShape))
+      return write;
+
+    Value maskForWrite = builder.createOrFold<vector::CreateMaskOp>(
+        loc, writeMaskType, maskSizes);
     write = mlir::vector::maskOperation(builder, write, maskForWrite);
   }
 
@@ -1700,10 +1808,10 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
   Value dest = rewriter.create<tensor::EmptyOp>(
       loc, reifiedReturnShapes[0],
       transposeOp.getResult().getType().getElementType());
-  Operation *write =
-      createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(), dest,
-                               /*inputVecSizesForLeadingDims=*/inputVectorSizes,
-                               /*useInBoundsInsteadOfMasking=*/false);
+  Operation *write = createWriteOrMaskedWrite(
+      rewriter, loc, transposeOp.getResult(), dest,
+      /*inputVecSizesForLeadingDims=*/inputVectorSizes, /*writeIndices=*/{},
+      /*useInBoundsInsteadOfMasking=*/false);
   newResults.push_back(write->getResult(0));
   return success();
 }
@@ -1839,10 +1947,10 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
   Value dest = rewriter.create<tensor::EmptyOp>(
       loc, reifiedRetShapes[0],
       shapeCastOp.getResult().getType().getElementType());
-  Operation *write =
-      createWriteOrMaskedWrite(rewriter, loc, shapeCastOp.getResult(), dest,
-                               /*inputVecSizesForLeadingDims=*/writeVectorSizes,
-                               useInBoundsInsteadOfMasking);
+  Operation *write = createWriteOrMaskedWrite(
+      rewriter, loc, shapeCastOp.getResult(), dest,
+      /*inputVecSizesForLeadingDims=*/writeVectorSizes,
+      /*writeIndices=*/{}, useInBoundsInsteadOfMasking);
   newResults.push_back(write->getResult(0));
   return success();
 }
@@ -1874,10 +1982,10 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
   // Create Xfer write Op
   Value dest = rewriter.create<tensor::EmptyOp>(
       loc, reifiedReturnShapes[0], padOp.getResultType().getElementType());
-  Operation *write =
-      createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest,
-                               /*inputVecSizesForLeadingDims=*/inputVectorSizes,
-                               /*useInBoundsInsteadOfMasking=*/false);
+  Operation *write = createWriteOrMaskedWrite(
+      rewriter, loc, maskedRead, dest,
+      /*inputVecSizesForLeadingDims=*/inputVectorSizes, {},
+      /*useInBoundsInsteadOfMasking=*/false);
   newResults.push_back(write->getResult(0));
   return success();
 }
@@ -2922,53 +3030,19 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
   auto vecType = VectorType::get(vecShape, sourceType.getElementType());
 
   // 3. Generate TransferReadOp + TransferWriteOp
-  ReifiedRankedShapedTypeDims reifiedSrcSizes;
-  Value maskOp;
-
-  // If vector sizes are user provided, make sure to mask. First, generate the
-  // mask.
-  if (!inputVectorSizes.empty()) {
-    auto *srcDefOp = source.getDefiningOp();
-    if (!srcDefOp) {
-      LDBG("Unable to get the defining Op of " << sliceOp);
-      return failure();
-    }
-
-    LogicalResult status =
-        cast<ReifyRankedShapedTypeOpInterface>(srcDefOp).reifyResultShapes(
-            rewriter, reifiedSrcSizes);
-    if (status.failed()) {
-      LDBG("Unable to reify result shapes of " << srcDefOp);
-      return failure();
-    }
-
-    // Create the mask
-    auto readMaskType = VectorType::get(inputVectorSizes, rewriter.getI1Type());
-    maskOp = rewriter.create<vector::CreateMaskOp>(
-        sliceOp.getLoc(), readMaskType, reifiedSrcSizes[0]);
-  }
+  auto loc = sliceOp.getLoc();
 
+  // Create read
   SmallVector<Value> readIndices(
-      vecType.getRank(),
-      rewriter.create<arith::ConstantIndexOp>(sliceOp.getLoc(), 0));
-  Operation *read = rewriter.create<vector::TransferReadOp>(
-      sliceOp.getLoc(), vecType, source, readIndices, padValue,
-      ArrayRef<bool>{readInBounds});
-
-  if (maskOp) {
-    read = mlir::vector::maskOperation(rewriter, read, maskOp);
-  }
-
-  auto writeIndices = getValueOrCreateConstantIndexOp(
-      rewriter, sliceOp.getLoc(), sliceOp.getMixedOffsets());
-
-  Operation *write = rewriter.create<vector::TransferWriteOp>(
-      sliceOp.getLoc(), read->getResult(0), sliceOp.getDest(), writeIndices,
-      ArrayRef<bool>{writeInBounds});
-
-  if (maskOp) {
-    write = mlir::vector::maskOperation(rewriter, write, maskOp);
-  }
+      vecType.getRank(), rewriter.create<arith::ConstantIndexOp>(loc, 0));
+  Value read = mlir::vector::createReadOrMaskedRead(
+      rewriter, loc, source, vecType.getShape(), padValue);
+
+  // Create write
+  auto writeIndices =
+      getValueOrCreateConstantIndexOp(rewriter, loc, sliceOp.getMixedOffsets());
+  Operation *write = createWriteOrMaskedWrite(
+      rewriter, loc, read, sliceOp.getDest(), vecType.getShape(), writeIndices);
 
   // 4. Finalize
   newResults.push_back(write->getResult(0));
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index d5dd6f2027be8..dda4856596bba 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -337,13 +337,13 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
   auto sourceShape = sourceShapedType.getShape();
   assert(sourceShape.size() == inputVectorSizes.size() &&
          "expected same ranks.");
-  auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type());
   auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());
   assert(padValue.getType() == sourceShapedType.getElementType() &&
          "expected same pad element type to match source element type");
   int64_t readRank = inputVectorSizes.size();
   auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
   SmallVector<bool> inBoundsVal(readRank, true);
+
   if (useInBoundsInsteadOfMasking) {
     // Update the inBounds attribute.
     for (unsigned i = 0; i < readRank; i++)
@@ -362,6 +362,8 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
     return transferReadOp;
   SmallVector<OpFoldResult> mixedSourceDims =
       tensor::getMixedSizes(builder, loc, source);
+
+  auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type());
   Value mask =
       builder.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
   return mlir::vector::maskOperation(builder, transferReadOp, mask)
diff --git a/mlir/test/Dialect/LLVM/transform-e2e.mlir b/mlir/test/Dialect/LLVM/transform-e2e.mlir
index c00b47fb936e9..98cfaf249c898 100644
--- a/mlir/test/Dialect/LLVM/transform-e2e.mlir
+++ b/mlir/test/Dialect/LLVM/transform-e2e.mlir
@@ -18,16 +18,14 @@ module attributes {transform.with_named_sequence} {
     %1, %loops:3 = transform.structured.tile_using_for %0 tile_sizes [2, 2, 2] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
     %2 = transform.get_parent_op %1 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
     transform.structured.vectorize_children_and_apply_patterns %2 : (!transform.any_op) -> !transform.any_op
-    %b = transform.bufferization.one_shot_bufferize layout{IdentityLayoutMap}
-        %module_op {bufferize_function_boundaries = true}
-        : (!transform.any_op) -> !transform.any_op
 
-    %f = transform.structured.match ops{["func.func"]} in %b
+    %f = transform.structured.match ops{["func.func"]} in %module_op
       : (!transform.any_op) -> !transform.any_op
 
     // TODO: group these lower-level controls into various properly named vector
     // lowering TD macros.
     transform.apply_patterns to %f {
+      transform.apply_patterns.vector.lower_masked_transfers
       transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
       transform.apply_patterns.vector.transfer_permutation_patterns
       transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerparallel"
@@ -37,6 +35,10 @@ module attributes {transform.with_named_sequence} {
       transform.apply_patterns.vector.lower_shape_cast
       transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_1d"
     } : !transform.any_op
+
+    %b = transform.bufferization.one_shot_bufferize layout{IdentityLayoutMap}
+        %module_op {bufferize_function_boundaries = true}
+        : (!transform.any_op) -> !transform.any_op
     transform.yield
   }
 }
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 8c6760fa50325..9a18f040d57cd 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -1027,4 +1027,3 @@ func.func @test_vectorize_unpack_no_vector_sizes_permute(%source: tensor<4x7x4xf
     transform.yield
   }
  }
-
diff --git a/mlir/test/Dialect/Linalg/vectorization/insert-slice-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization/insert-slice-with-patterns.mlir
index f7764be9be73f..d1f2ed194f6ce 100644
--- a/mlir/test/Dialect/Linalg/vectorization/insert-slice-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/insert-slice-with-patterns.mlir
@@ -67,10 +67,19 @@ module attributes {transform.with_named_sequence} {
 // CHECK-SAME:      %[[ARG_0:.*]]: tensor<1x?x3xf32>,
 // CHECK-SAME:      %[[PAD:.*]]: f32,
 // CHECK-SAME:      %[[SIZE:.*]]: index) -> tensor<9x8x7x1x2x3xf32> {
+// CHECK:           %[[C3:.*]] = arith.constant 3 : index
+// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
 // CHECK:           %[[EMPTY:.*]] = tensor.empty() : tensor<9x8x7x1x2x3xf32>
 // CHECK:           %[[BC:.*]] = vector.broadcast %[[PAD]] : f32 to vector<9x8x7x1x2x3xf32>
 // CHECK:           %[[WRITE:.*]] = vector.transfer_write %[[BC]], %[[EMPTY]]{{.*}} {in_bounds = [true, true, true, true, true, true]} : vector<9x8x7...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented May 23, 2025

@llvm/pr-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

Changes

This patch refactors two vectorization hooks in Vectorization.cpp:

  • createWriteOrMaskedWrite gains a new parameter for write indices,
    aligning it with its counterpart createReadOrMaskedRead.
  • vectorizeAsInsertSliceOp is updated to reuse both of the above
    hooks, rather than re-implementing similar logic.

CONTEXT

This is effectively a refactoring of the logic for vectorizing
tensor.insert_slice. Recent updates added masking support:

At the time, reuse of the shared create* hooks wasn't feasible due to
missing parameters and overly rigid assumptions. This patch resolves
that and moves us closer to a more maintainable structure.

CHANGES IN vectorizeAsInsertSliceOp

  • Introduces a clear distinction between the destination tensor and the
    vector to store, via named variables like destType/vecToStoreType,
    destShape/vecToStoreShape, etc.
  • Ensures the correct rank and shape are used for attributes like
    in_bounds. For example, the size of the in_bounds array now matches
    the source vector rank, not the tensor rank.
  • Drops the assumption that vecToStoreRank == destRank — this doesn't
    hold in many real examples.
  • Deduces mask dimensions from vecToStoreShape (vector) instead of
    destShape (tensor). (Eventually we should not require
    inputVecSizesForLeadingDims at all — mask shape should be inferred.)

NEW HELPER: isMaskTriviallyFoldable

Adds a utility to detect when masking is unnecessary. This avoids
inserting redundant masks and reduces the burden on canonicalization to
clean them up later.

Example where masking is provably unnecessary:

%2 = vector.mask %1 {
  vector.transfer_write %0, %arg1[%c0, %c0, %c0, %c0, %c0, %c0]
    {in_bounds = [true, true, true]}
    : vector&lt;1x2x3xf32&gt;, tensor&lt;9x8x7x1x2x3xf32&gt;
} : vector&lt;1x2x3xi1&gt; -&gt; tensor&lt;9x8x7x1x2x3xf32&gt;

Also, without this hook, tests are more complicated and require more
matching.

TEST CHANGES

This patch primarily affects vectorization of:

  • tensor.insert_slice, now refactored to use shared hooks.

tensor.pad vectorization patterns, which internally use
tensor.insert_slice, are also effectively updated. Note, only
pad-with-patterns.mlir is affected.

Most test updates involve the insertion of masks that were previously
missing — this reflects a correctness fix, not a regression. In all
cases, the added masks are indeed required.

You’ll also notice more repeated constants (arith.constant 0 : index),
due to increased use of helper hooks. This will be cleaned up separately
via a constant cache (see #138265 for discussion).

NOTE FOR REVIEWERS

This is a fairly substantial rewrite. You may find it easier to review
createWriteOrMaskedWrite as a new method rather than diffing
line-by-line.

TODOs (future PRs)

Further alignment of createWriteOrMaskedWrite and
createReadOrMaskedRead:

  • Move createWriteOrMaskedWrite next to createReadOrMaskedRead (in
    VectorUtils.cpp)
  • Make createReadOrMaskedRead leverage isMaskTriviallyFoldable.

(* This method will eventually be moved out of Vectorization.cpp, which isn't the right long-term home for it.)


Patch is 34.25 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/141244.diff

7 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+166-92)
  • (modified) mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp (+3-1)
  • (modified) mlir/test/Dialect/LLVM/transform-e2e.mlir (+6-4)
  • (modified) mlir/test/Dialect/Linalg/vectorization.mlir (-1)
  • (modified) mlir/test/Dialect/Linalg/vectorization/insert-slice-with-patterns.mlir (+10-1)
  • (modified) mlir/test/Dialect/Linalg/vectorization/insert-slice.mlir (+51-30)
  • (modified) mlir/test/Dialect/Linalg/vectorization/pad-with-patterns.mlir (+17-10)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index c5b62227777a7..0113ba86a5ae3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1506,20 +1506,104 @@ static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
   return applyPermutation(destShape, linalg::getPackInverseDestPerm(packOp));
 }
 
+/// Determines whether the mask for a corresponding `vector.transfer_write` op
+/// is trivially foldable (i.e., guaranteed to be all true).
+///
+/// Requirements:
+///   * All involved shapes (destination, mask) are static.
+///   * All write indices are constant.
+///   * All mask sizes are constant.
+///
+/// Once verified, the method checks for each destination dimension `d`:
+///   (1) destDimSize[rankDiff + d] <= maskShape[d]
+///   (2) destDimSize[rankDiff + d] <= writeIndex[d] + maskSize[d]
+///
+/// rankDiff = rank(dest) - rank(mask).
+///
+/// This method takes a conservative view: it may return false even if the mask
+/// is technically foldable.
+///
+/// EXAMPLE 1 (trivially foldable):
+///   %c0 = arith.constant 0 : index
+///   vector.transfer_write %vecToStore_1, %dest{[%c0, %c0]
+///     {in_bounds = [true, true]}
+///   : vector<5x1xi32>, tensor<5x1xi32>
+///
+/// EXAMPLE 2 (not trivially foldable - vector shape exceeds the tensor shape):
+///   %c0 = arith.constant 0 : index
+///   vector.transfer_write %vecToStore_2, %dest[%c0, %c0]
+///     {in_bounds = [true, true]}
+///   : vector<8x1xi32>, tensor<5x1xi32>
+///
+/// TODO: Re-use in createReadOrMaskedRead
+static bool isMaskTriviallyFoldable(SmallVector<OpFoldResult> &maskSizes,
+                                    SmallVector<Value> &writeIdxs,
+                                    ArrayRef<int64_t> destShape,
+                                    ArrayRef<int64_t> maskShape) {
+  // Masking is unavoidable in the case of dynamic tensors.
+  if (ShapedType::isDynamicShape(destShape))
+    return false;
+
+  // Collect all constant mask sizes.
+  SmallVector<int64_t, 4> cstMaskSizes;
+  for (auto [i, dimSize] : llvm::enumerate(maskSizes)) {
+    if (auto intSize = getConstantIntValue(dimSize)) {
+      cstMaskSizes.push_back(*intSize);
+    }
+  }
+
+  // If any of the mask sizes is non-constant, bail out.
+  if (cstMaskSizes.size() != maskShape.size())
+    return false;
+
+  // Collect all constant write indices.
+  SmallVector<int64_t, 4> cstWriteIdxs;
+  for (auto [i, idx] : llvm::enumerate(writeIdxs)) {
+    APSInt intVal;
+    if (matchPattern(idx, m_ConstantInt(&intVal))) {
+      cstWriteIdxs.push_back(intVal.getSExtValue());
+    }
+  }
+
+  // If any of the write indices is non-constant, bail out.
+  if (cstWriteIdxs.size() != destShape.size())
+    return false;
+
+  // Go over all destination dims and check (1) and (2). Take into account that:
+  //  * The number of mask sizes will match the rank of the vector to store.
+  //    This could be lower than the rank of the destination tensor.
+  //  * Mask sizes could be larger than the corresponding mask shape (hence
+  //  `clamp`).
+  // TODO: The 2nd item should be rejected by the verifier.
+  int64_t rankDiff = destShape.size() - cstMaskSizes.size();
+  for (auto [i, idx] : llvm::enumerate(cstMaskSizes)) {
+    if (/*(1)*/ maskShape[i] > destShape[rankDiff + i] ||
+        /*(2)*/ destShape[rankDiff + i] <
+            (std::clamp(cstMaskSizes[i], int64_t(0), maskShape[i]) +
+             cstWriteIdxs[i]))
+      return false;
+  }
+
+  return true;
+}
+
 /// Creates an optionally masked TransferWriteOp
 ///
 /// Generates the following operation:
 ///   %res = vector.transfer_write %vectorToStore into %dest
 ///
-/// If the leading N dimensions of the destination tensor do not match
+/// If the leading N dimensions of the vector to store do not match
 /// `inputVecSizesForLeadingDims` (N = rank(inputVecSizesForLeadingDims)),
 /// masking is applied to ensure correctness:
 ///
-///   %mask = vector.create_mask(%destShape)
+///   %mask = vector.create_mask(%destShape) : %vectorToStoreShape
 ///   %res = vector.mask %mask {
 ///     vector.transfer_write %vectorToStore into %dest
 ///   }
 ///
+/// The mask shape is identical to `vectorToStore` (with the element type ==
+/// i1), and the mask values are based on the shape of the `dest` tensor.
+///
 /// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
 /// is used instead of masking:
 ///
@@ -1528,75 +1612,99 @@ static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
 ///   %res = vector.transfer_write %input into %dest
 ///       {in_bounds = in_bounds_flags}
 ///
-/// NOTE: All write offsets are set to 0.
-/// TODO: Allow specyfying write offsets.
-/// NOTE: When N < rank(input), the missing vector sizes are effectively
-/// extracted from the trailing sizes of `destSizes`. This means those sizes
-/// must be static.
-/// TODO: Support cases where an arbitrary dim is dynamic - this will require
-/// specifying all the vector sizes.
+/// `writeIndices` specifies the offsets to use. If empty, all indices are set
+/// to 0.
+///
+/// NOTE: When N < rank(vectorToStore), the missing vector sizes are taken from
+/// `valueToStore`.
+/// TODO: `inputVecSizesForLeadingDims` should not be required - these sizes are
+/// already provided in `vectorToStore`.
 static Operation *
 createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
                          Value dest,
                          ArrayRef<int64_t> inputVecSizesForLeadingDims,
+                         SmallVector<Value> writeIndices = {},
                          bool useInBoundsInsteadOfMasking = false) {
 
   ShapedType destType = cast<ShapedType>(dest.getType());
-  assert(cast<VectorType>(vectorToStore.getType()).getRank() ==
-             static_cast<int64_t>(destType.getRank()) &&
-         "Rank mismatch!");
-  (void)destType;
+  int64_t destRank = destType.getRank();
+  auto destShape = destType.getShape();
 
-  int64_t rank = cast<ShapedType>(dest.getType()).getRank();
-  auto destShape = cast<ShapedType>(dest.getType()).getShape();
+  VectorType vecToStoreType = cast<VectorType>(vectorToStore.getType());
+  int64_t vecToStoreRank = vecToStoreType.getRank();
+  auto vecToStoreShape = vecToStoreType.getShape();
 
   // Compute the in_bounds attribute
-  SmallVector<bool> inBoundsVal(rank, true);
+  SmallVector<bool> inBoundsVal(vecToStoreRank, true);
   if (useInBoundsInsteadOfMasking) {
     // In this case, assume that all the required vector sizes have been
     // provided.
     assert(inputVecSizesForLeadingDims.size() ==
-               static_cast<size_t>(destType.getRank()) &&
+               static_cast<size_t>(vecToStoreType.getRank()) &&
            "Insufficient number of input vector sizes!");
     // Update the inBounds attribute.
-    for (unsigned i = 0; i < rank; i++)
+    for (unsigned i = 0; i < destRank; i++)
       inBoundsVal[i] = (destShape[i] == inputVecSizesForLeadingDims[i]) &&
                        !ShapedType::isDynamic(destShape[i]);
   }
 
+  // If missing, initialize the write indices to 0.
+  assert(writeIndices.empty() ||
+         writeIndices.size() == static_cast<size_t>(destRank) &&
+             "Invalid number of write indices!");
+  if (writeIndices.empty()) {
+    auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+    writeIndices = SmallVector<Value>(destRank, zero);
+  }
+
   // Generate the xfer_write Op
-  auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
-  Operation *write = builder.create<vector::TransferWriteOp>(
-      loc,
-      /*vector=*/vectorToStore,
-      /*source=*/dest,
-      /*indices=*/SmallVector<Value>(rank, zero),
-      /*inBounds=*/inBoundsVal);
-  assert(llvm::none_of(
-             destShape.drop_front(inputVecSizesForLeadingDims.size()),
-             [](int64_t size) { return size == ShapedType::kDynamic; }) &&
-         "Only dims aligned with inputVecSizesForLeadingDims may be dynamic");
+  Operation *write =
+      builder.create<vector::TransferWriteOp>(loc,
+                                              /*vector=*/vectorToStore,
+                                              /*source=*/dest,
+                                              /*indices=*/writeIndices,
+                                              /*inBounds=*/inBoundsVal);
 
   // If masking is disabled, exit.
   if (useInBoundsInsteadOfMasking)
     return write;
 
+  assert(llvm::none_of(
+             destShape.drop_front(inputVecSizesForLeadingDims.size()),
+             [](int64_t size) { return size == ShapedType::kDynamic; }) &&
+         "Only dims aligned with inputVecSizesForLeadingDims may be dynamic");
+
   // Check if masking is needed.
   bool needMaskForWrite =
       !llvm::equal(inputVecSizesForLeadingDims,
-                   destShape.take_front(inputVecSizesForLeadingDims.size()));
+                   destShape.take_front(destRank - vecToStoreRank +
+                                        inputVecSizesForLeadingDims.size()));
 
   // If masking is needed, generate the mask and mask the operation.
   if (needMaskForWrite) {
+    // Get the mask shape + type. Missing mask dimensions are taken from
+    // `vectorToStore`.
     SmallVector<int64_t> writeMaskShape;
     writeMaskShape.append(inputVecSizesForLeadingDims.begin(),
                           inputVecSizesForLeadingDims.end());
-    writeMaskShape.append(destShape.begin() +
-                              inputVecSizesForLeadingDims.size(),
-                          destShape.end());
+    if (vecToStoreRank >
+        static_cast<int64_t>(inputVecSizesForLeadingDims.size()))
+      writeMaskShape.append(vecToStoreShape.begin() +
+                                inputVecSizesForLeadingDims.size(),
+                            vecToStoreShape.end());
     auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
-    Value maskForWrite = builder.create<vector::CreateMaskOp>(
-        loc, writeMaskType, tensor::getMixedSizes(builder, loc, dest));
+
+    SmallVector<OpFoldResult> destSizes =
+        tensor::getMixedSizes(builder, loc, dest);
+    SmallVector<OpFoldResult> maskSizes(destSizes.end() - writeMaskShape.size(),
+                                        destSizes.end());
+
+    if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
+                                writeMaskShape))
+      return write;
+
+    Value maskForWrite = builder.createOrFold<vector::CreateMaskOp>(
+        loc, writeMaskType, maskSizes);
     write = mlir::vector::maskOperation(builder, write, maskForWrite);
   }
 
@@ -1700,10 +1808,10 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
   Value dest = rewriter.create<tensor::EmptyOp>(
       loc, reifiedReturnShapes[0],
       transposeOp.getResult().getType().getElementType());
-  Operation *write =
-      createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(), dest,
-                               /*inputVecSizesForLeadingDims=*/inputVectorSizes,
-                               /*useInBoundsInsteadOfMasking=*/false);
+  Operation *write = createWriteOrMaskedWrite(
+      rewriter, loc, transposeOp.getResult(), dest,
+      /*inputVecSizesForLeadingDims=*/inputVectorSizes, /*writeIndices=*/{},
+      /*useInBoundsInsteadOfMasking=*/false);
   newResults.push_back(write->getResult(0));
   return success();
 }
@@ -1839,10 +1947,10 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
   Value dest = rewriter.create<tensor::EmptyOp>(
       loc, reifiedRetShapes[0],
       shapeCastOp.getResult().getType().getElementType());
-  Operation *write =
-      createWriteOrMaskedWrite(rewriter, loc, shapeCastOp.getResult(), dest,
-                               /*inputVecSizesForLeadingDims=*/writeVectorSizes,
-                               useInBoundsInsteadOfMasking);
+  Operation *write = createWriteOrMaskedWrite(
+      rewriter, loc, shapeCastOp.getResult(), dest,
+      /*inputVecSizesForLeadingDims=*/writeVectorSizes,
+      /*writeIndices=*/{}, useInBoundsInsteadOfMasking);
   newResults.push_back(write->getResult(0));
   return success();
 }
@@ -1874,10 +1982,10 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
   // Create Xfer write Op
   Value dest = rewriter.create<tensor::EmptyOp>(
       loc, reifiedReturnShapes[0], padOp.getResultType().getElementType());
-  Operation *write =
-      createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest,
-                               /*inputVecSizesForLeadingDims=*/inputVectorSizes,
-                               /*useInBoundsInsteadOfMasking=*/false);
+  Operation *write = createWriteOrMaskedWrite(
+      rewriter, loc, maskedRead, dest,
+      /*inputVecSizesForLeadingDims=*/inputVectorSizes, {},
+      /*useInBoundsInsteadOfMasking=*/false);
   newResults.push_back(write->getResult(0));
   return success();
 }
@@ -2922,53 +3030,19 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
   auto vecType = VectorType::get(vecShape, sourceType.getElementType());
 
   // 3. Generate TransferReadOp + TransferWriteOp
-  ReifiedRankedShapedTypeDims reifiedSrcSizes;
-  Value maskOp;
-
-  // If vector sizes are user provided, make sure to mask. First, generate the
-  // mask.
-  if (!inputVectorSizes.empty()) {
-    auto *srcDefOp = source.getDefiningOp();
-    if (!srcDefOp) {
-      LDBG("Unable to get the defining Op of " << sliceOp);
-      return failure();
-    }
-
-    LogicalResult status =
-        cast<ReifyRankedShapedTypeOpInterface>(srcDefOp).reifyResultShapes(
-            rewriter, reifiedSrcSizes);
-    if (status.failed()) {
-      LDBG("Unable to reify result shapes of " << srcDefOp);
-      return failure();
-    }
-
-    // Create the mask
-    auto readMaskType = VectorType::get(inputVectorSizes, rewriter.getI1Type());
-    maskOp = rewriter.create<vector::CreateMaskOp>(
-        sliceOp.getLoc(), readMaskType, reifiedSrcSizes[0]);
-  }
+  auto loc = sliceOp.getLoc();
 
+  // Create read
   SmallVector<Value> readIndices(
-      vecType.getRank(),
-      rewriter.create<arith::ConstantIndexOp>(sliceOp.getLoc(), 0));
-  Operation *read = rewriter.create<vector::TransferReadOp>(
-      sliceOp.getLoc(), vecType, source, readIndices, padValue,
-      ArrayRef<bool>{readInBounds});
-
-  if (maskOp) {
-    read = mlir::vector::maskOperation(rewriter, read, maskOp);
-  }
-
-  auto writeIndices = getValueOrCreateConstantIndexOp(
-      rewriter, sliceOp.getLoc(), sliceOp.getMixedOffsets());
-
-  Operation *write = rewriter.create<vector::TransferWriteOp>(
-      sliceOp.getLoc(), read->getResult(0), sliceOp.getDest(), writeIndices,
-      ArrayRef<bool>{writeInBounds});
-
-  if (maskOp) {
-    write = mlir::vector::maskOperation(rewriter, write, maskOp);
-  }
+      vecType.getRank(), rewriter.create<arith::ConstantIndexOp>(loc, 0));
+  Value read = mlir::vector::createReadOrMaskedRead(
+      rewriter, loc, source, vecType.getShape(), padValue);
+
+  // Create write
+  auto writeIndices =
+      getValueOrCreateConstantIndexOp(rewriter, loc, sliceOp.getMixedOffsets());
+  Operation *write = createWriteOrMaskedWrite(
+      rewriter, loc, read, sliceOp.getDest(), vecType.getShape(), writeIndices);
 
   // 4. Finalize
   newResults.push_back(write->getResult(0));
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index d5dd6f2027be8..dda4856596bba 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -337,13 +337,13 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
   auto sourceShape = sourceShapedType.getShape();
   assert(sourceShape.size() == inputVectorSizes.size() &&
          "expected same ranks.");
-  auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type());
   auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());
   assert(padValue.getType() == sourceShapedType.getElementType() &&
          "expected same pad element type to match source element type");
   int64_t readRank = inputVectorSizes.size();
   auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
   SmallVector<bool> inBoundsVal(readRank, true);
+
   if (useInBoundsInsteadOfMasking) {
     // Update the inBounds attribute.
     for (unsigned i = 0; i < readRank; i++)
@@ -362,6 +362,8 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
     return transferReadOp;
   SmallVector<OpFoldResult> mixedSourceDims =
       tensor::getMixedSizes(builder, loc, source);
+
+  auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type());
   Value mask =
       builder.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
   return mlir::vector::maskOperation(builder, transferReadOp, mask)
diff --git a/mlir/test/Dialect/LLVM/transform-e2e.mlir b/mlir/test/Dialect/LLVM/transform-e2e.mlir
index c00b47fb936e9..98cfaf249c898 100644
--- a/mlir/test/Dialect/LLVM/transform-e2e.mlir
+++ b/mlir/test/Dialect/LLVM/transform-e2e.mlir
@@ -18,16 +18,14 @@ module attributes {transform.with_named_sequence} {
     %1, %loops:3 = transform.structured.tile_using_for %0 tile_sizes [2, 2, 2] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
     %2 = transform.get_parent_op %1 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
     transform.structured.vectorize_children_and_apply_patterns %2 : (!transform.any_op) -> !transform.any_op
-    %b = transform.bufferization.one_shot_bufferize layout{IdentityLayoutMap}
-        %module_op {bufferize_function_boundaries = true}
-        : (!transform.any_op) -> !transform.any_op
 
-    %f = transform.structured.match ops{["func.func"]} in %b
+    %f = transform.structured.match ops{["func.func"]} in %module_op
       : (!transform.any_op) -> !transform.any_op
 
     // TODO: group these lower-level controls into various properly named vector
     // lowering TD macros.
     transform.apply_patterns to %f {
+      transform.apply_patterns.vector.lower_masked_transfers
       transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
       transform.apply_patterns.vector.transfer_permutation_patterns
       transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerparallel"
@@ -37,6 +35,10 @@ module attributes {transform.with_named_sequence} {
       transform.apply_patterns.vector.lower_shape_cast
       transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_1d"
     } : !transform.any_op
+
+    %b = transform.bufferization.one_shot_bufferize layout{IdentityLayoutMap}
+        %module_op {bufferize_function_boundaries = true}
+        : (!transform.any_op) -> !transform.any_op
     transform.yield
   }
 }
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 8c6760fa50325..9a18f040d57cd 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -1027,4 +1027,3 @@ func.func @test_vectorize_unpack_no_vector_sizes_permute(%source: tensor<4x7x4xf
     transform.yield
   }
  }
-
diff --git a/mlir/test/Dialect/Linalg/vectorization/insert-slice-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization/insert-slice-with-patterns.mlir
index f7764be9be73f..d1f2ed194f6ce 100644
--- a/mlir/test/Dialect/Linalg/vectorization/insert-slice-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/insert-slice-with-patterns.mlir
@@ -67,10 +67,19 @@ module attributes {transform.with_named_sequence} {
 // CHECK-SAME:      %[[ARG_0:.*]]: tensor<1x?x3xf32>,
 // CHECK-SAME:      %[[PAD:.*]]: f32,
 // CHECK-SAME:      %[[SIZE:.*]]: index) -> tensor<9x8x7x1x2x3xf32> {
+// CHECK:           %[[C3:.*]] = arith.constant 3 : index
+// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
 // CHECK:           %[[EMPTY:.*]] = tensor.empty() : tensor<9x8x7x1x2x3xf32>
 // CHECK:           %[[BC:.*]] = vector.broadcast %[[PAD]] : f32 to vector<9x8x7x1x2x3xf32>
 // CHECK:           %[[WRITE:.*]] = vector.transfer_write %[[BC]], %[[EMPTY]]{{.*}} {in_bounds = [true, true, true, true, true, true]} : vector<9x8x7...
[truncated]

@banach-space banach-space changed the title [[mlir][linalg] Refactor vectorization hooks to improve code reuse [mlir][linalg] Refactor vectorization hooks to improve code reuse May 23, 2025
@banach-space banach-space force-pushed the users/banach-space/vector/update_vectorize_insert_slice branch from eccff09 to f0922e9 Compare May 24, 2025 11:49
banach-space added a commit that referenced this pull request May 27, 2025
This patch removes `inputVecSizesForLeadingDims` from the parameter list
of `createWriteOrMaskedWrite`. That argument is unnecessary — vector sizes
can be obtained from the `vecToStore` parameter. Since this doesn't change
behavior or test results, it's marked as NFC.

Additional cleanups:
  * Renamed `vectorToStore` to `vecToStore` for consistency and brevity.
  * Rewrote a conditional at the end of the function to use early exit,
    improving readability:

```cpp
  // BEFORE:
  if (maskingRequried) {
    Value maskForWrite = ...;
    write = maskOperation(write, maskForWrite);
  }
  return write;

  // AFTER
  if (!maskingRequried)
    return write;

  Value maskFroWrite = ...;
  return vector::maskOperation(builder, write, maskForWrite);
```

This change addresses a TODO from #141244.
banach-space added a commit that referenced this pull request May 29, 2025
This patch removes `inputVecSizesForLeadingDims` from the parameter list
of `createWriteOrMaskedWrite`. That argument is unnecessary — vector sizes
can be obtained from the `vecToStore` parameter. Since this doesn't change
behavior or test results, it's marked as NFC.

Additional cleanups:
  * Renamed `vectorToStore` to `vecToStore` for consistency and brevity.
  * Rewrote a conditional at the end of the function to use early exit,
    improving readability:

```cpp
  // BEFORE:
  if (maskingRequried) {
    Value maskForWrite = ...;
    write = maskOperation(write, maskForWrite);
  }
  return write;

  // AFTER
  if (!maskingRequried)
    return write;

  Value maskFroWrite = ...;
  return vector::maskOperation(builder, write, maskForWrite);
```

This change addresses a TODO from #141244.
@banach-space banach-space force-pushed the users/banach-space/vector/update_vectorize_insert_slice branch from f0922e9 to 82cc2fe Compare May 29, 2025 16:31
Copy link
Contributor

@Max191 Max191 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good to me, but there are some cases that used to use in_bounds, and now use masking, which I'm not 100% sure about. Maybe it would be easier to split that part out into a second PR, so it can be reviewed separately by others with more context?

@banach-space
Copy link
Contributor Author

Hey @Max191 , thanks for taking a lot and for your insightful comments - that's much appreciated 🙏🏻

I totally agree with your suggestions and have made the changes accordingly. As you will notice, all the changes in the following test files have been reverted:

  • insert-slice-with-patterns.mlir
  • pad-with-patterns.mlir.

That's expected - these are the test files in which we do not specify the vector sizes, hence there should be no masking.

Re future steps:

If it's unclear whether or not we should be doing this

Same. I intend to investigate this in the near future and might propose some changes. But I totally agree that this is quite nuanced and that we should approach this in small, incremental steps.

Copy link
Contributor

@Max191 Max191 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, this LGTM now! Thanks for addressing all the comments

This patch refactors two vectorization hooks in Vectorization.cpp:
 * `createWriteOrMaskedWrite` gains a new parameter for write indices,
   aligning it with its counterpart `createReadOrMaskedRead`.
 * `vectorizeAsInsertSliceOp` is updated to reuse both of the above
   hooks, rather than re-implementing similar logic.

CONTEXT
-------
This is effectively a refactoring of the logic for vectorizing
`tensor.insert_slice`. Recent updates added masking support:
  * #122927
  * #123031

At the time, reuse of the shared `create*` hooks wasn't feasible due to
missing parameters and overly rigid assumptions. This patch resolves
that and moves us closer to a more maintainable structure.

CHANGES IN `vectorizeAsInsertSliceOp`
-------------------------------------
* Introduces a clear distinction between the destination tensor and the
  vector to store, via named variables like `destType`/`vecToStoreType`,
  `destShape`/`vecToStoreShape`, etc.
* Ensures the correct rank and shape are used for attributes like
  in_bounds. For example, the size of the in_bounds array now matches
  the source vector rank, not the tensor rank.
* Drops the assumption that `vecToStoreRank == destRank` — this doesn't
  hold in many real examples.
*  Deduces mask dimensions from `vecToStoreShape` (vector) instead of
   `destShape` (tensor). (Eventually we should not require
   `inputVecSizesForLeadingDims` at all — mask shape should be inferred.)

NEW HELPER: `isMaskTriviallyFoldable`
-------------------------------------
Adds a utility to detect when masking is unnecessary. This avoids
inserting redundant masks and reduces the burden on canonicalization to
clean them up later.

Example where masking is provably unnecessary:
```mlir
%2 = vector.mask %1 {
  vector.transfer_write %0, %arg1[%c0, %c0, %c0, %c0, %c0, %c0]
    {in_bounds = [true, true, true]}
    : vector<1x2x3xf32>, tensor<9x8x7x1x2x3xf32>
} : vector<1x2x3xi1> -> tensor<9x8x7x1x2x3xf32>
```

Also, without this hook, tests are more complicated and require more
matching.

TEST CHANGES
-----------
This patch primarily affects vectorization of:
  * `tensor.insert_slice`, now refactored to use shared hooks.

`tensor.pad` vectorization patterns, which internally use
`tensor.insert_slice`, are also _effectively_ updated. Note, only
pad-with-patterns.mlir is affected.

Most test updates involve the insertion of masks that were previously
missing — this reflects a correctness fix, not a regression. In all
cases, the added masks are indeed required.

You’ll also notice more repeated constants (`arith.constant 0 : index`),
due to increased use of helper hooks. This will be cleaned up separately
via a constant cache (see #138265 for discussion).

NOTE FOR REVIEWERS
------------------
This is a fairly substantial rewrite. You may find it easier to review
`createWriteOrMaskedWrite` as a new method rather than diffing
line-by-line.

TODOs (future PRs)
------------------
Further alignment of `createWriteOrMaskedWrite` and
`createReadOrMaskedRead`:
  * Move `createWriteOrMaskedWrite` next to `createReadOrMaskedRead` (in
    VectorUtils.cpp)
  * Make `createReadOrMaskedRead` leverage `isMaskTriviallyFoldable`.
  * Extend `isMaskTriviallyFoldable` with value-bounds-analysis. See the
    updated test in transform-vector.mlir for an example that would
    benefit from this.

(* This method will eventually be moved out of Vectorization.cpp, which isn't the right long-term home for it.)
* Restore the original behaviour in `vectorizeAsInsertSliceOp`, whereby
  the `in_bounds` attribute was used to identify potentially
  out-of-bounds accesses. Masks are only used when input vector sizes
  are specified.
* Revert the changes in insert-slice-with-patterns.mlir and
  pad-with-patterns.mlir, i.e. the tests in which we don't specify
  vector sizes.
* Other minor updates.
…code reuse

* Restore the changes in transform-e2e.mlir + transform-vector.mlir
* Updated in_bounds attribute calculation in `createWriteOrMaskedWrite`
  - otherwise transform-e2e.mlir goes into an infite loop. I will create
    a repro and open a GitHub issue before landing this.
* The in_bounds attribute calculaiton is incorrect and I will create a
  GitHub ticket to fix it before merging this. See the comments in this
  patch.
@banach-space banach-space force-pushed the users/banach-space/vector/update_vectorize_insert_slice branch from 42b1783 to 373036e Compare May 30, 2025 10:17
@banach-space
Copy link
Contributor Author

Update 30/5/25

Updated the summary and rebased. I’ve made some small tweaks so that I could revert the changes made to transform-vector.mlir and transform-e2e.mlir in my first commit. This way, only insert_slice.mlir is updated, making it clear that:

  • These changes only affect vectorizeAsInsertSliceOp.
  • The vectorizer’s behaviour is otherwise preserved.

As noted in the summary, I did identify some issues - these are tracked here:

I’ll be away next week, so I plan to wait until I’m back before landing this. That also gives other potential reviewers some time to take a look :) If there are no new comments by then, I’ll go ahead and merge.

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants