Skip to content

[mlir][vector] Canonicalize/fold 'order preserving' transposes #135841

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

newling
Copy link
Contributor

@newling newling commented Apr 15, 2025

Handles special case where transpose doesn't permute any non-1 dimensions (and so is effectively a shape_cast) and is adjacent to a shape_cast that it can fold into. For example

%1 = vector.transpose %0, [1, 0, 3, 2] : vector<4x1x1x6xf32> to vector<1x4x6x1xf32>

can be folded into an adjacent shape_cast. An alternative to this PR would be to canonicalize such transposes to shape_casts directly, but I think it'll be difficult getting consensus that shape_cast is 'more canonical' than transpose, so this PR compromises with the less opinionated claim that

  1. shape_cast is more canonical than shape_cast(transpose)
  2. shape_cast is more canonical than transpose(shape_cast)

The pattern ConvertIllegalShapeCastOpsToTransposes that is specific to transposes with scalable dimensions reverses the canonicalization added here, so I've I've disabled this canonicalization for scalable vectors

@llvmbot
Copy link
Member

llvmbot commented Apr 16, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: James Newling (newling)

Changes

Handles special case where transpose doesn't permute any non-1 dimensions (and so is effectively a shape_cast) and is adjacent to a shape_cast that it can fold into. For example

%1 = vector.transpose %0, [1, 0, 3, 2] : vector&lt;4x1x1x6xf32&gt; to vector&lt;1x4x6x1xf32&gt;

can be folded into an adjacent shape_cast. An alternative to this PR would be to canonicalize such transposes to shape_casts directly, but I think it'll be difficult getting consensus that shape_cast is 'more canonical' than transpose, so this PR compromises with the less opinionated claim that

  1. shape_cast is more canonical than shape_cast(transpose)
  2. shape_cast is more canonical than transpose(shape_cast)

The pattern ConvertIllegalShapeCastOpsToTransposes that is specific to transposes with scalable dimensions reverses the canonicalization added here, so I've I've disabled this canonicalization for scalable vectors


Full diff: https://github.com/llvm/llvm-project/pull/135841.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+79-9)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+64)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index bee5c1fd6ed58..5da0ef0af032f 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5621,6 +5621,29 @@ LogicalResult ShapeCastOp::verify() {
   return success();
 }
 
+namespace {
+
+/// Return true if `transpose` does not permute a pair of dimensions that are
+/// both not of size 1. By `order preserving` we mean that the flattened
+/// versions of the input and output vectors are (numerically) identical.
+/// In other words `transpose` is effectively a shape cast.
+bool isOrderPreserving(TransposeOp transpose) {
+  ArrayRef<int64_t> permutation = transpose.getPermutation();
+  ArrayRef<int64_t> inShape = transpose.getSourceVectorType().getShape();
+  int64_t current = 0;
+  for (auto p : permutation) {
+    if (inShape[p] != 1) {
+      if (p < current) {
+        return false;
+      }
+      current = p;
+    }
+  }
+  return true;
+}
+
+} // namespace
+
 OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
 
   // No-op shape cast.
@@ -5629,13 +5652,15 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
 
   VectorType resultType = getType();
 
-  // Canceling shape casts.
-  if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
-
-    // Only allows valid transitive folding (expand/collapse dimensions).
-    VectorType srcType = otherOp.getSource().getType();
+  // shape_cast(something(x)) -> x, or
+  //                          -> shape_cast(x).
+  //
+  // Confirms that a new shape_cast will have valid semantics (expands OR
+  // collapses dimensions).
+  auto maybeFold = [&](TypedValue<VectorType> source) -> OpFoldResult {
+    VectorType srcType = source.getType();
     if (resultType == srcType)
-      return otherOp.getSource();
+      return source;
     if (srcType.getRank() < resultType.getRank()) {
       if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
         return {};
@@ -5645,8 +5670,25 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
     } else {
       return {};
     }
-    setOperand(otherOp.getSource());
+    setOperand(source);
     return getResult();
+  };
+
+  // Canceling shape casts.
+  if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
+    TypedValue<VectorType> source = otherOp.getSource();
+    return maybeFold(source);
+  }
+
+  // shape_cast(transpose(x)) -> shape_cast(x)
+  if (auto transpose = getSource().getDefiningOp<TransposeOp>()) {
+    if (transpose.getType().isScalable())
+      return {};
+    if (isOrderPreserving(transpose)) {
+      TypedValue<VectorType> source = transpose.getVector();
+      return maybeFold(source);
+    }
+    return {};
   }
 
   // Cancelling broadcast and shape cast ops.
@@ -5675,7 +5717,7 @@ namespace {
 /// Helper function that computes a new vector type based on the input vector
 /// type by removing the trailing one dims:
 ///
-///   vector<4x1x1xi1> --> vector<4x1>
+///   vector<4x1x1xi1> --> vector<4x1xi1>
 ///
 static VectorType trimTrailingOneDims(VectorType oldType) {
   ArrayRef<int64_t> oldShape = oldType.getShape();
@@ -6161,12 +6203,40 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
   }
 };
 
+/// Folds transpose(shape_cast) into a new shape_cast.
+class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(TransposeOp transposeOp,
+                                PatternRewriter &rewriter) const override {
+    auto shapeCastOp =
+        transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
+    if (!shapeCastOp)
+      return failure();
+    if (!isOrderPreserving(transposeOp))
+      return failure();
+    if (transposeOp.getType().isScalable())
+      return failure();
+
+    VectorType resultType = transposeOp.getType();
+
+    // We don't need to check isValidShapeCast at this point, because it is
+    // guaranteed that merging the transpose into the the shape_cast is a valid
+    // shape_cast, because the transpose just inserts/removes ones.
+
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transposeOp, resultType,
+                                                     shapeCastOp.getSource());
+    return success();
+  }
+};
 } // namespace
 
 void vector::TransposeOp::getCanonicalizationPatterns(
     RewritePatternSet &results, MLIRContext *context) {
   results.add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast,
-              TransposeFolder, FoldTransposeSplat>(context);
+              FoldTransposeShapeCast, TransposeFolder, FoldTransposeSplat>(
+      context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 78b0ea78849e8..10144cb9034e4 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -3295,3 +3295,67 @@ func.func @fold_insert_constant_indices(%arg : vector<4x1xi32>) -> vector<4x1xi3
   %res = vector.insert %1, %arg[%0, %0] : i32 into vector<4x1xi32>
   return %res : vector<4x1xi32>
 }
+
+// -----
+
+// In this test, the permutation maps the non-one dimensions (1 and 2) as follows:
+// 1 -> 0
+// 2 -> 4
+// Because 0 < 4, this permutation is order preserving and effectively a shape_cast.
+// CHECK-LABEL: @transpose_shape_cast
+//  CHECK-SAME:   %[[ARG:.*]]: vector<1x4x4x1x1xi8>) -> vector<4x4xi8> {
+//       CHECK:   %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
+//  CHECK-SAME:   vector<1x4x4x1x1xi8> to vector<4x4xi8>
+//       CHECK:   return %[[SHAPE_CAST]] : vector<4x4xi8>
+func.func @transpose_shape_cast(%arg : vector<1x4x4x1x1xi8>) -> vector<4x4xi8> {
+  %0 = vector.transpose %arg, [1, 0, 3, 4, 2]
+     : vector<1x4x4x1x1xi8> to vector<4x1x1x1x4xi8>
+  %1 = vector.shape_cast %0 : vector<4x1x1x1x4xi8> to vector<4x4xi8>
+  return %1 : vector<4x4xi8>
+}
+
+// -----
+
+// In this test, the mapping of non-one indices (1 and 2) is as follows:
+// 1 -> 2
+// 2 -> 1
+// As this is not increasing (2 > 1), this transpose is not order
+// preserving and cannot be treated as a shape_cast.
+// CHECK-LABEL: @negative_transpose_shape_cast
+//  CHECK-SAME:   %[[ARG:.*]]: vector<1x4x4x1xi8>) -> vector<4x4xi8> {
+//       CHECK:   %[[TRANSPOSE:.*]] = vector.transpose %[[ARG]]
+//       CHECK:   %[[SHAPE_CAST:.*]] = vector.shape_cast %[[TRANSPOSE]]
+//       CHECK:   return %[[SHAPE_CAST]] : vector<4x4xi8>
+func.func @negative_transpose_shape_cast(%arg : vector<1x4x4x1xi8>) -> vector<4x4xi8> {
+  %0 = vector.transpose %arg, [0, 2, 1, 3]
+     : vector<1x4x4x1xi8> to vector<1x4x4x1xi8>
+  %1 = vector.shape_cast %0 : vector<1x4x4x1xi8> to vector<4x4xi8>
+  return %1 : vector<4x4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @shape_cast_transpose
+//  CHECK-SAME:   %[[ARG:.*]]: vector<2x3x1x1xi8>) -> vector<6x1x1xi8> {
+//       CHECK:   %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
+//  CHECK-SAME:   vector<2x3x1x1xi8> to vector<6x1x1xi8>
+//       CHECK:   return %[[SHAPE_CAST]] : vector<6x1x1xi8>
+func.func @shape_cast_transpose(%arg : vector<2x3x1x1xi8>) ->  vector<6x1x1xi8> {
+  %0 = vector.shape_cast %arg : vector<2x3x1x1xi8> to vector<6x1x1xi8>
+  %1 = vector.transpose %0, [0, 2, 1]
+     : vector<6x1x1xi8> to vector<6x1x1xi8>
+  return %1 : vector<6x1x1xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_shape_cast_transpose
+//  CHECK-SAME:   %[[ARG:.*]]: vector<6xi8>) -> vector<2x3xi8> {
+//       CHECK:   %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
+//       CHECK:   %[[TRANSPOSE:.*]] = vector.transpose %[[SHAPE_CAST]]
+//       CHECK:   return %[[TRANSPOSE]] : vector<2x3xi8>
+func.func @negative_shape_cast_transpose(%arg : vector<6xi8>) -> vector<2x3xi8> {
+  %0 = vector.shape_cast %arg : vector<6xi8> to vector<3x2xi8>
+  %1 = vector.transpose %0, [1, 0] : vector<3x2xi8> to vector<2x3xi8>
+  return %1 : vector<2x3xi8>
+}

@newling newling force-pushed the fold_transpose_shape_cast branch from b21a4a6 to f4ae206 Compare April 21, 2025 17:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants