Skip to content

Commit 82cc2fe

Browse files
committed
[mlir][linalg] Refactor vectorization hooks to improve code reuse
This patch refactors two vectorization hooks in Vectorization.cpp: * `createWriteOrMaskedWrite` gains a new parameter for write indices, aligning it with its counterpart `createReadOrMaskedRead`. * `vectorizeAsInsertSliceOp` is updated to reuse both of the above hooks, rather than re-implementing similar logic. CONTEXT ------- This is effectively a refactoring of the logic for vectorizing `tensor.insert_slice`. Recent updates added masking support: * #122927 * #123031 At the time, reuse of the shared `create*` hooks wasn't feasible due to missing parameters and overly rigid assumptions. This patch resolves that and moves us closer to a more maintainable structure. CHANGES IN `vectorizeAsInsertSliceOp` ------------------------------------- * Introduces a clear distinction between the destination tensor and the vector to store, via named variables like `destType`/`vecToStoreType`, `destShape`/`vecToStoreShape`, etc. * Ensures the correct rank and shape are used for attributes like in_bounds. For example, the size of the in_bounds array now matches the source vector rank, not the tensor rank. * Drops the assumption that `vecToStoreRank == destRank` — this doesn't hold in many real examples. * Deduces mask dimensions from `vecToStoreShape` (vector) instead of `destShape` (tensor). (Eventually we should not require `inputVecSizesForLeadingDims` at all — mask shape should be inferred.) NEW HELPER: `isMaskTriviallyFoldable` ------------------------------------- Adds a utility to detect when masking is unnecessary. This avoids inserting redundant masks and reduces the burden on canonicalization to clean them up later. Example where masking is provably unnecessary: ```mlir %2 = vector.mask %1 { vector.transfer_write %0, %arg1[%c0, %c0, %c0, %c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x2x3xf32>, tensor<9x8x7x1x2x3xf32> } : vector<1x2x3xi1> -> tensor<9x8x7x1x2x3xf32> ``` Also, without this hook, tests are more complicated and require more matching. TEST CHANGES ----------- This patch primarily affects vectorization of: * `tensor.insert_slice`, now refactored to use shared hooks. `tensor.pad` vectorization patterns, which internally use `tensor.insert_slice`, are also _effectively_ updated. Note, only pad-with-patterns.mlir is affected. Most test updates involve the insertion of masks that were previously missing — this reflects a correctness fix, not a regression. In all cases, the added masks are indeed required. You’ll also notice more repeated constants (`arith.constant 0 : index`), due to increased use of helper hooks. This will be cleaned up separately via a constant cache (see #138265 for discussion). NOTE FOR REVIEWERS ------------------ This is a fairly substantial rewrite. You may find it easier to review `createWriteOrMaskedWrite` as a new method rather than diffing line-by-line. TODOs (future PRs) ------------------ Further alignment of `createWriteOrMaskedWrite` and `createReadOrMaskedRead`: * Move `createWriteOrMaskedWrite` next to `createReadOrMaskedRead` (in VectorUtils.cpp) * Make `createReadOrMaskedRead` leverage `isMaskTriviallyFoldable`. * Extend `isMaskTriviallyFoldable` with value-bounds-analysis. See the updated test in transform-vector.mlir for an example that would benefit from this. (* This method will eventually be moved out of Vectorization.cpp, which isn't the right long-term home for it.)
1 parent 28eb66b commit 82cc2fe

File tree

8 files changed

+299
-149
lines changed

8 files changed

+299
-149
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 182 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1506,20 +1506,120 @@ static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
15061506
return applyPermutation(destShape, linalg::getPackInverseDestPerm(packOp));
15071507
}
15081508

1509+
/// Determines whether a mask for xfer_write is trivially "all true"
1510+
///
1511+
/// Given all the inputs required to generate a mask (mask sizes and shapes),
1512+
/// and an xfer_write operation (write indices and the destination tensor
1513+
/// shape), determines whether the corresponding mask would be trivially
1514+
/// foldable (i.e., trivially "all true").
1515+
///
1516+
/// Use this method to avoid generating spurious masks and relaying on
1517+
/// vectorization post-processing to remove them.
1518+
///
1519+
/// Pre-conditions for a mask to be trivially foldable:
1520+
/// * All involved shapes (mask + destination tensor) are static.
1521+
/// * All write indices are constant.
1522+
/// * All mask sizes are constant (including `arith.constant`).
1523+
///
1524+
/// If the pre-conditions are met, the method checks for each destination
1525+
/// dimension `d`:
1526+
/// (1) destDimSize[rankDiff + d] <= maskShape[d]
1527+
/// (2) destDimSize[rankDiff + d] <= writeIndex[d] + maskSize[d]
1528+
///
1529+
/// rankDiff = rank(dest) - rank(mask).
1530+
///
1531+
/// This method takes a conservative view: it may return false even if the mask
1532+
/// is technically foldable.
1533+
///
1534+
/// EXAMPLE 1 (trivially foldable, all shapes match, mask sizes match the shape
1535+
/// of the dest tensor):
1536+
/// %c0 = arith.constant 0 : index
1537+
/// %mask = vector.create_mask 5, 1
1538+
/// vector.mask %mask {
1539+
/// vector.transfer_write %vecToStore_1, %dest{[%c0, %c0]
1540+
/// {in_bounds = [true, true]}
1541+
/// : vector<5x1xi32>, tensor<5x1xi32>
1542+
/// }
1543+
///
1544+
/// EXAMPLE 2 (not trivially foldable - vector shape exceeds the tensor shape,
1545+
/// mask is required to avoid out-of-bounds write):
1546+
/// %c0 = arith.constant 0 : index
1547+
/// %mask = vector.create_mask 5, 1
1548+
/// vector.mask %mask {
1549+
/// vector.transfer_write %vecToStore_2, %dest[%c0, %c0]
1550+
/// {in_bounds = [true, true]}
1551+
/// : vector<8x1xi32>, tensor<5x1xi32>
1552+
/// }
1553+
///
1554+
/// TODO: Re-use in createReadOrMaskedRead
1555+
static bool isMaskTriviallyFoldable(SmallVector<OpFoldResult> &maskSizes,
1556+
SmallVector<Value> &writeIdxs,
1557+
ArrayRef<int64_t> destShape,
1558+
ArrayRef<int64_t> maskShape) {
1559+
// Masking is unavoidable in the case of dynamic tensors.
1560+
if (ShapedType::isDynamicShape(destShape))
1561+
return false;
1562+
1563+
// Collect all constant mask sizes.
1564+
SmallVector<int64_t, 4> cstMaskSizes;
1565+
for (auto [i, dimSize] : llvm::enumerate(maskSizes)) {
1566+
if (auto intSize = getConstantIntValue(dimSize)) {
1567+
cstMaskSizes.push_back(*intSize);
1568+
}
1569+
}
1570+
1571+
// If any of the mask sizes is non-constant, bail out.
1572+
if (cstMaskSizes.size() != maskShape.size())
1573+
return false;
1574+
1575+
// Collect all constant write indices.
1576+
SmallVector<int64_t, 4> cstWriteIdxs;
1577+
for (auto [i, idx] : llvm::enumerate(writeIdxs)) {
1578+
APSInt intVal;
1579+
if (matchPattern(idx, m_ConstantInt(&intVal))) {
1580+
cstWriteIdxs.push_back(intVal.getSExtValue());
1581+
}
1582+
}
1583+
1584+
// If any of the write indices is non-constant, bail out.
1585+
if (cstWriteIdxs.size() != destShape.size())
1586+
return false;
1587+
1588+
// Go over all destination dims and check (1) and (2). Take into account that:
1589+
// * The number of mask sizes will match the rank of the vector to store.
1590+
// This could be lower than the rank of the destination tensor.
1591+
// * Mask sizes could be larger than the corresponding mask shape (hence
1592+
// `clamp`).
1593+
// TODO: The 2nd item should be rejected by the verifier.
1594+
int64_t rankDiff = destShape.size() - cstMaskSizes.size();
1595+
for (auto [i, idx] : llvm::enumerate(cstMaskSizes)) {
1596+
if (/*(1)*/ maskShape[i] > destShape[rankDiff + i] ||
1597+
/*(2)*/ destShape[rankDiff + i] <
1598+
(std::clamp(cstMaskSizes[i], int64_t(0), maskShape[i]) +
1599+
cstWriteIdxs[i]))
1600+
return false;
1601+
}
1602+
1603+
return true;
1604+
}
1605+
15091606
/// Creates an optionally masked TransferWriteOp
15101607
///
15111608
/// Generates the following operation:
15121609
/// %res = vector.transfer_write %vectorToStore into %dest
15131610
///
1514-
/// If the leading N dimensions of the destination tensor do not match
1611+
/// If the leading N dimensions of the vector to store do not match
15151612
/// `inputVecSizesForLeadingDims` (N = rank(inputVecSizesForLeadingDims)),
15161613
/// masking is applied to ensure correctness:
15171614
///
1518-
/// %mask = vector.create_mask(%destShape)
1615+
/// %mask = vector.create_mask(%destShape) : %vectorToStoreShape
15191616
/// %res = vector.mask %mask {
15201617
/// vector.transfer_write %vectorToStore into %dest
15211618
/// }
15221619
///
1620+
/// The mask shape is identical to `vectorToStore` (with the element type ==
1621+
/// i1), and the mask values are based on the shape of the `dest` tensor.
1622+
///
15231623
/// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
15241624
/// is used instead of masking:
15251625
///
@@ -1528,75 +1628,99 @@ static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
15281628
/// %res = vector.transfer_write %input into %dest
15291629
/// {in_bounds = in_bounds_flags}
15301630
///
1531-
/// NOTE: All write offsets are set to 0.
1532-
/// TODO: Allow specyfying write offsets.
1533-
/// NOTE: When N < rank(input), the missing vector sizes are effectively
1534-
/// extracted from the trailing sizes of `destSizes`. This means those sizes
1535-
/// must be static.
1536-
/// TODO: Support cases where an arbitrary dim is dynamic - this will require
1537-
/// specifying all the vector sizes.
1631+
/// `writeIndices` specifies the offsets to use. If empty, all indices are set
1632+
/// to 0.
1633+
///
1634+
/// NOTE: When N < rank(vectorToStore), the missing vector sizes are taken from
1635+
/// `valueToStore`.
1636+
/// TODO: `inputVecSizesForLeadingDims` should not be required - these sizes are
1637+
/// already provided in `vectorToStore`.
15381638
static Operation *
15391639
createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
15401640
Value dest,
15411641
ArrayRef<int64_t> inputVecSizesForLeadingDims,
1642+
SmallVector<Value> writeIndices = {},
15421643
bool useInBoundsInsteadOfMasking = false) {
15431644

15441645
ShapedType destType = cast<ShapedType>(dest.getType());
1545-
assert(cast<VectorType>(vectorToStore.getType()).getRank() ==
1546-
static_cast<int64_t>(destType.getRank()) &&
1547-
"Rank mismatch!");
1548-
(void)destType;
1646+
int64_t destRank = destType.getRank();
1647+
auto destShape = destType.getShape();
15491648

1550-
int64_t rank = cast<ShapedType>(dest.getType()).getRank();
1551-
auto destShape = cast<ShapedType>(dest.getType()).getShape();
1649+
VectorType vecToStoreType = cast<VectorType>(vectorToStore.getType());
1650+
int64_t vecToStoreRank = vecToStoreType.getRank();
1651+
auto vecToStoreShape = vecToStoreType.getShape();
15521652

15531653
// Compute the in_bounds attribute
1554-
SmallVector<bool> inBoundsVal(rank, true);
1654+
SmallVector<bool> inBoundsVal(vecToStoreRank, true);
15551655
if (useInBoundsInsteadOfMasking) {
15561656
// In this case, assume that all the required vector sizes have been
15571657
// provided.
15581658
assert(inputVecSizesForLeadingDims.size() ==
1559-
static_cast<size_t>(destType.getRank()) &&
1659+
static_cast<size_t>(vecToStoreType.getRank()) &&
15601660
"Insufficient number of input vector sizes!");
15611661
// Update the inBounds attribute.
1562-
for (unsigned i = 0; i < rank; i++)
1662+
for (unsigned i = 0; i < destRank; i++)
15631663
inBoundsVal[i] = (destShape[i] == inputVecSizesForLeadingDims[i]) &&
15641664
!ShapedType::isDynamic(destShape[i]);
15651665
}
15661666

1667+
// If missing, initialize the write indices to 0.
1668+
assert(writeIndices.empty() ||
1669+
writeIndices.size() == static_cast<size_t>(destRank) &&
1670+
"Invalid number of write indices!");
1671+
if (writeIndices.empty()) {
1672+
auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
1673+
writeIndices = SmallVector<Value>(destRank, zero);
1674+
}
1675+
15671676
// Generate the xfer_write Op
1568-
auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
1569-
Operation *write = builder.create<vector::TransferWriteOp>(
1570-
loc,
1571-
/*vector=*/vectorToStore,
1572-
/*source=*/dest,
1573-
/*indices=*/SmallVector<Value>(rank, zero),
1574-
/*inBounds=*/inBoundsVal);
1575-
assert(llvm::none_of(
1576-
destShape.drop_front(inputVecSizesForLeadingDims.size()),
1577-
[](int64_t size) { return size == ShapedType::kDynamic; }) &&
1578-
"Only dims aligned with inputVecSizesForLeadingDims may be dynamic");
1677+
Operation *write =
1678+
builder.create<vector::TransferWriteOp>(loc,
1679+
/*vector=*/vectorToStore,
1680+
/*source=*/dest,
1681+
/*indices=*/writeIndices,
1682+
/*inBounds=*/inBoundsVal);
15791683

15801684
// If masking is disabled, exit.
15811685
if (useInBoundsInsteadOfMasking)
15821686
return write;
15831687

1688+
assert(llvm::none_of(
1689+
destShape.drop_front(inputVecSizesForLeadingDims.size()),
1690+
[](int64_t size) { return size == ShapedType::kDynamic; }) &&
1691+
"Only dims aligned with inputVecSizesForLeadingDims may be dynamic");
1692+
15841693
// Check if masking is needed.
15851694
bool needMaskForWrite =
15861695
!llvm::equal(inputVecSizesForLeadingDims,
1587-
destShape.take_front(inputVecSizesForLeadingDims.size()));
1696+
destShape.take_front(destRank - vecToStoreRank +
1697+
inputVecSizesForLeadingDims.size()));
15881698

15891699
// If masking is needed, generate the mask and mask the operation.
15901700
if (needMaskForWrite) {
1701+
// Get the mask shape + type. Missing mask dimensions are taken from
1702+
// `vectorToStore`.
15911703
SmallVector<int64_t> writeMaskShape;
15921704
writeMaskShape.append(inputVecSizesForLeadingDims.begin(),
15931705
inputVecSizesForLeadingDims.end());
1594-
writeMaskShape.append(destShape.begin() +
1595-
inputVecSizesForLeadingDims.size(),
1596-
destShape.end());
1706+
if (vecToStoreRank >
1707+
static_cast<int64_t>(inputVecSizesForLeadingDims.size()))
1708+
writeMaskShape.append(vecToStoreShape.begin() +
1709+
inputVecSizesForLeadingDims.size(),
1710+
vecToStoreShape.end());
15971711
auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
1598-
Value maskForWrite = builder.create<vector::CreateMaskOp>(
1599-
loc, writeMaskType, tensor::getMixedSizes(builder, loc, dest));
1712+
1713+
SmallVector<OpFoldResult> destSizes =
1714+
tensor::getMixedSizes(builder, loc, dest);
1715+
SmallVector<OpFoldResult> maskSizes(destSizes.end() - writeMaskShape.size(),
1716+
destSizes.end());
1717+
1718+
if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
1719+
writeMaskShape))
1720+
return write;
1721+
1722+
Value maskForWrite = builder.createOrFold<vector::CreateMaskOp>(
1723+
loc, writeMaskType, maskSizes);
16001724
write = mlir::vector::maskOperation(builder, write, maskForWrite);
16011725
}
16021726

@@ -1700,10 +1824,10 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
17001824
Value dest = rewriter.create<tensor::EmptyOp>(
17011825
loc, reifiedReturnShapes[0],
17021826
transposeOp.getResult().getType().getElementType());
1703-
Operation *write =
1704-
createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(), dest,
1705-
/*inputVecSizesForLeadingDims=*/inputVectorSizes,
1706-
/*useInBoundsInsteadOfMasking=*/false);
1827+
Operation *write = createWriteOrMaskedWrite(
1828+
rewriter, loc, transposeOp.getResult(), dest,
1829+
/*inputVecSizesForLeadingDims=*/inputVectorSizes, /*writeIndices=*/{},
1830+
/*useInBoundsInsteadOfMasking=*/false);
17071831
newResults.push_back(write->getResult(0));
17081832
return success();
17091833
}
@@ -1839,10 +1963,10 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
18391963
Value dest = rewriter.create<tensor::EmptyOp>(
18401964
loc, reifiedRetShapes[0],
18411965
shapeCastOp.getResult().getType().getElementType());
1842-
Operation *write =
1843-
createWriteOrMaskedWrite(rewriter, loc, shapeCastOp.getResult(), dest,
1844-
/*inputVecSizesForLeadingDims=*/writeVectorSizes,
1845-
useInBoundsInsteadOfMasking);
1966+
Operation *write = createWriteOrMaskedWrite(
1967+
rewriter, loc, shapeCastOp.getResult(), dest,
1968+
/*inputVecSizesForLeadingDims=*/writeVectorSizes,
1969+
/*writeIndices=*/{}, useInBoundsInsteadOfMasking);
18461970
newResults.push_back(write->getResult(0));
18471971
return success();
18481972
}
@@ -1874,10 +1998,10 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
18741998
// Create Xfer write Op
18751999
Value dest = rewriter.create<tensor::EmptyOp>(
18762000
loc, reifiedReturnShapes[0], padOp.getResultType().getElementType());
1877-
Operation *write =
1878-
createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest,
1879-
/*inputVecSizesForLeadingDims=*/inputVectorSizes,
1880-
/*useInBoundsInsteadOfMasking=*/false);
2001+
Operation *write = createWriteOrMaskedWrite(
2002+
rewriter, loc, maskedRead, dest,
2003+
/*inputVecSizesForLeadingDims=*/inputVectorSizes, {},
2004+
/*useInBoundsInsteadOfMasking=*/false);
18812005
newResults.push_back(write->getResult(0));
18822006
return success();
18832007
}
@@ -2922,53 +3046,19 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
29223046
auto vecType = VectorType::get(vecShape, sourceType.getElementType());
29233047

29243048
// 3. Generate TransferReadOp + TransferWriteOp
2925-
ReifiedRankedShapedTypeDims reifiedSrcSizes;
2926-
Value maskOp;
2927-
2928-
// If vector sizes are user provided, make sure to mask. First, generate the
2929-
// mask.
2930-
if (!inputVectorSizes.empty()) {
2931-
auto *srcDefOp = source.getDefiningOp();
2932-
if (!srcDefOp) {
2933-
LDBG("Unable to get the defining Op of " << sliceOp);
2934-
return failure();
2935-
}
2936-
2937-
LogicalResult status =
2938-
cast<ReifyRankedShapedTypeOpInterface>(srcDefOp).reifyResultShapes(
2939-
rewriter, reifiedSrcSizes);
2940-
if (status.failed()) {
2941-
LDBG("Unable to reify result shapes of " << srcDefOp);
2942-
return failure();
2943-
}
2944-
2945-
// Create the mask
2946-
auto readMaskType = VectorType::get(inputVectorSizes, rewriter.getI1Type());
2947-
maskOp = rewriter.create<vector::CreateMaskOp>(
2948-
sliceOp.getLoc(), readMaskType, reifiedSrcSizes[0]);
2949-
}
3049+
auto loc = sliceOp.getLoc();
29503050

