Skip to content

[mlir][vector] Improve shape_cast lowering #140800

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 5, 2025

Conversation

newling
Copy link
Contributor

@newling newling commented May 20, 2025

Before this PR, a rank-m -> rank-n vector.shape_cast with m,n>1 was lowered to extracts/inserts of single elements, so that a shape_cast on a vector with N elements would always require N extracts/inserts. While this is necessary in the worst case scenario it is sometimes possible to use fewer, larger extracts/inserts. Specifically, the largest common suffix on the shapes of the source and result can be extracted/inserted. For example:

%0 = vector.shape_cast %arg0 : vector<10x2x3xf32> to vector<2x5x2x3xf32>

has common suffix of shape 2x3. Before this PR, this would be lowered to 60 extract/insert pairs with extracts of the form
vector.extract %arg0 [a, b, c] : f32 from vector<10x2x3xf32>. With this PR it is 10 extract/insert pairs with extracts of the form vector.extract %arg0 [a] : vector<2x3xf32> from vector<10x2x3xf32>.

This case first mentioned here: #138777 (comment)

General direction of travel:

shape_cast lowering needs to be at least as 'good' as lowering of broadcast, transpose, and extract for the canonicalizations proposed (#138777 and #140583) Eventually I would like to move shape_cast lowering (i.e. vector->vector) to the later conversion step (i.e. vector -> llvm/spirv)

@newling newling force-pushed the improve_shape_cast_lowering branch from ff92faa to 6c0bc5d Compare May 22, 2025 18:54
@newling newling marked this pull request as ready for review May 22, 2025 18:59
@llvmbot
Copy link
Member

llvmbot commented May 22, 2025

@llvm/pr-subscribers-mlir-vector

Author: James Newling (newling)

Changes

Before this PR, a rank-m -> rank-n vector.shape_cast with m,n>1 was lowered to extracts/inserts of single elements, so that a shape_cast on a vector with N elements would always require N extracts/inserts. While this is necessary in the worst case scenario it is sometimes possible to use fewer, larger extracts/inserts. Specifically, the largest common suffix on the shapes of the source and result can be extracted/inserted. For example:

%0 = vector.shape_cast %arg0 : vector&lt;10x2x3xf32&gt; to vector&lt;2x5x2x3xf32&gt;

has common suffix of shape 2x3. Before this PR, this would be lowered to 60 extract/insert pairs with extracts of the form
vector.extract %arg0 [a, b, c] : f32 from vector&lt;10x2x3xf32&gt;. With this PR it is 10 extract/insert pairs with extracts of the form vector.extract %arg0 [a] : vector&lt;2x3xf32&gt; from vector&lt;10x2x3xf32&gt;.

This case first mentioned here: #138777 (comment)


Full diff: https://github.com/llvm/llvm-project/pull/140800.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp (+48-32)
  • (modified) mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir (+53)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
index 23324a007377e..d0085bffca23c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
@@ -28,17 +28,20 @@ using namespace mlir;
 using namespace mlir::vector;
 
 /// Increments n-D `indices` by `step` starting from the innermost dimension.
-static void incIdx(SmallVectorImpl<int64_t> &indices, VectorType vecType,
+static void incIdx(MutableArrayRef<int64_t> indices, ArrayRef<int64_t> shape,
                    int step = 1) {
   for (int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) {
-    assert(indices[dim] < vecType.getDimSize(dim) &&
-           "Indices are out of bound");
+    int64_t dimSize = shape[dim];
+    assert(indices[dim] < dimSize && "Indices are out of bound");
+
     indices[dim] += step;
-    if (indices[dim] < vecType.getDimSize(dim))
+
+    int64_t spill = indices[dim] / dimSize;
+    if (spill == 0)
       break;
 
-    indices[dim] = 0;
-    step = 1;
+    indices[dim] %= dimSize;
+    step = spill;
   }
 }
 
@@ -79,8 +82,8 @@ class ShapeCastOpNDDownCastRewritePattern
     // and destination slice insertion and generate such instructions.
     for (int64_t i = 0; i < numElts; ++i) {
       if (i != 0) {
-        incIdx(srcIdx, sourceVectorType, /*step=*/1);
-        incIdx(resIdx, resultVectorType, /*step=*/extractSize);
+        incIdx(srcIdx, sourceVectorType.getShape(), /*step=*/1);
+        incIdx(resIdx, resultVectorType.getShape(), /*step=*/extractSize);
       }
 
       Value extract =
@@ -131,8 +134,8 @@ class ShapeCastOpNDUpCastRewritePattern
     Value result = rewriter.create<ub::PoisonOp>(loc, resultVectorType);
     for (int64_t i = 0; i < numElts; ++i) {
       if (i != 0) {
-        incIdx(srcIdx, sourceVectorType, /*step=*/extractSize);
-        incIdx(resIdx, resultVectorType, /*step=*/1);
+        incIdx(srcIdx, sourceVectorType.getShape(), /*step=*/extractSize);
+        incIdx(resIdx, resultVectorType.getShape(), /*step=*/1);
       }
 
       Value extract = rewriter.create<vector::ExtractStridedSliceOp>(
@@ -157,41 +160,54 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
                                 PatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
-    auto sourceVectorType = op.getSourceVectorType();
-    auto resultVectorType = op.getResultVectorType();
+    VectorType sourceType = op.getSourceVectorType();
+    VectorType resultType = op.getResultVectorType();
 
-    if (sourceVectorType.isScalable() || resultVectorType.isScalable())
+    if (sourceType.isScalable() || resultType.isScalable())
       return failure();
 
-    // Special case for n-D / 1-D lowerings with better implementations.
-    int64_t srcRank = sourceVectorType.getRank();
-    int64_t resRank = resultVectorType.getRank();
-    if ((srcRank > 1 && resRank == 1) || (srcRank == 1 && resRank > 1))
+    // Special case for n-D / 1-D lowerings with implementations that use
+    // extract_strided_slice / insert_strided_slice.
+    int64_t sourceRank = sourceType.getRank();
+    int64_t resultRank = resultType.getRank();
+    if ((sourceRank > 1 && resultRank == 1) ||
+        (sourceRank == 1 && resultRank > 1))
       return failure();
 
-    // Generic ShapeCast lowering path goes all the way down to unrolled scalar
-    // extract/insert chains.
-    int64_t numElts = 1;
-    for (int64_t r = 0; r < srcRank; r++)
-      numElts *= sourceVectorType.getDimSize(r);
+    int64_t numExtracts = sourceType.getNumElements();
+    int64_t nbCommonInnerDims = 0;
+    while (true) {
+      int64_t sourceDim = sourceRank - 1 - nbCommonInnerDims;
+      int64_t resultDim = resultRank - 1 - nbCommonInnerDims;
+      if (sourceDim < 0 || resultDim < 0)
+        break;
+      int64_t dimSize = sourceType.getDimSize(sourceDim);
+      if (dimSize != resultType.getDimSize(resultDim))
+        break;
+      numExtracts /= dimSize;
+      ++nbCommonInnerDims;
+    }
+
     // Replace with data movement operations:
     //    x[0,0,0] = y[0,0]
     //    x[0,0,1] = y[0,1]
     //    x[0,1,0] = y[0,2]
     // etc., incrementing the two index vectors "row-major"
     // within the source and result shape.
-    SmallVector<int64_t> srcIdx(srcRank, 0);
-    SmallVector<int64_t> resIdx(resRank, 0);
-    Value result = rewriter.create<ub::PoisonOp>(loc, resultVectorType);
-    for (int64_t i = 0; i < numElts; i++) {
+    SmallVector<int64_t> sourceIndex(sourceRank - nbCommonInnerDims, 0);
+    SmallVector<int64_t> resultIndex(resultRank - nbCommonInnerDims, 0);
+    Value result = rewriter.create<ub::PoisonOp>(loc, resultType);
+
+    for (int64_t i = 0; i < numExtracts; i++) {
       if (i != 0) {
-        incIdx(srcIdx, sourceVectorType);
-        incIdx(resIdx, resultVectorType);
+        incIdx(sourceIndex, sourceType.getShape().drop_back(nbCommonInnerDims));
+        incIdx(resultIndex, resultType.getShape().drop_back(nbCommonInnerDims));
       }
 
       Value extract =
-          rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
-      result = rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
+          rewriter.create<vector::ExtractOp>(loc, op.getSource(), sourceIndex);
+      result =
+          rewriter.create<vector::InsertOp>(loc, extract, result, resultIndex);
     }
     rewriter.replaceOp(op, result);
     return success();
@@ -329,8 +345,8 @@ class ScalableShapeCastOpRewritePattern
 
       // 4. Increment the insert/extract indices, stepping by minExtractionSize
       // for the trailing dimensions.
-      incIdx(srcIdx, sourceVectorType, /*step=*/minExtractionSize);
-      incIdx(resIdx, resultVectorType, /*step=*/minExtractionSize);
+      incIdx(srcIdx, sourceVectorType.getShape(), /*step=*/minExtractionSize);
+      incIdx(resIdx, resultVectorType.getShape(), /*step=*/minExtractionSize);
     }
 
     rewriter.replaceOp(op, result);
diff --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
index ef32f8c6a1cdb..2875f159a2df9 100644
--- a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
@@ -140,6 +140,59 @@ func.func @shape_cast_1d0d(%arg0 : vector<1xf32>) -> vector<f32> {
   return %s : vector<f32>
 }
 
+
+// The shapes have 2 inner dimension sizes in common, so the extract result is rank-2.
+// CHECK-LABEL:  func.func @squeeze_out_prefix_unit_dim(
+// CHECK-SAME:               %[[ARG0:.*]]: vector<1x2x3xf32>) -> vector<2x3xf32> {
+// CHECK:          %[[EXTRACTED:.*]] = vector.extract %[[ARG0]][0] :
+// CHECK-SAME:               vector<2x3xf32> from vector<1x2x3xf32>
+// CHECK:          return %[[EXTRACTED]] : vector<2x3xf32>
+func.func @squeeze_out_prefix_unit_dim(%arg0 : vector<1x2x3xf32>) -> vector<2x3xf32> {
+  %s = vector.shape_cast %arg0 : vector<1x2x3xf32> to vector<2x3xf32>
+  return %s : vector<2x3xf32>
+}
+
+// The shapes have 1 inner dimension size in common, so the extract results are rank-1.
+// CHECK-LABEL:  func.func @squeeze_out_middle_unit_dim(
+// CHECK-SAME:               %[[ARG0:.*]]: vector<2x1x3xf32>) -> vector<2x3xf32> {
+// CHECK:         %[[UB:.*]] = ub.poison : vector<2x3xf32>
+// CHECK:         %[[E0:.*]] = vector.extract %[[ARG0]][0, 0] : vector<3xf32>
+// CHECK:         %[[I0:.*]] = vector.insert %[[E0]], %[[UB]] [0] : vector<3xf32>
+// CHECK-SAME:                 into vector<2x3xf32>
+// CHECK:         %[[E1:.*]] = vector.extract %[[ARG0]][1, 0] : vector<3xf32>
+// CHECK:         %[[I1:.*]] = vector.insert %[[E1]], %[[I0]] [1] : vector<3xf32>
+// CHECK-SAME:                 into vector<2x3xf32>
+// CHECK:         return %[[I1]] : vector<2x3xf32>
+func.func @squeeze_out_middle_unit_dim(%arg0 : vector<2x1x3xf32>) -> vector<2x3xf32> {
+  %s = vector.shape_cast %arg0 : vector<2x1x3xf32> to vector<2x3xf32>
+  return %s : vector<2x3xf32>
+}
+
+// CHECK-LABEL:  func.func @prepend_unit_dim(
+// CHECK-SAME:               %[[ARG0:.*]]: vector<2x3xf32>) -> vector<1x2x3xf32> {
+// CHECK:          %[[UB:.*]] = ub.poison : vector<1x2x3xf32>
+// CHECK:          %[[INSERTED:.*]] = vector.insert %[[ARG0]], %[[UB]] [0]
+// CHECK-SAME:               : vector<2x3xf32> into vector<1x2x3xf32>
+// CHECK:        return %[[INSERTED]] : vector<1x2x3xf32>
+func.func @prepend_unit_dim(%arg0 : vector<2x3xf32>) -> vector<1x2x3xf32> {
+  %s = vector.shape_cast %arg0 : vector<2x3xf32> to vector<1x2x3xf32>
+  return %s : vector<1x2x3xf32>
+}
+
+// CHECK-LABEL:  func.func @insert_middle_unit_dim(
+// CHECK-SAME:               %[[ARG0:.*]]: vector<2x3xf32>) -> vector<2x1x3xf32> {
+// CHECK:           %[[UB:.*]] = ub.poison : vector<2x1x3xf32>
+// CHECK:           %[[E0:.*]] = vector.extract %[[ARG0]][0] : vector<3xf32>
+// CHECK:           %[[I0:.*]] = vector.insert %[[E0]], %[[UB]] [0, 0] : vector<3xf32>
+// CHECK:           %[[E1:.*]] = vector.extract %[[ARG0]][1] : vector<3xf32>
+// CHECK:           %[[I1:.*]] = vector.insert %[[E1]], %[[I0]] [1, 0] : vector<3xf32>
+// CHECK:           return %[[I1]] : vector<2x1x3xf32>
+func.func @insert_middle_unit_dim(%arg0 : vector<2x3xf32>) -> vector<2x1x3xf32> {
+  %s = vector.shape_cast %arg0 : vector<2x3xf32> to vector<2x1x3xf32>
+  return %s : vector<2x1x3xf32>
+}
+
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
     %f = transform.structured.match ops{["func.func"]} in %module_op

@llvmbot
Copy link
Member

llvmbot commented May 22, 2025

@llvm/pr-subscribers-mlir

Author: James Newling (newling)

Changes

Before this PR, a rank-m -> rank-n vector.shape_cast with m,n>1 was lowered to extracts/inserts of single elements, so that a shape_cast on a vector with N elements would always require N extracts/inserts. While this is necessary in the worst case scenario it is sometimes possible to use fewer, larger extracts/inserts. Specifically, the largest common suffix on the shapes of the source and result can be extracted/inserted. For example:

%0 = vector.shape_cast %arg0 : vector&lt;10x2x3xf32&gt; to vector&lt;2x5x2x3xf32&gt;

has common suffix of shape 2x3. Before this PR, this would be lowered to 60 extract/insert pairs with extracts of the form
vector.extract %arg0 [a, b, c] : f32 from vector&lt;10x2x3xf32&gt;. With this PR it is 10 extract/insert pairs with extracts of the form vector.extract %arg0 [a] : vector&lt;2x3xf32&gt; from vector&lt;10x2x3xf32&gt;.

This case first mentioned here: #138777 (comment)


Full diff: https://github.com/llvm/llvm-project/pull/140800.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp (+48-32)
  • (modified) mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir (+53)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
index 23324a007377e..d0085bffca23c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
@@ -28,17 +28,20 @@ using namespace mlir;
 using namespace mlir::vector;
 
 /// Increments n-D `indices` by `step` starting from the innermost dimension.
-static void incIdx(SmallVectorImpl<int64_t> &indices, VectorType vecType,
+static void incIdx(MutableArrayRef<int64_t> indices, ArrayRef<int64_t> shape,
                    int step = 1) {
   for (int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) {
-    assert(indices[dim] < vecType.getDimSize(dim) &&
-           "Indices are out of bound");
+    int64_t dimSize = shape[dim];
+    assert(indices[dim] < dimSize && "Indices are out of bound");
+
     indices[dim] += step;
-    if (indices[dim] < vecType.getDimSize(dim))
+
+    int64_t spill = indices[dim] / dimSize;
+    if (spill == 0)
       break;
 
-    indices[dim] = 0;
-    step = 1;
+    indices[dim] %= dimSize;
+    step = spill;
   }
 }
 
@@ -79,8 +82,8 @@ class ShapeCastOpNDDownCastRewritePattern
     // and destination slice insertion and generate such instructions.
     for (int64_t i = 0; i < numElts; ++i) {
       if (i != 0) {
-        incIdx(srcIdx, sourceVectorType, /*step=*/1);
-        incIdx(resIdx, resultVectorType, /*step=*/extractSize);
+        incIdx(srcIdx, sourceVectorType.getShape(), /*step=*/1);
+        incIdx(resIdx, resultVectorType.getShape(), /*step=*/extractSize);
       }
 
       Value extract =
@@ -131,8 +134,8 @@ class ShapeCastOpNDUpCastRewritePattern
     Value result = rewriter.create<ub::PoisonOp>(loc, resultVectorType);
     for (int64_t i = 0; i < numElts; ++i) {
       if (i != 0) {
-        incIdx(srcIdx, sourceVectorType, /*step=*/extractSize);
-        incIdx(resIdx, resultVectorType, /*step=*/1);
+        incIdx(srcIdx, sourceVectorType.getShape(), /*step=*/extractSize);
+        incIdx(resIdx, resultVectorType.getShape(), /*step=*/1);
       }
 
       Value extract = rewriter.create<vector::ExtractStridedSliceOp>(
@@ -157,41 +160,54 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
                                 PatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
-    auto sourceVectorType = op.getSourceVectorType();
-    auto resultVectorType = op.getResultVectorType();
+    VectorType sourceType = op.getSourceVectorType();
+    VectorType resultType = op.getResultVectorType();
 
-    if (sourceVectorType.isScalable() || resultVectorType.isScalable())
+    if (sourceType.isScalable() || resultType.isScalable())
       return failure();
 
-    // Special case for n-D / 1-D lowerings with better implementations.
-    int64_t srcRank = sourceVectorType.getRank();
-    int64_t resRank = resultVectorType.getRank();
-    if ((srcRank > 1 && resRank == 1) || (srcRank == 1 && resRank > 1))
+    // Special case for n-D / 1-D lowerings with implementations that use
+    // extract_strided_slice / insert_strided_slice.
+    int64_t sourceRank = sourceType.getRank();
+    int64_t resultRank = resultType.getRank();
+    if ((sourceRank > 1 && resultRank == 1) ||
+        (sourceRank == 1 && resultRank > 1))
       return failure();
 
-    // Generic ShapeCast lowering path goes all the way down to unrolled scalar
-    // extract/insert chains.
-    int64_t numElts = 1;
-    for (int64_t r = 0; r < srcRank; r++)
-      numElts *= sourceVectorType.getDimSize(r);
+    int64_t numExtracts = sourceType.getNumElements();
+    int64_t nbCommonInnerDims = 0;
+    while (true) {
+      int64_t sourceDim = sourceRank - 1 - nbCommonInnerDims;
+      int64_t resultDim = resultRank - 1 - nbCommonInnerDims;
+      if (sourceDim < 0 || resultDim < 0)
+        break;
+      int64_t dimSize = sourceType.getDimSize(sourceDim);
+      if (dimSize != resultType.getDimSize(resultDim))
+        break;
+      numExtracts /= dimSize;
+      ++nbCommonInnerDims;
+    }
+
     // Replace with data movement operations:
     //    x[0,0,0] = y[0,0]
     //    x[0,0,1] = y[0,1]
     //    x[0,1,0] = y[0,2]
     // etc., incrementing the two index vectors "row-major"
     // within the source and result shape.
-    SmallVector<int64_t> srcIdx(srcRank, 0);
-    SmallVector<int64_t> resIdx(resRank, 0);
-    Value result = rewriter.create<ub::PoisonOp>(loc, resultVectorType);
-    for (int64_t i = 0; i < numElts; i++) {
+    SmallVector<int64_t> sourceIndex(sourceRank - nbCommonInnerDims, 0);
+    SmallVector<int64_t> resultIndex(resultRank - nbCommonInnerDims, 0);
+    Value result = rewriter.create<ub::PoisonOp>(loc, resultType);
+
+    for (int64_t i = 0; i < numExtracts; i++) {
       if (i != 0) {
-        incIdx(srcIdx, sourceVectorType);
-        incIdx(resIdx, resultVectorType);
+        incIdx(sourceIndex, sourceType.getShape().drop_back(nbCommonInnerDims));
+        incIdx(resultIndex, resultType.getShape().drop_back(nbCommonInnerDims));
       }
 
       Value extract =
-          rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
-      result = rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
+          rewriter.create<vector::ExtractOp>(loc, op.getSource(), sourceIndex);
+      result =
+          rewriter.create<vector::InsertOp>(loc, extract, result, resultIndex);
     }
     rewriter.replaceOp(op, result);
     return success();
@@ -329,8 +345,8 @@ class ScalableShapeCastOpRewritePattern
 
       // 4. Increment the insert/extract indices, stepping by minExtractionSize
       // for the trailing dimensions.
-      incIdx(srcIdx, sourceVectorType, /*step=*/minExtractionSize);
-      incIdx(resIdx, resultVectorType, /*step=*/minExtractionSize);
+      incIdx(srcIdx, sourceVectorType.getShape(), /*step=*/minExtractionSize);
+      incIdx(resIdx, resultVectorType.getShape(), /*step=*/minExtractionSize);
     }
 
     rewriter.replaceOp(op, result);
diff --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
index ef32f8c6a1cdb..2875f159a2df9 100644
--- a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
@@ -140,6 +140,59 @@ func.func @shape_cast_1d0d(%arg0 : vector<1xf32>) -> vector<f32> {
   return %s : vector<f32>
 }
 
+
+// The shapes have 2 inner dimension sizes in common, so the extract result is rank-2.
+// CHECK-LABEL:  func.func @squeeze_out_prefix_unit_dim(
+// CHECK-SAME:               %[[ARG0:.*]]: vector<1x2x3xf32>) -> vector<2x3xf32> {
+// CHECK:          %[[EXTRACTED:.*]] = vector.extract %[[ARG0]][0] :
+// CHECK-SAME:               vector<2x3xf32> from vector<1x2x3xf32>
+// CHECK:          return %[[EXTRACTED]] : vector<2x3xf32>
+func.func @squeeze_out_prefix_unit_dim(%arg0 : vector<1x2x3xf32>) -> vector<2x3xf32> {
+  %s = vector.shape_cast %arg0 : vector<1x2x3xf32> to vector<2x3xf32>
+  return %s : vector<2x3xf32>
+}
+
+// The shapes have 1 inner dimension size in common, so the extract results are rank-1.
+// CHECK-LABEL:  func.func @squeeze_out_middle_unit_dim(
+// CHECK-SAME:               %[[ARG0:.*]]: vector<2x1x3xf32>) -> vector<2x3xf32> {
+// CHECK:         %[[UB:.*]] = ub.poison : vector<2x3xf32>
+// CHECK:         %[[E0:.*]] = vector.extract %[[ARG0]][0, 0] : vector<3xf32>
+// CHECK:         %[[I0:.*]] = vector.insert %[[E0]], %[[UB]] [0] : vector<3xf32>
+// CHECK-SAME:                 into vector<2x3xf32>
+// CHECK:         %[[E1:.*]] = vector.extract %[[ARG0]][1, 0] : vector<3xf32>
+// CHECK:         %[[I1:.*]] = vector.insert %[[E1]], %[[I0]] [1] : vector<3xf32>
+// CHECK-SAME:                 into vector<2x3xf32>
+// CHECK:         return %[[I1]] : vector<2x3xf32>
+func.func @squeeze_out_middle_unit_dim(%arg0 : vector<2x1x3xf32>) -> vector<2x3xf32> {
+  %s = vector.shape_cast %arg0 : vector<2x1x3xf32> to vector<2x3xf32>
+  return %s : vector<2x3xf32>
+}
+
+// CHECK-LABEL:  func.func @prepend_unit_dim(
+// CHECK-SAME:               %[[ARG0:.*]]: vector<2x3xf32>) -> vector<1x2x3xf32> {
+// CHECK:          %[[UB:.*]] = ub.poison : vector<1x2x3xf32>
+// CHECK:          %[[INSERTED:.*]] = vector.insert %[[ARG0]], %[[UB]] [0]
+// CHECK-SAME:               : vector<2x3xf32> into vector<1x2x3xf32>
+// CHECK:        return %[[INSERTED]] : vector<1x2x3xf32>
+func.func @prepend_unit_dim(%arg0 : vector<2x3xf32>) -> vector<1x2x3xf32> {
+  %s = vector.shape_cast %arg0 : vector<2x3xf32> to vector<1x2x3xf32>
+  return %s : vector<1x2x3xf32>
+}
+
+// CHECK-LABEL:  func.func @insert_middle_unit_dim(
+// CHECK-SAME:               %[[ARG0:.*]]: vector<2x3xf32>) -> vector<2x1x3xf32> {
+// CHECK:           %[[UB:.*]] = ub.poison : vector<2x1x3xf32>
+// CHECK:           %[[E0:.*]] = vector.extract %[[ARG0]][0] : vector<3xf32>
+// CHECK:           %[[I0:.*]] = vector.insert %[[E0]], %[[UB]] [0, 0] : vector<3xf32>
+// CHECK:           %[[E1:.*]] = vector.extract %[[ARG0]][1] : vector<3xf32>
+// CHECK:           %[[I1:.*]] = vector.insert %[[E1]], %[[I0]] [1, 0] : vector<3xf32>
+// CHECK:           return %[[I1]] : vector<2x1x3xf32>
+func.func @insert_middle_unit_dim(%arg0 : vector<2x3xf32>) -> vector<2x1x3xf32> {
+  %s = vector.shape_cast %arg0 : vector<2x3xf32> to vector<2x1x3xf32>
+  return %s : vector<2x1x3xf32>
+}
+
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
     %f = transform.structured.match ops{["func.func"]} in %module_op

@newling newling force-pushed the improve_shape_cast_lowering branch from 6c0bc5d to 1a4b759 Compare May 23, 2025 16:59
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice improvements, thanks!

Please find some minor comments inline :)


if (sourceVectorType.isScalable() || resultVectorType.isScalable())
if (sourceType.isScalable() || resultType.isScalable())
return failure();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you are refactoring a fair bit, would you mind replacing this (and other instances of failure) with notifyMatchFailure? Thanks!

Comment on lines 169 to 170
// Special case for n-D / 1-D lowerings with implementations that use
// extract_strided_slice / insert_strided_slice.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we clarify this comment? (the original part is quite confusing). Right now, combined with the code, it reads a bit like:

This is a special case, lets fail!

😅 I assume that it was meant to be:

This special case is handled by other, more optimal patterns.

Or something similar :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't much like the logic spread over the 3 patterns (N->N, 1->N, N->1) as there isn't really anything special about then 1->N and N->1 cases. So I've done a fairly major update to the N->N pattern, so that now it handles then 1->N and N->1 cases. As the new change is quite significant, if you'd prefer it to be done in a separate PR I'm happy to postpone this 'unification', backtrack, and just make the minor suggestions to this PR that you suggested.

I also unified the tests across the test file. The behavior for the 1->N and N->1 cases is unchanged by this PR though.

Comment on lines 177 to 178
int64_t numExtracts = sourceType.getNumElements();
int64_t nbCommonInnerDims = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do both num and nb stand for number? Could you unify?

@@ -28,17 +28,20 @@ using namespace mlir;
using namespace mlir::vector;

/// Increments n-D `indices` by `step` starting from the innermost dimension.
static void incIdx(SmallVectorImpl<int64_t> &indices, VectorType vecType,
static void incIdx(MutableArrayRef<int64_t> indices, ArrayRef<int64_t> shape,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps just a slow day for me, but it took me a while to follow the logic in this method. Let me share my observations:

  • step only really defines the step for the trailing dim? If yes, it would be good to update the variable name.
  • spill is either 0 or 1.

is this correct?

Btw, extra documentation would help. My initial interpretation was: "Update every single index by 1", but that's not true, is it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've refactored and documented this significantly in the latest commit, it is hopefully now clearer

Comment on lines 175 to 176
// CHECK-SAME: : vector<2x3xf32> into vector<1x2x3xf32>
// CHECK: return %[[INSERTED]] : vector<1x2x3xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This indention is off.

Comment on lines 147 to 148
// CHECK: %[[EXTRACTED:.*]] = vector.extract %[[ARG0]][0] :
// CHECK-SAME: vector<2x3xf32> from vector<1x2x3xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This indentation is inconsistent with what's used in @squeeze_out_middle_unit_dim.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did a significant test file refactor to make the pre-existing and new tests consistent. The pre-existing test logic is unchanged. I'm happy to postpone refactoring the old tests to make that clearer

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm happy for you to keep all these changes in this PR - this is effectively a proper refactor of the pattern, which was in dire need of some TLC anyway 🙂

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @newling - I’m happy with the expanded scope; the overall direction looks very positive!

I’m slowly running out of cycles this week and will be travelling next week. No need to wait for my approval if another reviewer signs off. I’ll also try to post a few comments while in transit :)

Comment on lines 114 to 116
/// Two special cases are handled seperately:
/// (1) A shape_cast that just does leading 1 insertion/removal
/// (2) A shape_cast where the gcd is 1.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where are these cases handled?

Comment on lines 147 to 148
// CHECK: %[[EXTRACTED:.*]] = vector.extract %[[ARG0]][0] :
// CHECK-SAME: vector<2x3xf32> from vector<1x2x3xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm happy for you to keep all these changes in this PR - this is effectively a proper refactor of the pattern, which was in dire need of some TLC anyway 🙂

@newling
Copy link
Contributor Author

newling commented Jun 2, 2025

Hey @newling - I’m happy with the expanded scope; the overall direction looks very positive!

I’m slowly running out of cycles this week and will be travelling next week. No need to wait for my approval if another reviewer signs off. I’ll also try to post a few comments while in transit :)

Hi @banach-space - ok no worries, thanks for the heads up. And thanks for the numerous and useful reviews :)

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome improvement. LGTM % refactoring

/// +-----------------> gcd(4,6) is 2 | |
/// | | |
/// v v v
/// atomic shape <----- 2x7x11
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! This makes sense to me...

// IR is generated in this case if we just extract and insert the elements
// directly. In other words, we don't use extract_strided_slice and
// insert_strided_slice.
if (greatestCommonDivisor == 1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we refactor the different cases to functions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, looks a bit better now IMO

@newling newling merged commit 7ce315d into llvm:main Jun 5, 2025
9 of 10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants