Skip to content

[mlir][vector] Handle corner cases in DropUnitDimsFromTransposeOp. #102518

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 9, 2024

Conversation

hanhanW
Copy link
Contributor

@hanhanW hanhanW commented Aug 8, 2024

da8778e breaks the lowering of vector.transpose that all the dimensions are unit dimensions. The revision fixes the issue and adds a test.

llvm@da8778e
breaks the lowering of vector.transpose that all the dimensions are unit
dimensions. The revision fixes the issue and add a test.

Signed-off-by: hanhanW <hanhan0912@gmail.com>
@llvmbot
Copy link
Member

llvmbot commented Aug 8, 2024

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Han-Chung Wang (hanhanW)

Changes

da8778e breaks the lowering of vector.transpose that all the dimensions are unit dimensions. The revision fixes the issue and add a test.


Full diff: https://github.com/llvm/llvm-project/pull/102518.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+8)
  • (modified) mlir/test/Dialect/Vector/vector-transfer-flatten.mlir (+12)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index bc0c96b32a80f..b4ae9b319343a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1771,6 +1771,14 @@ struct DropUnitDimsFromTransposeOp final
       newPerm.push_back(idx - droppedDimsBefore[idx]);
     }
 
+    // Fixup for `newPerm`. The `sourceTypeWithoutUnitDims` could be vector<1xT>
+    // type when the dimensions are unit dimensions. In this case, the newPerm
+    // should be [0].
+    if (sourceTypeWithoutUnitDims.getRank() == 1 &&
+        sourceTypeWithoutUnitDims.getShape()[0] == 1 && newPerm.empty()) {
+      newPerm.push_back(0);
+    }
+
     Location loc = op.getLoc();
     // Drop the unit dims via shape_cast.
     auto dropDimsShapeCast = rewriter.create<vector::ShapeCastOp>(
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 937dbf22bb713..0d34d692393fd 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -737,6 +737,18 @@ func.func @transpose_with_scalable_unit_dims(%vec: vector<[1]x1x2x4x1xf32>) -> v
 
 // -----
 
+func.func @transpose_with_all_unit_dims(%vec: vector<1x1x1xf32>) -> vector<1x1x1xf32> {
+  %res = vector.transpose %vec, [0, 2, 1] : vector<1x1x1xf32> to vector<1x1x1xf32>
+  return %res : vector<1x1x1xf32>
+}
+// The `vec` is returned because there are other flattening patterns fold
+// vector.shape_cast ops away.
+// CHECK-LABEL: func.func @transpose_with_all_unit_dims
+// CHECK-SAME:      %[[VEC:.[a-zA-Z0-9]+]]
+// CHECK-NEXT:    return %[[VEC]]
+
+// -----
+
 func.func @negative_transpose_with_no_unit_dims(%vec: vector<4x2x3xf32>) -> vector<4x3x2xf32> {
   %res = vector.transpose %vec, [0, 2, 1] : vector<4x2x3xf32> to vector<4x3x2xf32>
   return %res : vector<4x3x2xf32>

Signed-off-by: hanhanW <hanhan0912@gmail.com>
Copy link
Member

@MacDue MacDue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

Signed-off-by: hanhanW <hanhan0912@gmail.com>
@hanhanW hanhanW merged commit 201da87 into llvm:main Aug 9, 2024
8 checks passed
@hanhanW hanhanW deleted the fix-vector-transpose-unit-dim-pattern branch August 9, 2024 00:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants