-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][vector] Handle corner cases in DropUnitDimsFromTransposeOp. #102518
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector] Handle corner cases in DropUnitDimsFromTransposeOp. #102518
Conversation
llvm@da8778e breaks the lowering of vector.transpose that all the dimensions are unit dimensions. The revision fixes the issue and add a test. Signed-off-by: hanhanW <hanhan0912@gmail.com>
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Han-Chung Wang (hanhanW) Changesda8778e breaks the lowering of vector.transpose that all the dimensions are unit dimensions. The revision fixes the issue and add a test. Full diff: https://github.com/llvm/llvm-project/pull/102518.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index bc0c96b32a80f..b4ae9b319343a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1771,6 +1771,14 @@ struct DropUnitDimsFromTransposeOp final
newPerm.push_back(idx - droppedDimsBefore[idx]);
}
+ // Fixup for `newPerm`. The `sourceTypeWithoutUnitDims` could be vector<1xT>
+ // type when the dimensions are unit dimensions. In this case, the newPerm
+ // should be [0].
+ if (sourceTypeWithoutUnitDims.getRank() == 1 &&
+ sourceTypeWithoutUnitDims.getShape()[0] == 1 && newPerm.empty()) {
+ newPerm.push_back(0);
+ }
+
Location loc = op.getLoc();
// Drop the unit dims via shape_cast.
auto dropDimsShapeCast = rewriter.create<vector::ShapeCastOp>(
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 937dbf22bb713..0d34d692393fd 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -737,6 +737,18 @@ func.func @transpose_with_scalable_unit_dims(%vec: vector<[1]x1x2x4x1xf32>) -> v
// -----
+func.func @transpose_with_all_unit_dims(%vec: vector<1x1x1xf32>) -> vector<1x1x1xf32> {
+ %res = vector.transpose %vec, [0, 2, 1] : vector<1x1x1xf32> to vector<1x1x1xf32>
+ return %res : vector<1x1x1xf32>
+}
+// The `vec` is returned because there are other flattening patterns fold
+// vector.shape_cast ops away.
+// CHECK-LABEL: func.func @transpose_with_all_unit_dims
+// CHECK-SAME: %[[VEC:.[a-zA-Z0-9]+]]
+// CHECK-NEXT: return %[[VEC]]
+
+// -----
+
func.func @negative_transpose_with_no_unit_dims(%vec: vector<4x2x3xf32>) -> vector<4x3x2xf32> {
%res = vector.transpose %vec, [0, 2, 1] : vector<4x2x3xf32> to vector<4x3x2xf32>
return %res : vector<4x3x2xf32>
|
Signed-off-by: hanhanW <hanhan0912@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
Signed-off-by: hanhanW <hanhan0912@gmail.com>
da8778e breaks the lowering of vector.transpose that all the dimensions are unit dimensions. The revision fixes the issue and adds a test.