@@ -1506,20 +1506,120 @@ static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
1506
1506
return applyPermutation (destShape, linalg::getPackInverseDestPerm (packOp));
1507
1507
}
1508
1508
1509
+ // / Determines whether a mask for xfer_write is trivially "all true"
1510
+ // /
1511
+ // / Given all the inputs required to generate a mask (mask sizes and shapes),
1512
+ // / and an xfer_write operation (write indices and the destination tensor
1513
+ // / shape), determines whether the corresponding mask would be trivially
1514
+ // / foldable (i.e., trivially "all true").
1515
+ // /
1516
+ // / Use this method to avoid generating spurious masks and relaying on
1517
+ // / vectorization post-processing to remove them.
1518
+ // /
1519
+ // / Pre-conditions for a mask to be trivially foldable:
1520
+ // / * All involved shapes (mask + destination tensor) are static.
1521
+ // / * All write indices are constant.
1522
+ // / * All mask sizes are constant (including `arith.constant`).
1523
+ // /
1524
+ // / If the pre-conditions are met, the method checks for each destination
1525
+ // / dimension `d`:
1526
+ // / (1) destDimSize[rankDiff + d] <= maskShape[d]
1527
+ // / (2) destDimSize[rankDiff + d] <= writeIndex[d] + maskSize[d]
1528
+ // /
1529
+ // / rankDiff = rank(dest) - rank(mask).
1530
+ // /
1531
+ // / This method takes a conservative view: it may return false even if the mask
1532
+ // / is technically foldable.
1533
+ // /
1534
+ // / EXAMPLE 1 (trivially foldable, all shapes match, mask sizes match the shape
1535
+ // / of the dest tensor):
1536
+ // / %c0 = arith.constant 0 : index
1537
+ // / %mask = vector.create_mask 5, 1
1538
+ // / vector.mask %mask {
1539
+ // / vector.transfer_write %vecToStore_1, %dest{[%c0, %c0]
1540
+ // / {in_bounds = [true, true]}
1541
+ // / : vector<5x1xi32>, tensor<5x1xi32>
1542
+ // / }
1543
+ // /
1544
+ // / EXAMPLE 2 (not trivially foldable - vector shape exceeds the tensor shape,
1545
+ // / mask is required to avoid out-of-bounds write):
1546
+ // / %c0 = arith.constant 0 : index
1547
+ // / %mask = vector.create_mask 5, 1
1548
+ // / vector.mask %mask {
1549
+ // / vector.transfer_write %vecToStore_2, %dest[%c0, %c0]
1550
+ // / {in_bounds = [true, true]}
1551
+ // / : vector<8x1xi32>, tensor<5x1xi32>
1552
+ // / }
1553
+ // /
1554
+ // / TODO: Re-use in createReadOrMaskedRead
1555
+ static bool isMaskTriviallyFoldable (SmallVector<OpFoldResult> &maskSizes,
1556
+ SmallVector<Value> &writeIdxs,
1557
+ ArrayRef<int64_t > destShape,
1558
+ ArrayRef<int64_t > maskShape) {
1559
+ // Masking is unavoidable in the case of dynamic tensors.
1560
+ if (ShapedType::isDynamicShape (destShape))
1561
+ return false ;
1562
+
1563
+ // Collect all constant mask sizes.
1564
+ SmallVector<int64_t , 4 > cstMaskSizes;
1565
+ for (auto [i, dimSize] : llvm::enumerate (maskSizes)) {
1566
+ if (auto intSize = getConstantIntValue (dimSize)) {
1567
+ cstMaskSizes.push_back (*intSize);
1568
+ }
1569
+ }
1570
+
1571
+ // If any of the mask sizes is non-constant, bail out.
1572
+ if (cstMaskSizes.size () != maskShape.size ())
1573
+ return false ;
1574
+
1575
+ // Collect all constant write indices.
1576
+ SmallVector<int64_t , 4 > cstWriteIdxs;
1577
+ for (auto [i, idx] : llvm::enumerate (writeIdxs)) {
1578
+ APSInt intVal;
1579
+ if (matchPattern (idx, m_ConstantInt (&intVal))) {
1580
+ cstWriteIdxs.push_back (intVal.getSExtValue ());
1581
+ }
1582
+ }
1583
+
1584
+ // If any of the write indices is non-constant, bail out.
1585
+ if (cstWriteIdxs.size () != destShape.size ())
1586
+ return false ;
1587
+
1588
+ // Go over all destination dims and check (1) and (2). Take into account that:
1589
+ // * The number of mask sizes will match the rank of the vector to store.
1590
+ // This could be lower than the rank of the destination tensor.
1591
+ // * Mask sizes could be larger than the corresponding mask shape (hence
1592
+ // `clamp`).
1593
+ // TODO: The 2nd item should be rejected by the verifier.
1594
+ int64_t rankDiff = destShape.size () - cstMaskSizes.size ();
1595
+ for (auto [i, idx] : llvm::enumerate (cstMaskSizes)) {
1596
+ if (/* (1)*/ maskShape[i] > destShape[rankDiff + i] ||
1597
+ /* (2)*/ destShape[rankDiff + i] <
1598
+ (std::clamp (cstMaskSizes[i], int64_t (0 ), maskShape[i]) +
1599
+ cstWriteIdxs[i]))
1600
+ return false ;
1601
+ }
1602
+
1603
+ return true ;
1604
+ }
1605
+
1509
1606
// / Creates an optionally masked TransferWriteOp
1510
1607
// /
1511
1608
// / Generates the following operation:
1512
1609
// / %res = vector.transfer_write %vectorToStore into %dest
1513
1610
// /
1514
- // / If the leading N dimensions of the destination tensor do not match
1611
+ // / If the leading N dimensions of the vector to store do not match
1515
1612
// / `inputVecSizesForLeadingDims` (N = rank(inputVecSizesForLeadingDims)),
1516
1613
// / masking is applied to ensure correctness:
1517
1614
// /
1518
- // / %mask = vector.create_mask(%destShape)
1615
+ // / %mask = vector.create_mask(%destShape) : %vectorToStoreShape
1519
1616
// / %res = vector.mask %mask {
1520
1617
// / vector.transfer_write %vectorToStore into %dest
1521
1618
// / }
1522
1619
// /
1620
+ // / The mask shape is identical to `vectorToStore` (with the element type ==
1621
+ // / i1), and the mask values are based on the shape of the `dest` tensor.
1622
+ // /
1523
1623
// / If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
1524
1624
// / is used instead of masking:
1525
1625
// /
@@ -1528,75 +1628,99 @@ static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
1528
1628
// / %res = vector.transfer_write %input into %dest
1529
1629
// / {in_bounds = in_bounds_flags}
1530
1630
// /
1531
- // / NOTE: All write offsets are set to 0.
1532
- // / TODO: Allow specyfying write offsets .
1533
- // / NOTE: When N < rank(input), the missing vector sizes are effectively
1534
- // / extracted from the trailing sizes of `destSizes`. This means those sizes
1535
- // / must be static .
1536
- // / TODO: Support cases where an arbitrary dim is dynamic - this will require
1537
- // / specifying all the vector sizes .
1631
+ // / `writeIndices` specifies the offsets to use. If empty, all indices are set
1632
+ // / to 0 .
1633
+ // /
1634
+ // / NOTE: When N < rank(vectorToStore), the missing vector sizes are taken from
1635
+ // / `valueToStore` .
1636
+ // / TODO: `inputVecSizesForLeadingDims` should not be required - these sizes are
1637
+ // / already provided in `vectorToStore` .
1538
1638
static Operation *
1539
1639
createWriteOrMaskedWrite (OpBuilder &builder, Location loc, Value vectorToStore,
1540
1640
Value dest,
1541
1641
ArrayRef<int64_t > inputVecSizesForLeadingDims,
1642
+ SmallVector<Value> writeIndices = {},
1542
1643
bool useInBoundsInsteadOfMasking = false ) {
1543
1644
1544
1645
ShapedType destType = cast<ShapedType>(dest.getType ());
1545
- assert (cast<VectorType>(vectorToStore.getType ()).getRank () ==
1546
- static_cast <int64_t >(destType.getRank ()) &&
1547
- " Rank mismatch!" );
1548
- (void )destType;
1646
+ int64_t destRank = destType.getRank ();
1647
+ auto destShape = destType.getShape ();
1549
1648
1550
- int64_t rank = cast<ShapedType>(dest.getType ()).getRank ();
1551
- auto destShape = cast<ShapedType>(dest.getType ()).getShape ();
1649
+ VectorType vecToStoreType = cast<VectorType>(vectorToStore.getType ());
1650
+ int64_t vecToStoreRank = vecToStoreType.getRank ();
1651
+ auto vecToStoreShape = vecToStoreType.getShape ();
1552
1652
1553
1653
// Compute the in_bounds attribute
1554
- SmallVector<bool > inBoundsVal (rank , true );
1654
+ SmallVector<bool > inBoundsVal (vecToStoreRank , true );
1555
1655
if (useInBoundsInsteadOfMasking) {
1556
1656
// In this case, assume that all the required vector sizes have been
1557
1657
// provided.
1558
1658
assert (inputVecSizesForLeadingDims.size () ==
1559
- static_cast <size_t >(destType .getRank ()) &&
1659
+ static_cast <size_t >(vecToStoreType .getRank ()) &&
1560
1660
" Insufficient number of input vector sizes!" );
1561
1661
// Update the inBounds attribute.
1562
- for (unsigned i = 0 ; i < rank ; i++)
1662
+ for (unsigned i = 0 ; i < destRank ; i++)
1563
1663
inBoundsVal[i] = (destShape[i] == inputVecSizesForLeadingDims[i]) &&
1564
1664
!ShapedType::isDynamic (destShape[i]);
1565
1665
}
1566
1666
1667
+ // If missing, initialize the write indices to 0.
1668
+ assert (writeIndices.empty () ||
1669
+ writeIndices.size () == static_cast <size_t >(destRank) &&
1670
+ " Invalid number of write indices!" );
1671
+ if (writeIndices.empty ()) {
1672
+ auto zero = builder.create <arith::ConstantIndexOp>(loc, 0 );
1673
+ writeIndices = SmallVector<Value>(destRank, zero);
1674
+ }
1675
+
1567
1676
// Generate the xfer_write Op
1568
- auto zero = builder.create <arith::ConstantIndexOp>(loc, 0 );
1569
- Operation *write = builder.create <vector::TransferWriteOp>(
1570
- loc,
1571
- /* vector=*/ vectorToStore,
1572
- /* source=*/ dest,
1573
- /* indices=*/ SmallVector<Value>(rank, zero),
1574
- /* inBounds=*/ inBoundsVal);
1575
- assert (llvm::none_of (
1576
- destShape.drop_front (inputVecSizesForLeadingDims.size ()),
1577
- [](int64_t size) { return size == ShapedType::kDynamic ; }) &&
1578
- " Only dims aligned with inputVecSizesForLeadingDims may be dynamic" );
1677
+ Operation *write =
1678
+ builder.create <vector::TransferWriteOp>(loc,
1679
+ /* vector=*/ vectorToStore,
1680
+ /* source=*/ dest,
1681
+ /* indices=*/ writeIndices,
1682
+ /* inBounds=*/ inBoundsVal);
1579
1683
1580
1684
// If masking is disabled, exit.
1581
1685
if (useInBoundsInsteadOfMasking)
1582
1686
return write;
1583
1687
1688
+ assert (llvm::none_of (
1689
+ destShape.drop_front (inputVecSizesForLeadingDims.size ()),
1690
+ [](int64_t size) { return size == ShapedType::kDynamic ; }) &&
1691
+ " Only dims aligned with inputVecSizesForLeadingDims may be dynamic" );
1692
+
1584
1693
// Check if masking is needed.
1585
1694
bool needMaskForWrite =
1586
1695
!llvm::equal (inputVecSizesForLeadingDims,
1587
- destShape.take_front (inputVecSizesForLeadingDims.size ()));
1696
+ destShape.take_front (destRank - vecToStoreRank +
1697
+ inputVecSizesForLeadingDims.size ()));
1588
1698
1589
1699
// If masking is needed, generate the mask and mask the operation.
1590
1700
if (needMaskForWrite) {
1701
+ // Get the mask shape + type. Missing mask dimensions are taken from
1702
+ // `vectorToStore`.
1591
1703
SmallVector<int64_t > writeMaskShape;
1592
1704
writeMaskShape.append (inputVecSizesForLeadingDims.begin (),
1593
1705
inputVecSizesForLeadingDims.end ());
1594
- writeMaskShape.append (destShape.begin () +
1595
- inputVecSizesForLeadingDims.size (),
1596
- destShape.end ());
1706
+ if (vecToStoreRank >
1707
+ static_cast <int64_t >(inputVecSizesForLeadingDims.size ()))
1708
+ writeMaskShape.append (vecToStoreShape.begin () +
1709
+ inputVecSizesForLeadingDims.size (),
1710
+ vecToStoreShape.end ());
1597
1711
auto writeMaskType = VectorType::get (writeMaskShape, builder.getI1Type ());
1598
- Value maskForWrite = builder.create <vector::CreateMaskOp>(
1599
- loc, writeMaskType, tensor::getMixedSizes (builder, loc, dest));
1712
+
1713
+ SmallVector<OpFoldResult> destSizes =
1714
+ tensor::getMixedSizes (builder, loc, dest);
1715
+ SmallVector<OpFoldResult> maskSizes (destSizes.end () - writeMaskShape.size (),
1716
+ destSizes.end ());
1717
+
1718
+ if (isMaskTriviallyFoldable (maskSizes, writeIndices, destShape,
1719
+ writeMaskShape))
1720
+ return write;
1721
+
1722
+ Value maskForWrite = builder.createOrFold <vector::CreateMaskOp>(
1723
+ loc, writeMaskType, maskSizes);
1600
1724
write = mlir::vector::maskOperation (builder, write, maskForWrite);
1601
1725
}
1602
1726
@@ -1700,10 +1824,10 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
1700
1824
Value dest = rewriter.create <tensor::EmptyOp>(
1701
1825
loc, reifiedReturnShapes[0 ],
1702
1826
transposeOp.getResult ().getType ().getElementType ());
1703
- Operation *write =
1704
- createWriteOrMaskedWrite ( rewriter, loc, transposeOp.getResult (), dest,
1705
- /* inputVecSizesForLeadingDims=*/ inputVectorSizes,
1706
- /* useInBoundsInsteadOfMasking=*/ false );
1827
+ Operation *write = createWriteOrMaskedWrite (
1828
+ rewriter, loc, transposeOp.getResult (), dest,
1829
+ /* inputVecSizesForLeadingDims=*/ inputVectorSizes, /* writeIndices= */ {} ,
1830
+ /* useInBoundsInsteadOfMasking=*/ false );
1707
1831
newResults.push_back (write->getResult (0 ));
1708
1832
return success ();
1709
1833
}
@@ -1839,10 +1963,10 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1839
1963
Value dest = rewriter.create <tensor::EmptyOp>(
1840
1964
loc, reifiedRetShapes[0 ],
1841
1965
shapeCastOp.getResult ().getType ().getElementType ());
1842
- Operation *write =
1843
- createWriteOrMaskedWrite ( rewriter, loc, shapeCastOp.getResult (), dest,
1844
- /* inputVecSizesForLeadingDims=*/ writeVectorSizes,
1845
- useInBoundsInsteadOfMasking);
1966
+ Operation *write = createWriteOrMaskedWrite (
1967
+ rewriter, loc, shapeCastOp.getResult (), dest,
1968
+ /* inputVecSizesForLeadingDims=*/ writeVectorSizes,
1969
+ /* writeIndices= */ {}, useInBoundsInsteadOfMasking);
1846
1970
newResults.push_back (write->getResult (0 ));
1847
1971
return success ();
1848
1972
}
@@ -1874,10 +1998,10 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
1874
1998
// Create Xfer write Op
1875
1999
Value dest = rewriter.create <tensor::EmptyOp>(
1876
2000
loc, reifiedReturnShapes[0 ], padOp.getResultType ().getElementType ());
1877
- Operation *write =
1878
- createWriteOrMaskedWrite ( rewriter, loc, maskedRead, dest,
1879
- /* inputVecSizesForLeadingDims=*/ inputVectorSizes,
1880
- /* useInBoundsInsteadOfMasking=*/ false );
2001
+ Operation *write = createWriteOrMaskedWrite (
2002
+ rewriter, loc, maskedRead, dest,
2003
+ /* inputVecSizesForLeadingDims=*/ inputVectorSizes, {} ,
2004
+ /* useInBoundsInsteadOfMasking=*/ false );
1881
2005
newResults.push_back (write->getResult (0 ));
1882
2006
return success ();
1883
2007
}
@@ -2922,53 +3046,19 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
2922
3046
auto vecType = VectorType::get (vecShape, sourceType.getElementType ());
2923
3047
2924
3048
// 3. Generate TransferReadOp + TransferWriteOp
2925
- ReifiedRankedShapedTypeDims reifiedSrcSizes;
2926
- Value maskOp;
2927
-
2928
- // If vector sizes are user provided, make sure to mask. First, generate the
2929
- // mask.
2930
- if (!inputVectorSizes.empty ()) {
2931
- auto *srcDefOp = source.getDefiningOp ();
2932
- if (!srcDefOp) {
2933
- LDBG (" Unable to get the defining Op of " << sliceOp);
2934
- return failure ();
2935
- }
2936
-
2937
- LogicalResult status =
2938
- cast<ReifyRankedShapedTypeOpInterface>(srcDefOp).reifyResultShapes (
2939
- rewriter, reifiedSrcSizes);
2940
- if (status.failed ()) {
2941
- LDBG (" Unable to reify result shapes of " << srcDefOp);
2942
- return failure ();
2943
- }
2944
-
2945
- // Create the mask
2946
- auto readMaskType = VectorType::get (inputVectorSizes, rewriter.getI1Type ());
2947
- maskOp = rewriter.create <vector::CreateMaskOp>(
2948
- sliceOp.getLoc (), readMaskType, reifiedSrcSizes[0 ]);
2949
- }
3049
+ auto loc = sliceOp.getLoc ();
2950
3050
3051
+ // Create read
2951
3052
SmallVector<Value> readIndices (
2952
- vecType.getRank (),
2953
- rewriter.create <arith::ConstantIndexOp>(sliceOp.getLoc (), 0 ));
2954
- Operation *read = rewriter.create <vector::TransferReadOp>(
2955
- sliceOp.getLoc (), vecType, source, readIndices, padValue,
2956
- ArrayRef<bool >{readInBounds});
2957
-
2958
- if (maskOp) {
2959
- read = mlir::vector::maskOperation (rewriter, read, maskOp);
2960
- }
2961
-
2962
- auto writeIndices = getValueOrCreateConstantIndexOp (
2963
- rewriter, sliceOp.getLoc (), sliceOp.getMixedOffsets ());
2964
-
2965
- Operation *write = rewriter.create <vector::TransferWriteOp>(
2966
- sliceOp.getLoc (), read->getResult (0 ), sliceOp.getDest (), writeIndices,
2967
- ArrayRef<bool >{writeInBounds});
2968
-
2969
- if (maskOp) {
2970
- write = mlir::vector::maskOperation (rewriter, write, maskOp);
2971
- }
3053
+ vecType.getRank (), rewriter.create <arith::ConstantIndexOp>(loc, 0 ));
3054
+ Value read = mlir::vector::createReadOrMaskedRead (
3055
+ rewriter, loc, source, vecType.getShape (), padValue);
3056
+
3057
+ // Create write
3058
+ auto writeIndices =
3059
+ getValueOrCreateConstantIndexOp (rewriter, loc, sliceOp.getMixedOffsets ());
3060
+ Operation *write = createWriteOrMaskedWrite (
3061
+ rewriter, loc, read, sliceOp.getDest (), vecType.getShape (), writeIndices);
2972
3062
2973
3063
// 4. Finalize
2974
3064
newResults.push_back (write->getResult (0 ));
0 commit comments