Skip to content

Commit f4ae206

Browse files
committed
add transpose(shape_cast) and shape_cast(transpose) folders, with tests
1 parent f0a59c4 commit f4ae206

File tree

3 files changed

+145
-11
lines changed

3 files changed

+145
-11
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 80 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5594,6 +5594,29 @@ LogicalResult ShapeCastOp::verify() {
55945594
return success();
55955595
}
55965596

5597+
namespace {
5598+
5599+
/// Return true if `transpose` does not permute a pair of dimensions that are
5600+
/// both not of size 1. By `order preserving` we mean that the flattened
5601+
/// versions of the input and output vectors are (numerically) identical.
5602+
/// In other words `transpose` is effectively a shape cast.
5603+
bool isOrderPreserving(TransposeOp transpose) {
5604+
ArrayRef<int64_t> permutation = transpose.getPermutation();
5605+
ArrayRef<int64_t> inShape = transpose.getSourceVectorType().getShape();
5606+
int64_t current = 0;
5607+
for (auto p : permutation) {
5608+
if (inShape[p] != 1) {
5609+
if (p < current) {
5610+
return false;
5611+
}
5612+
current = p;
5613+
}
5614+
}
5615+
return true;
5616+
}
5617+
5618+
} // namespace
5619+
55975620
OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
55985621

55995622
// No-op shape cast.
@@ -5602,13 +5625,15 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
56025625

56035626
VectorType resultType = getType();
56045627

5605-
// Canceling shape casts.
5606-
if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
5607-
5608-
// Only allows valid transitive folding (expand/collapse dimensions).
5609-
VectorType srcType = otherOp.getSource().getType();
5628+
// shape_cast(something(x)) -> x, or
5629+
// -> shape_cast(x).
5630+
//
5631+
// Confirms that a new shape_cast will have valid semantics (expands OR
5632+
// collapses dimensions).
5633+
auto maybeFold = [&](TypedValue<VectorType> source) -> OpFoldResult {
5634+
VectorType srcType = source.getType();
56105635
if (resultType == srcType)
5611-
return otherOp.getSource();
5636+
return source;
56125637
if (srcType.getRank() < resultType.getRank()) {
56135638
if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
56145639
return {};
@@ -5618,8 +5643,25 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
56185643
} else {
56195644
return {};
56205645
}
5621-
setOperand(otherOp.getSource());
5646+
setOperand(source);
56225647
return getResult();
5648+
};
5649+
5650+
// Canceling shape casts.
5651+
if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
5652+
TypedValue<VectorType> source = otherOp.getSource();
5653+
return maybeFold(source);
5654+
}
5655+
5656+
// shape_cast(transpose(x)) -> shape_cast(x)
5657+
if (auto transpose = getSource().getDefiningOp<TransposeOp>()) {
5658+
if (transpose.getType().isScalable())
5659+
return {};
5660+
if (isOrderPreserving(transpose)) {
5661+
TypedValue<VectorType> source = transpose.getVector();
5662+
return maybeFold(source);
5663+
}
5664+
return {};
56235665
}
56245666

56255667
// Cancelling broadcast and shape cast ops.
@@ -5646,7 +5688,7 @@ namespace {
56465688
/// Helper function that computes a new vector type based on the input vector
56475689
/// type by removing the trailing one dims:
56485690
///
5649-
/// vector<4x1x1xi1> --> vector<4x1>
5691+
/// vector<4x1x1xi1> --> vector<4x1xi1>
56505692
///
56515693
static VectorType trimTrailingOneDims(VectorType oldType) {
56525694
ArrayRef<int64_t> oldShape = oldType.getShape();
@@ -6113,6 +6155,34 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
61136155
}
61146156
};
61156157

6158+
/// Folds transpose(shape_cast) into a new shape_cast.
6159+
class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
6160+
public:
6161+
using OpRewritePattern::OpRewritePattern;
6162+
6163+
LogicalResult matchAndRewrite(TransposeOp transposeOp,
6164+
PatternRewriter &rewriter) const override {
6165+
auto shapeCastOp =
6166+
transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
6167+
if (!shapeCastOp)
6168+
return failure();
6169+
if (!isOrderPreserving(transposeOp))
6170+
return failure();
6171+
if (transposeOp.getType().isScalable())
6172+
return failure();
6173+
6174+
VectorType resultType = transposeOp.getType();
6175+
6176+
// We don't need to check isValidShapeCast at this point, because it is
6177+
// guaranteed that merging the transpose into the the shape_cast is a valid
6178+
// shape_cast, because the transpose just inserts/removes ones.
6179+
6180+
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transposeOp, resultType,
6181+
shapeCastOp.getSource());
6182+
return success();
6183+
}
6184+
};
6185+
61166186
/// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
61176187
/// 'order preserving', where 'order preserving' means the flattened
61186188
/// inputs and outputs of the transpose have identical (numerical) values.
@@ -6211,8 +6281,8 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
62116281

62126282
void vector::TransposeOp::getCanonicalizationPatterns(
62136283
RewritePatternSet &results, MLIRContext *context) {
6214-
results.add<FoldTransposeCreateMask, TransposeFolder, FoldTransposeSplat,
6215-
FoldTransposeBroadcast>(context);
6284+
results.add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
6285+
FoldTransposeSplat, FoldTransposeBroadcast>(context);
62166286
}
62176287

62186288
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ func.func @create_vector_mask_to_constant_mask() -> (vector<4x3xi1>) {
88
%0 = vector.create_mask %c3, %c2 : vector<4x3xi1>
99
return %0 : vector<4x3xi1>
1010
}
11+
1112
// -----
1213

1314
// CHECK-LABEL: create_scalable_vector_mask_to_constant_mask
@@ -3035,7 +3036,6 @@ func.func @insert_vector_poison(%a: vector<4x8xf32>)
30353036
return %1 : vector<4x8xf32>
30363037
}
30373038

3038-
30393039
// -----
30403040

30413041
// CHECK-LABEL: @insert_scalar_poison_idx

mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,3 +137,67 @@ func.func @negative_broadcast_transpose_021(%arg0 : vector<3x1x3xi8>) -> vector<
137137
return %1 : vector<3x3x3xi8>
138138
}
139139

140+
141+
// -----
142+
143+
// In this test, the permutation maps the non-one dimensions (1 and 2) as follows:
144+
// 1 -> 0
145+
// 2 -> 4
146+
// Because 0 < 4, this permutation is order preserving and effectively a shape_cast.
147+
// CHECK-LABEL: @transpose_shape_cast
148+
// CHECK-SAME: %[[ARG:.*]]: vector<1x4x4x1x1xi8>) -> vector<4x4xi8> {
149+
// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
150+
// CHECK-SAME: vector<1x4x4x1x1xi8> to vector<4x4xi8>
151+
// CHECK: return %[[SHAPE_CAST]] : vector<4x4xi8>
152+
func.func @transpose_shape_cast(%arg : vector<1x4x4x1x1xi8>) -> vector<4x4xi8> {
153+
%0 = vector.transpose %arg, [1, 0, 3, 4, 2]
154+
: vector<1x4x4x1x1xi8> to vector<4x1x1x1x4xi8>
155+
%1 = vector.shape_cast %0 : vector<4x1x1x1x4xi8> to vector<4x4xi8>
156+
return %1 : vector<4x4xi8>
157+
}
158+
159+
// -----
160+
161+
// In this test, the mapping of non-one indices (1 and 2) is as follows:
162+
// 1 -> 2
163+
// 2 -> 1
164+
// As this is not increasing (2 > 1), this transpose is not order
165+
// preserving and cannot be treated as a shape_cast.
166+
// CHECK-LABEL: @negative_transpose_shape_cast
167+
// CHECK-SAME: %[[ARG:.*]]: vector<1x4x4x1xi8>) -> vector<4x4xi8> {
168+
// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[ARG]]
169+
// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[TRANSPOSE]]
170+
// CHECK: return %[[SHAPE_CAST]] : vector<4x4xi8>
171+
func.func @negative_transpose_shape_cast(%arg : vector<1x4x4x1xi8>) -> vector<4x4xi8> {
172+
%0 = vector.transpose %arg, [0, 2, 1, 3]
173+
: vector<1x4x4x1xi8> to vector<1x4x4x1xi8>
174+
%1 = vector.shape_cast %0 : vector<1x4x4x1xi8> to vector<4x4xi8>
175+
return %1 : vector<4x4xi8>
176+
}
177+
178+
// -----
179+
180+
// CHECK-LABEL: @shape_cast_transpose
181+
// CHECK-SAME: %[[ARG:.*]]: vector<2x3x1x1xi8>) -> vector<6x1x1xi8> {
182+
// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
183+
// CHECK-SAME: vector<2x3x1x1xi8> to vector<6x1x1xi8>
184+
// CHECK: return %[[SHAPE_CAST]] : vector<6x1x1xi8>
185+
func.func @shape_cast_transpose(%arg : vector<2x3x1x1xi8>) -> vector<6x1x1xi8> {
186+
%0 = vector.shape_cast %arg : vector<2x3x1x1xi8> to vector<6x1x1xi8>
187+
%1 = vector.transpose %0, [0, 2, 1]
188+
: vector<6x1x1xi8> to vector<6x1x1xi8>
189+
return %1 : vector<6x1x1xi8>
190+
}
191+
192+
// -----
193+
194+
// CHECK-LABEL: @negative_shape_cast_transpose
195+
// CHECK-SAME: %[[ARG:.*]]: vector<6xi8>) -> vector<2x3xi8> {
196+
// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
197+
// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[SHAPE_CAST]]
198+
// CHECK: return %[[TRANSPOSE]] : vector<2x3xi8>
199+
func.func @negative_shape_cast_transpose(%arg : vector<6xi8>) -> vector<2x3xi8> {
200+
%0 = vector.shape_cast %arg : vector<6xi8> to vector<3x2xi8>
201+
%1 = vector.transpose %0, [1, 0] : vector<3x2xi8> to vector<2x3xi8>
202+
return %1 : vector<2x3xi8>
203+
}

0 commit comments

Comments
 (0)