Skip to content

Commit 21f1a61

Browse files
authored
[mlir][vector] Additional transpose folding (#138347)
Fold transpose with unit-dimensions. Seen in the wild: ``` %0 = vector.transpose %arg, [0, 2, 1, 3] : vector<6x1x1x4xi8> to vector<6x1x1x4xi8> ``` This transpose can be folded because (1) it preserves the shape and (2) the shuffled dims are unit extent. Also addresses comment about static vs anonymous namespace: #135841 (comment) --------- Signed-off-by: James Newling <james.newling@gmail.com>
1 parent 0dd2c9f commit 21f1a61

File tree

4 files changed

+74
-46
lines changed

4 files changed

+74
-46
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5573,13 +5573,11 @@ LogicalResult ShapeCastOp::verify() {
55735573
return success();
55745574
}
55755575

5576-
namespace {
5577-
55785576
/// Return true if `transpose` does not permute a pair of non-unit dims.
55795577
/// By `order preserving` we mean that the flattened versions of the input and
55805578
/// output vectors are (numerically) identical. In other words `transpose` is
55815579
/// effectively a shape cast.
5582-
bool isOrderPreserving(TransposeOp transpose) {
5580+
static bool isOrderPreserving(TransposeOp transpose) {
55835581
ArrayRef<int64_t> permutation = transpose.getPermutation();
55845582
VectorType sourceType = transpose.getSourceVectorType();
55855583
ArrayRef<int64_t> inShape = sourceType.getShape();
@@ -5599,8 +5597,6 @@ bool isOrderPreserving(TransposeOp transpose) {
55995597
return true;
56005598
}
56015599

5602-
} // namespace
5603-
56045600
OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
56055601

56065602
VectorType resultType = getType();
@@ -5997,18 +5993,22 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
59975993
if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getVector()))
59985994
return ub::PoisonAttr::get(getContext());
59995995

6000-
// Eliminate identity transpose ops. This happens when the dimensions of the
6001-
// input vector remain in their original order after the transpose operation.
6002-
ArrayRef<int64_t> perm = getPermutation();
6003-
6004-
// Check if the permutation of the dimensions contains sequential values:
6005-
// {0, 1, 2, ...}.
6006-
for (int64_t i = 0, e = perm.size(); i < e; i++) {
6007-
if (perm[i] != i)
6008-
return {};
6009-
}
5996+
// Eliminate identity transposes, and more generally any transposes that
5997+
// preserves the shape without permuting elements.
5998+
//
5999+
// Examples of what to fold:
6000+
// %0 = vector.transpose %arg, [0, 1] : vector<1x1xi8> to vector<1x1xi8>
6001+
// %0 = vector.transpose %arg, [0, 1] : vector<2x2xi8> to vector<2x2xi8>
6002+
// %0 = vector.transpose %arg, [1, 0] : vector<1x1xi8> to vector<1x1xi8>
6003+
//
6004+
// Example of what NOT to fold:
6005+
// %0 = vector.transpose %arg, [1, 0] : vector<2x2xi8> to vector<2x2xi8>
6006+
//
6007+
if (getSourceVectorType() == getResultVectorType() &&
6008+
isOrderPreserving(*this))
6009+
return getVector();
60106010

6011-
return getVector();
6011+
return {};
60126012
}
60136013

60146014
LogicalResult vector::TransposeOp::verify() {

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -450,28 +450,6 @@ func.func @extract_strided_fold_insert(%a: vector<2x8xf32>, %b: vector<1x4xf32>,
450450

451451
// -----
452452

453-
// CHECK-LABEL: transpose_1D_identity
454-
// CHECK-SAME: ([[ARG:%.*]]: vector<4xf32>)
455-
func.func @transpose_1D_identity(%arg : vector<4xf32>) -> vector<4xf32> {
456-
// CHECK-NOT: transpose
457-
%0 = vector.transpose %arg, [0] : vector<4xf32> to vector<4xf32>
458-
// CHECK-NEXT: return [[ARG]]
459-
return %0 : vector<4xf32>
460-
}
461-
462-
// -----
463-
464-
// CHECK-LABEL: transpose_2D_identity
465-
// CHECK-SAME: ([[ARG:%.*]]: vector<4x3xf32>)
466-
func.func @transpose_2D_identity(%arg : vector<4x3xf32>) -> vector<4x3xf32> {
467-
// CHECK-NOT: transpose
468-
%0 = vector.transpose %arg, [0, 1] : vector<4x3xf32> to vector<4x3xf32>
469-
// CHECK-NEXT: return [[ARG]]
470-
return %0 : vector<4x3xf32>
471-
}
472-
473-
// -----
474-
475453
// CHECK-LABEL: transpose_3D_identity
476454
// CHECK-SAME: ([[ARG:%.*]]: vector<4x3x2xf32>)
477455
func.func @transpose_3D_identity(%arg : vector<4x3x2xf32>) -> vector<4x3x2xf32> {

mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
22

3-
// This file contains some canonicalizations tests involving vector.transpose.
3+
// This file contains some tests of canonicalizations and foldings involving vector.transpose.
4+
5+
// +---------------------------------------------------------------------------
6+
// Tests of FoldTransposeBroadcast: transpose(broadcast) -> broadcast
7+
// +---------------------------------------------------------------------------
48

59
// CHECK-LABEL: func @transpose_scalar_broadcast1
610
// CHECK-SAME: (%[[ARG:.+]]: vector<1xf32>)
@@ -248,3 +252,47 @@ func.func @negative_transpose_of_shape_cast(%arg : vector<6xi8>) -> vector<2x3xi
248252
%1 = vector.transpose %0, [1, 0] : vector<3x2xi8> to vector<2x3xi8>
249253
return %1 : vector<2x3xi8>
250254
}
255+
256+
// -----
257+
258+
// +-----------------------------------
259+
// Tests of TransposeOp::fold
260+
// +-----------------------------------
261+
262+
// CHECK-LABEL: transpose_1D_identity
263+
// CHECK-SAME: [[ARG:%.*]]: vector<4xf32>
264+
// CHECK-NEXT: return [[ARG]]
265+
func.func @transpose_1D_identity(%arg : vector<4xf32>) -> vector<4xf32> {
266+
%0 = vector.transpose %arg, [0] : vector<4xf32> to vector<4xf32>
267+
return %0 : vector<4xf32>
268+
}
269+
270+
// -----
271+
272+
// CHECK-LABEL: transpose_2D_identity
273+
// CHECK-SAME: [[ARG:%.*]]: vector<4x3xf32>
274+
// CHECK-NEXT: return [[ARG]]
275+
func.func @transpose_2D_identity(%arg : vector<4x3xf32>) -> vector<4x3xf32> {
276+
%0 = vector.transpose %arg, [0, 1] : vector<4x3xf32> to vector<4x3xf32>
277+
return %0 : vector<4x3xf32>
278+
}
279+
280+
// -----
281+
282+
// CHECK-LABEL: transpose_shape_and_order_preserving
283+
// CHECK-SAME: [[ARG:%.*]]: vector<6x1x1x4xi8>
284+
// CHECK-NEXT: return [[ARG]]
285+
func.func @transpose_shape_and_order_preserving(%arg : vector<6x1x1x4xi8>) -> vector<6x1x1x4xi8> {
286+
%0 = vector.transpose %arg, [0, 2, 1, 3] : vector<6x1x1x4xi8> to vector<6x1x1x4xi8>
287+
return %0 : vector<6x1x1x4xi8>
288+
}
289+
290+
// -----
291+
292+
// CHECK-LABEL: negative_transpose_fold
293+
// CHECK: [[TRANSP:%.*]] = vector.transpose
294+
// CHECK: return [[TRANSP]]
295+
func.func @negative_transpose_fold(%arg : vector<2x2xi8>) -> vector<2x2xi8> {
296+
%0 = vector.transpose %arg, [1, 0] : vector<2x2xi8> to vector<2x2xi8>
297+
return %0 : vector<2x2xi8>
298+
}

mlir/test/Dialect/Vector/vector-transpose-lowering.mlir

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,15 @@ func.func @transpose102_8x1x8xf32(%arg0: vector<8x1x8xf32>) -> vector<1x8x8xf32>
6565
return %0 : vector<1x8x8xf32>
6666
}
6767

68-
// CHECK-LABEL: func @transpose1023_1x1x8x8xf32(
69-
func.func @transpose1023_1x1x8x8xf32(%arg0: vector<1x1x8x8xf32>) -> vector<1x1x8x8xf32> {
70-
// Note the single 2-D extract/insert pair since 2 and 3 are not transposed!
71-
// CHECK: vector.extract {{.*}}[0, 0] : vector<8x8xf32> from vector<1x1x8x8xf32>
72-
// CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<8x8xf32> into vector<1x1x8x8xf32>
73-
%0 = vector.transpose %arg0, [1, 0, 2, 3] : vector<1x1x8x8xf32> to vector<1x1x8x8xf32>
74-
return %0 : vector<1x1x8x8xf32>
68+
// CHECK-LABEL: func @transpose1023_2x1x8x4xf32(
69+
func.func @transpose1023_2x1x8x4xf32(%arg0: vector<2x1x8x4xf32>) -> vector<1x2x8x4xf32> {
70+
// Note the 2-D extract/insert pair since dimensions 2 and 3 are not transposed!
71+
// CHECK: vector.extract {{.*}}[0, 0] : vector<8x4xf32> from vector<2x1x8x4xf32>
72+
// CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<8x4xf32> into vector<1x2x8x4xf32>
73+
// CHECK-NEXT: vector.extract {{.*}}[1, 0] : vector<8x4xf32> from vector<2x1x8x4xf32>
74+
// CHECK-NEXT: vector.insert {{.*}} [0, 1] : vector<8x4xf32> into vector<1x2x8x4xf32>
75+
%0 = vector.transpose %arg0, [1, 0, 2, 3] : vector<2x1x8x4xf32> to vector<1x2x8x4xf32>
76+
return %0 : vector<1x2x8x4xf32>
7577
}
7678

7779
/// Scalable dim should not be unrolled.

0 commit comments

Comments
 (0)