3051+
// Create read
29513052
SmallVector<Value> readIndices(
2952-
vecType.getRank(),
2953-
rewriter.create<arith::ConstantIndexOp>(sliceOp.getLoc(), 0));
2954-
Operation *read = rewriter.create<vector::TransferReadOp>(
2955-
sliceOp.getLoc(), vecType, source, readIndices, padValue,
2956-
ArrayRef<bool>{readInBounds});
2957-
2958-
if (maskOp) {
2959-
read = mlir::vector::maskOperation(rewriter, read, maskOp);
2960-
}
2961-
2962-
auto writeIndices = getValueOrCreateConstantIndexOp(
2963-
rewriter, sliceOp.getLoc(), sliceOp.getMixedOffsets());
2964-
2965-
Operation *write = rewriter.create<vector::TransferWriteOp>(
2966-
sliceOp.getLoc(), read->getResult(0), sliceOp.getDest(), writeIndices,
2967-
ArrayRef<bool>{writeInBounds});
2968-
2969-
if (maskOp) {
2970-
write = mlir::vector::maskOperation(rewriter, write, maskOp);
2971-
}
3053+
vecType.getRank(), rewriter.create<arith::ConstantIndexOp>(loc, 0));
3054+
Value read = mlir::vector::createReadOrMaskedRead(
3055+
rewriter, loc, source, vecType.getShape(), padValue);
3056+
3057+
// Create write
3058+
auto writeIndices =
3059+
getValueOrCreateConstantIndexOp(rewriter, loc, sliceOp.getMixedOffsets());
3060+
Operation *write = createWriteOrMaskedWrite(
3061+
rewriter, loc, read, sliceOp.getDest(), vecType.getShape(), writeIndices);
29723062

29733063
// 4. Finalize
29743064
newResults.push_back(write->getResult(0));

mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,13 +337,13 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
337337
auto sourceShape = sourceShapedType.getShape();
338338
assert(sourceShape.size() == inputVectorSizes.size() &&
339339
"expected same ranks.");
340-
auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type());
341340
auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());
342341
assert(padValue.getType() == sourceShapedType.getElementType() &&
343342
"expected same pad element type to match source element type");
344343
int64_t readRank = inputVectorSizes.size();
345344
auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
346345
SmallVector<bool> inBoundsVal(readRank, true);
346+
347347
if (useInBoundsInsteadOfMasking) {
348348
// Update the inBounds attribute.
349349
for (unsigned i = 0; i < readRank; i++)
@@ -362,6 +362,8 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
362362
return transferReadOp;
363363
SmallVector<OpFoldResult> mixedSourceDims =
364364
tensor::getMixedSizes(builder, loc, source);
365+
366+
auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type());
365367
Value mask =
366368
builder.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
367369
return mlir::vector::maskOperation(builder, transferReadOp, mask)

0 commit comments

Comments
 (0)