Skip to content

Commit d3fe38a

Browse files
committed
simplify and add one test
1 parent 2498d7d commit d3fe38a

File tree

2 files changed

+26
-15
lines changed

2 files changed

+26
-15
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6216,10 +6216,10 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
62166216
bool prevNotOne = (inputIndex != 0 && inputShape[inputIndex - 1] != 1);
62176217
bool groupEndFound = notOne || prevNotOne;
62186218
if (groupEndFound) {
6219-
// Return failure if all not permutation destinations for indices in
6219+
int high = inputIndex + deltaRank;
6220+
// Return failure if not all permutation destinations for indices in
62206221
// [low, high) are in [low, high), i.e. the permutation is not local to
62216222
// the group.
6222-
int high = inputIndex + deltaRank;
62236223
for (int i = low; i < high; ++i) {
62246224
if (permutation[i] < low || permutation[i] >= high) {
62256225
return rewriter.notifyMatchFailure(
@@ -6229,14 +6229,13 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
62296229
}
62306230
}
62316231

6232-
// We don't need to check the final group [low, outputRank) because
6233-
// if it is not locally bound, there must be a preceding group that
6234-
// already failed the check (impossible to have just 1 non-locally
6235-
// bound group).
6232+
// We don't need to check the final group [low, outputRank) because if it is
6233+
// not locally bound, there must be a preceding group that already failed
6234+
// the check (impossible to have just 1 non-locally bound group).
62366235

62376236
// The preceding logic also ensures that at this point, the output of the
6238-
// transpose is definitely broadcastable from the input shape, so we
6239-
// don't need to check vector::isBroadcastableTo now.
6237+
// transpose is definitely broadcastable from the input shape, so we don't
6238+
// need to check vector::isBroadcastableTo now.
62406239

62416240
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
62426241
transpose, transpose.getResultVectorType(), transpose.getVector());

mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
// This file contains some canonicalizations tests involving vector.transpose.
44

5-
// CHECK-LABEL: scalar_broadcast_transpose_to_broadcast
5+
// CHECK-LABEL: broadcast_transpose_scalar_to_broadcast
66
// CHECK-SAME: %[[ARG:.*]]: i8) -> vector<2x3x4xi8> {
7-
func.func @scalar_broadcast_transpose_to_broadcast(%arg0 : i8) -> vector<2x3x4xi8> {
7+
func.func @broadcast_transpose_scalar_to_broadcast(%arg0 : i8) -> vector<2x3x4xi8> {
88
// CHECK: %[[BC:.*]] = vector.broadcast %[[ARG]] : i8 to vector<2x3x4xi8>
99
%0 = vector.broadcast %arg0 : i8 to vector<3x4x2xi8>
1010
%1 = vector.transpose %0, [2, 0, 1] : vector<3x4x2xi8> to vector<2x3x4xi8>
@@ -14,23 +14,23 @@ func.func @scalar_broadcast_transpose_to_broadcast(%arg0 : i8) -> vector<2x3x4xi
1414

1515
// -----
1616

17-
// CHECK-LABEL: ones_broadcast_transpose_to_broadcast
17+
// CHECK-LABEL: broadcast_transpose_ones_to_broadcast
1818
// CHECK-SAME: %[[ARG:.*]]: vector<1x1x1xi8>) -> vector<2x3x4xi8> {
1919
// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<1x1x1xi8> to vector<2x3x4xi8>
2020
// CHECK: return %[[RES]] : vector<2x3x4xi8>
21-
func.func @ones_broadcast_transpose_to_broadcast(%arg0 : vector<1x1x1xi8>) -> vector<2x3x4xi8> {
21+
func.func @broadcast_transpose_ones_to_broadcast(%arg0 : vector<1x1x1xi8>) -> vector<2x3x4xi8> {
2222
%0 = vector.broadcast %arg0 : vector<1x1x1xi8> to vector<3x4x2xi8>
2323
%1 = vector.transpose %0, [2, 0, 1] : vector<3x4x2xi8> to vector<2x3x4xi8>
2424
return %1 : vector<2x3x4xi8>
2525
}
2626

2727
// -----
2828

29-
// CHECK-LABEL: partial_ones_broadcast_transpose_to_broadcast
29+
// CHECK-LABEL: broadcast_transpose_partial_ones_to_broadcast
3030
// CHECK-SAME: %[[ARG:.*]]: vector<1xi8>) -> vector<8x1xi8> {
3131
// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<1xi8> to vector<8x1xi8>
3232
// CHECK: return %[[RES]] : vector<8x1xi8>
33-
func.func @partial_ones_broadcast_transpose_to_broadcast(%arg0 : vector<1xi8>) -> vector<8x1xi8> {
33+
func.func @broadcast_transpose_partial_ones_to_broadcast(%arg0 : vector<1xi8>) -> vector<8x1xi8> {
3434
%0 = vector.broadcast %arg0 : vector<1xi8> to vector<1x8xi8>
3535
%1 = vector.transpose %0, [1, 0] : vector<1x8xi8> to vector<8x1xi8>
3636
return %1 : vector<8x1xi8>
@@ -50,6 +50,18 @@ func.func @broadcast_transpose_mixed_example(%arg0 : vector<4x1x1x7xi8>) -> vect
5050

5151
// -----
5252

53+
// CHECK-LABEL: broadcast_transpose_final_group
54+
// CHECK-SAME: %[[ARG:.*]]: vector<4x7x1x1xi8>) -> vector<4x7x2x3xi8> {
55+
// CHECK: %[[RES:.*]] = vector.broadcast %[[ARG]] : vector<4x7x1x1xi8> to vector<4x7x2x3xi8>
56+
// CHECK: return %[[RES]] : vector<4x7x2x3xi8>
57+
func.func @broadcast_transpose_final_group(%arg0 : vector<4x7x1x1xi8>) -> vector<4x7x2x3xi8> {
58+
%0 = vector.broadcast %arg0 : vector<4x7x1x1xi8> to vector<4x7x3x2xi8>
59+
%1 = vector.transpose %0, [0, 1, 3, 2] : vector<4x7x3x2xi8> to vector<4x7x2x3xi8>
60+
return %1 : vector<4x7x2x3xi8>
61+
}
62+
63+
// -----
64+
5365
// CHECK-LABEL: negative_broadcast_transpose_square
5466
// CHECK-SAME: %[[ARG:.*]]:
5567
// CHECK: %[[BCT:.*]] = vector.broadcast %[[ARG]]
@@ -94,7 +106,7 @@ func.func @negative_broadcast_transpose_102(%arg0 : vector<3x1x3xi8>) -> vector<
94106
// CHECK: %[[BCT:.*]] = vector.broadcast %[[ARG]]
95107
// CHECK: %[[TRP:.*]] = vector.transpose %[[BCT]], [0, 2, 1]
96108
// CHECK: return %[[TRP]] : vector<3x3x3xi8>
97-
func.func @neagtive_broadcast_transpose_021(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
109+
func.func @negative_broadcast_transpose_021(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
98110
%0 = vector.broadcast %arg0 : vector<3x1x3xi8> to vector<3x3x3xi8>
99111
%1 = vector.transpose %0, [0, 2, 1] : vector<3x3x3xi8> to vector<3x3x3xi8>
100112
return %1 : vector<3x3x3xi8>

0 commit comments

Comments
 (0)