@@ -5594,6 +5594,29 @@ LogicalResult ShapeCastOp::verify() {
5594
5594
return success ();
5595
5595
}
5596
5596
5597
+ namespace {
5598
+
5599
+ // / Return true if `transpose` does not permute a pair of dimensions that are
5600
+ // / both not of size 1. By `order preserving` we mean that the flattened
5601
+ // / versions of the input and output vectors are (numerically) identical.
5602
+ // / In other words `transpose` is effectively a shape cast.
5603
+ bool isOrderPreserving (TransposeOp transpose) {
5604
+ ArrayRef<int64_t > permutation = transpose.getPermutation ();
5605
+ ArrayRef<int64_t > inShape = transpose.getSourceVectorType ().getShape ();
5606
+ int64_t current = 0 ;
5607
+ for (auto p : permutation) {
5608
+ if (inShape[p] != 1 ) {
5609
+ if (p < current) {
5610
+ return false ;
5611
+ }
5612
+ current = p;
5613
+ }
5614
+ }
5615
+ return true ;
5616
+ }
5617
+
5618
+ } // namespace
5619
+
5597
5620
OpFoldResult ShapeCastOp::fold (FoldAdaptor adaptor) {
5598
5621
5599
5622
// No-op shape cast.
@@ -5602,13 +5625,15 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
5602
5625
5603
5626
VectorType resultType = getType ();
5604
5627
5605
- // Canceling shape casts.
5606
- if (auto otherOp = getSource ().getDefiningOp <ShapeCastOp>()) {
5607
-
5608
- // Only allows valid transitive folding (expand/collapse dimensions).
5609
- VectorType srcType = otherOp.getSource ().getType ();
5628
+ // shape_cast(something(x)) -> x, or
5629
+ // -> shape_cast(x).
5630
+ //
5631
+ // Confirms that a new shape_cast will have valid semantics (expands OR
5632
+ // collapses dimensions).
5633
+ auto maybeFold = [&](TypedValue<VectorType> source) -> OpFoldResult {
5634
+ VectorType srcType = source.getType ();
5610
5635
if (resultType == srcType)
5611
- return otherOp. getSource () ;
5636
+ return source ;
5612
5637
if (srcType.getRank () < resultType.getRank ()) {
5613
5638
if (!isValidShapeCast (srcType.getShape (), resultType.getShape ()))
5614
5639
return {};
@@ -5618,8 +5643,25 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
5618
5643
} else {
5619
5644
return {};
5620
5645
}
5621
- setOperand (otherOp. getSource () );
5646
+ setOperand (source );
5622
5647
return getResult ();
5648
+ };
5649
+
5650
+ // Canceling shape casts.
5651
+ if (auto otherOp = getSource ().getDefiningOp <ShapeCastOp>()) {
5652
+ TypedValue<VectorType> source = otherOp.getSource ();
5653
+ return maybeFold (source);
5654
+ }
5655
+
5656
+ // shape_cast(transpose(x)) -> shape_cast(x)
5657
+ if (auto transpose = getSource ().getDefiningOp <TransposeOp>()) {
5658
+ if (transpose.getType ().isScalable ())
5659
+ return {};
5660
+ if (isOrderPreserving (transpose)) {
5661
+ TypedValue<VectorType> source = transpose.getVector ();
5662
+ return maybeFold (source);
5663
+ }
5664
+ return {};
5623
5665
}
5624
5666
5625
5667
// Cancelling broadcast and shape cast ops.
@@ -5646,7 +5688,7 @@ namespace {
5646
5688
// / Helper function that computes a new vector type based on the input vector
5647
5689
// / type by removing the trailing one dims:
5648
5690
// /
5649
- // / vector<4x1x1xi1> --> vector<4x1 >
5691
+ // / vector<4x1x1xi1> --> vector<4x1xi1 >
5650
5692
// /
5651
5693
static VectorType trimTrailingOneDims (VectorType oldType) {
5652
5694
ArrayRef<int64_t > oldShape = oldType.getShape ();
@@ -6113,6 +6155,34 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
6113
6155
}
6114
6156
};
6115
6157
6158
+ // / Folds transpose(shape_cast) into a new shape_cast.
6159
+ class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
6160
+ public:
6161
+ using OpRewritePattern::OpRewritePattern;
6162
+
6163
+ LogicalResult matchAndRewrite (TransposeOp transposeOp,
6164
+ PatternRewriter &rewriter) const override {
6165
+ auto shapeCastOp =
6166
+ transposeOp.getVector ().getDefiningOp <vector::ShapeCastOp>();
6167
+ if (!shapeCastOp)
6168
+ return failure ();
6169
+ if (!isOrderPreserving (transposeOp))
6170
+ return failure ();
6171
+ if (transposeOp.getType ().isScalable ())
6172
+ return failure ();
6173
+
6174
+ VectorType resultType = transposeOp.getType ();
6175
+
6176
+ // We don't need to check isValidShapeCast at this point, because it is
6177
+ // guaranteed that merging the transpose into the the shape_cast is a valid
6178
+ // shape_cast, because the transpose just inserts/removes ones.
6179
+
6180
+ rewriter.replaceOpWithNewOp <vector::ShapeCastOp>(transposeOp, resultType,
6181
+ shapeCastOp.getSource ());
6182
+ return success ();
6183
+ }
6184
+ };
6185
+
6116
6186
// / Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
6117
6187
// / 'order preserving', where 'order preserving' means the flattened
6118
6188
// / inputs and outputs of the transpose have identical (numerical) values.
@@ -6211,8 +6281,8 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
6211
6281
6212
6282
void vector::TransposeOp::getCanonicalizationPatterns (
6213
6283
RewritePatternSet &results, MLIRContext *context) {
6214
- results.add <FoldTransposeCreateMask, TransposeFolder, FoldTransposeSplat ,
6215
- FoldTransposeBroadcast>(context);
6284
+ results.add <FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder ,
6285
+ FoldTransposeSplat, FoldTransposeBroadcast>(context);
6216
6286
}
6217
6287
6218
6288
// ===----------------------------------------------------------------------===//
0 commit comments