Skip to content

[mlir][Vector] Remove more special case uses for extractelement/insertelement #130166

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 24, 2025

Conversation

Groverkss
Copy link
Member

@Groverkss Groverkss commented Mar 6, 2025

A number of places in our codebase special case to use extractelement/insertelement for 0D vectors, because extract/insert did not support 0D vectors previously. Since insert/extract support 0D vectors now, use them instead of special casing.

@llvmbot
Copy link
Member

llvmbot commented Mar 6, 2025

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Kunwar Grover (Groverkss)

Changes

A number of places in our codebase special case to use extractelement/insertelement for 0D vectors, because extract/insert did not support 0D vectors previously. Since insert/extract support 0D vectors now, use them instead of special casing.


Full diff: https://github.com/llvm/llvm-project/pull/130166.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+2)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+18-14)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp (+1-5)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp (+3-19)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp (+2-11)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+1-1)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+1-1)
  • (modified) mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir (+2-6)
  • (modified) mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir (+2-2)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index fbbf817ecff98..2f5436f353539 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -718,6 +718,7 @@ def Vector_ExtractOp :
   let results = (outs AnyType:$result);
 
   let builders = [
+    OpBuilder<(ins "Value":$source)>,
     OpBuilder<(ins "Value":$source, "int64_t":$position)>,
     OpBuilder<(ins "Value":$source, "OpFoldResult":$position)>,
     OpBuilder<(ins "Value":$source, "ArrayRef<int64_t>":$position)>,
@@ -913,6 +914,7 @@ def Vector_InsertOp :
   let results = (outs AnyVectorOfAnyRank:$result);
 
   let builders = [
+    OpBuilder<(ins "Value":$source, "Value":$dest)>,
     OpBuilder<(ins "Value":$source, "Value":$dest, "int64_t":$position)>,
     OpBuilder<(ins "Value":$source, "Value":$dest, "OpFoldResult":$position)>,
     OpBuilder<(ins "Value":$source, "Value":$dest, "ArrayRef<int64_t>":$position)>,
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8e0e723cf4ed3..860778fc9db38 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -560,11 +560,9 @@ struct ElideUnitDimsInMultiDimReduction
     } else {
       // This means we are reducing all the dimensions, and all reduction
       // dimensions are of size 1. So a simple extraction would do.
-      SmallVector<int64_t> zeroIdx(shape.size(), 0);
       if (mask)
-        mask = rewriter.create<vector::ExtractOp>(loc, mask, zeroIdx);
-      cast = rewriter.create<vector::ExtractOp>(loc, reductionOp.getSource(),
-                                                zeroIdx);
+        mask = rewriter.create<vector::ExtractOp>(loc, mask);
+      cast = rewriter.create<vector::ExtractOp>(loc, reductionOp.getSource());
     }
 
     Value result =
@@ -698,16 +696,9 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
       return failure();
 
     Location loc = reductionOp.getLoc();
-    Value result;
-    if (vectorType.getRank() == 0) {
-      if (mask)
-        mask = rewriter.create<ExtractElementOp>(loc, mask);
-      result = rewriter.create<ExtractElementOp>(loc, reductionOp.getVector());
-    } else {
-      if (mask)
-        mask = rewriter.create<ExtractOp>(loc, mask, 0);
-      result = rewriter.create<ExtractOp>(loc, reductionOp.getVector(), 0);
-    }
+    if (mask)
+      mask = rewriter.create<ExtractOp>(loc, mask);
+    Value result = rewriter.create<ExtractOp>(loc, reductionOp.getVector());
 
     if (Value acc = reductionOp.getAcc())
       result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
@@ -1294,6 +1285,12 @@ void ExtractOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
   setResultRanges(getResult(), argRanges.front());
 }
 
+void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
+                              Value source) {
+  auto vectorTy = cast<VectorType>(source.getType());
+  build(builder, result, source, SmallVector<int64_t>(vectorTy.getRank(), 0));
+}
+
 void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
                               Value source, int64_t position) {
   build(builder, result, source, ArrayRef<int64_t>{position});
@@ -2916,6 +2913,13 @@ void vector::InsertOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
   setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
 }
 
+void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
+                             Value source, Value dest) {
+  auto vectorTy = cast<VectorType>(dest.getType());
+  build(builder, result, source, dest,
+        SmallVector<int64_t>(vectorTy.getRank(), 0));
+}
+
 void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
                              Value source, Value dest, int64_t position) {
   build(builder, result, source, dest, ArrayRef<int64_t>{position});
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
index fec3c6c52e5e4..11dcfe421e0c4 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
@@ -52,11 +52,7 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
 
     // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
     if (srcRank <= 1 && dstRank == 1) {
-      Value ext;
-      if (srcRank == 0)
-        ext = rewriter.create<vector::ExtractElementOp>(loc, op.getSource());
-      else
-        ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0);
+      Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource());
       rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
       return success();
     }
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
index 9c1e5fcee91de..23324a007377e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
@@ -189,25 +189,9 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
         incIdx(resIdx, resultVectorType);
       }
 
-      Value extract;
-      if (srcRank == 0) {
-        // 0-D vector special case
-        assert(srcIdx.empty() && "Unexpected indices for 0-D vector");
-        extract = rewriter.create<vector::ExtractElementOp>(
-            loc, op.getSourceVectorType().getElementType(), op.getSource());
-      } else {
-        extract =
-            rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
-      }
-
-      if (resRank == 0) {
-        // 0-D vector special case
-        assert(resIdx.empty() && "Unexpected indices for 0-D vector");
-        result = rewriter.create<vector::InsertElementOp>(loc, extract, result);
-      } else {
-        result =
-            rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
-      }
+      Value extract =
+          rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
+      result = rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
     }
     rewriter.replaceOp(op, result);
     return success();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 2413a4126f3f7..074c2d5664f64 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -920,17 +920,8 @@ class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> {
     if (!xferOp.getPermutationMap().isMinorIdentity())
       return failure();
     // Only float and integer element types are supported.
-    Value scalar;
-    if (vecType.getRank() == 0) {
-      // vector.extract does not support vector<f32> etc., so use
-      // vector.extractelement instead.
-      scalar = rewriter.create<vector::ExtractElementOp>(xferOp.getLoc(),
-                                                         xferOp.getVector());
-    } else {
-      SmallVector<int64_t> pos(vecType.getRank(), 0);
-      scalar = rewriter.create<vector::ExtractOp>(xferOp.getLoc(),
-                                                  xferOp.getVector(), pos);
-    }
+    Value scalar =
+        rewriter.create<vector::ExtractOp>(xferOp.getLoc(), xferOp.getVector());
     // Construct a scalar store.
     if (isa<MemRefType>(xferOp.getSource().getType())) {
       rewriter.replaceOpWithNewOp<memref::StoreOp>(
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 1ab28b9df2d19..52a2224b963f2 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -187,7 +187,7 @@ func.func @broadcast_vec2d_from_vec0d(%arg0: vector<f32>) -> vector<3x2xf32> {
 //       CHECK: %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<f32> to vector<1xf32>
 //       CHECK: %[[T1:.*]] = ub.poison : vector<3x2xf32>
 //       CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
-//       CHECK: %[[T4:.*]] = llvm.mlir.constant(0 : index) : i64
+//       CHECK: %[[T4:.*]] = llvm.mlir.constant(0 : i64) : i64
 //       CHECK: %[[T5:.*]] = llvm.extractelement %[[T0]][%[[T4]] : i64] : vector<1xf32>
 //       CHECK: %[[T6Insert:.*]] = llvm.insertelement %[[T5]]
 //       CHECK: %[[T6:.*]] = llvm.shufflevector %[[T6Insert]]
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index bf755b466c7eb..8bb6593d99058 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2658,7 +2658,7 @@ func.func @fold_extractelement_of_broadcast(%f: f32) -> f32 {
 
 // CHECK-LABEL: func.func @fold_0d_vector_reduction
 func.func @fold_0d_vector_reduction(%arg0: vector<f32>) -> f32 {
-  // CHECK-NEXT: %[[RES:.*]] = vector.extractelement %arg{{.*}}[] : vector<f32>
+  // CHECK-NEXT: %[[RES:.*]] = vector.extract %arg{{.*}}[] : f32 from vector<f32>
   // CHECK-NEXT: return %[[RES]] : f32
   %0 = vector.reduction <add>, %arg0 : vector<f32> into f32
   return %0 : f32
diff --git a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
index b4ebb14b8829e..52b0fdee184f6 100644
--- a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
+++ b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
@@ -45,9 +45,7 @@ func.func @tensor_transfer_read_0d(%t: tensor<?x?x?xf32>, %idx: index) -> f32 {
 
 // CHECK-LABEL: func @transfer_write_0d(
 //  CHECK-SAME:     %[[m:.*]]: memref<?x?x?xf32>, %[[idx:.*]]: index, %[[f:.*]]: f32
-//       CHECK:   %[[bc:.*]] = vector.broadcast %[[f]] : f32 to vector<f32>
-//       CHECK:   %[[extract:.*]] = vector.extractelement %[[bc]][] : vector<f32>
-//       CHECK:   memref.store %[[extract]], %[[m]][%[[idx]], %[[idx]], %[[idx]]]
+//       CHECK:   memref.store %[[f]], %[[m]][%[[idx]], %[[idx]], %[[idx]]]
 func.func @transfer_write_0d(%m: memref<?x?x?xf32>, %idx: index, %f: f32) {
   %0 = vector.broadcast %f : f32 to vector<f32>
   vector.transfer_write %0, %m[%idx, %idx, %idx] : vector<f32>, memref<?x?x?xf32>
@@ -69,9 +67,7 @@ func.func @transfer_write_1d(%m: memref<?x?x?xf32>, %idx: index, %f: f32) {
 
 // CHECK-LABEL: func @tensor_transfer_write_0d(
 //  CHECK-SAME:     %[[t:.*]]: tensor<?x?x?xf32>, %[[idx:.*]]: index, %[[f:.*]]: f32
-//       CHECK:   %[[bc:.*]] = vector.broadcast %[[f]] : f32 to vector<f32>
-//       CHECK:   %[[extract:.*]] = vector.extractelement %[[bc]][] : vector<f32>
-//       CHECK:   %[[r:.*]] = tensor.insert %[[extract]] into %[[t]][%[[idx]], %[[idx]], %[[idx]]]
+//       CHECK:   %[[r:.*]] = tensor.insert %[[f]] into %[[t]][%[[idx]], %[[idx]], %[[idx]]]
 //       CHECK:   return %[[r]]
 func.func @tensor_transfer_write_0d(%t: tensor<?x?x?xf32>, %idx: index, %f: f32) -> tensor<?x?x?xf32> {
   %0 = vector.broadcast %f : f32 to vector<f32>
diff --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
index ab30acf68b30b..ef32f8c6a1cdb 100644
--- a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
@@ -117,7 +117,7 @@ func.func @shape_cast_1d3d(%arg0 : vector<6xf32>) -> vector<2x1x3xf32> {
 // CHECK-LABEL:   func.func @shape_cast_0d1d(
 // CHECK-SAME:                               %[[ARG0:.*]]: vector<f32>) -> vector<1xf32> {
 // CHECK:           %[[UB:.*]] = ub.poison : vector<1xf32>
-// CHECK:           %[[EXTRACT0:.*]] = vector.extractelement %[[ARG0]][] : vector<f32>
+// CHECK:           %[[EXTRACT0:.*]] = vector.extract %[[ARG0]][] : f32 from vector<f32>
 // CHECK:           %[[RES:.*]] = vector.insert %[[EXTRACT0]], %[[UB]] [0] : f32 into vector<1xf32>
 // CHECK:           return %[[RES]] : vector<1xf32>
 // CHECK:         }
@@ -131,7 +131,7 @@ func.func @shape_cast_0d1d(%arg0 : vector<f32>) -> vector<1xf32> {
 // CHECK-SAME:                               %[[ARG0:.*]]: vector<1xf32>) -> vector<f32> {
 // CHECK:           %[[UB:.*]] = ub.poison : vector<f32>
 // CHECK:           %[[EXTRACT0:.*]] = vector.extract %[[ARG0]][0] : f32 from vector<1xf32>
-// CHECK:           %[[RES:.*]] = vector.insertelement %[[EXTRACT0]], %[[UB]][] : vector<f32>
+// CHECK:           %[[RES:.*]] = vector.insert %[[EXTRACT0]], %[[UB]] [] : f32 into vector<f32>
 // CHECK:           return %[[RES]] : vector<f32>
 // CHECK:         }
 

@Groverkss
Copy link
Member Author

While I think this change is straightforward, and does not handle any dynamic indices, so it should have no difference in performance (and probably an improvement), I'm happy to benchmark this for IREE if asked for.

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lovely to see progress on this, thank you @Groverkss ! 🙏🏻

LGTM % request for comments :)

@Groverkss
Copy link
Member Author

No regressions in IREE: iree-org/iree#20268

@Groverkss Groverkss force-pushed the remove-extractelement-insertelement branch from 58a09e2 to c85a2bd Compare March 24, 2025 12:53
@Groverkss Groverkss merged commit dc28e0d into llvm:main Mar 24, 2025
8 of 10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants