-
Notifications
You must be signed in to change notification settings - Fork 14.1k
[mlir][Vector] Remove more special case uses for extractelement/insertelement #130166
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Vector] Remove more special case uses for extractelement/insertelement #130166
Conversation
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Kunwar Grover (Groverkss) ChangesA number of places in our codebase special case to use extractelement/insertelement for 0D vectors, because extract/insert did not support 0D vectors previously. Since insert/extract support 0D vectors now, use them instead of special casing. Full diff: https://github.com/llvm/llvm-project/pull/130166.diff 9 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index fbbf817ecff98..2f5436f353539 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -718,6 +718,7 @@ def Vector_ExtractOp :
let results = (outs AnyType:$result);
let builders = [
+ OpBuilder<(ins "Value":$source)>,
OpBuilder<(ins "Value":$source, "int64_t":$position)>,
OpBuilder<(ins "Value":$source, "OpFoldResult":$position)>,
OpBuilder<(ins "Value":$source, "ArrayRef<int64_t>":$position)>,
@@ -913,6 +914,7 @@ def Vector_InsertOp :
let results = (outs AnyVectorOfAnyRank:$result);
let builders = [
+ OpBuilder<(ins "Value":$source, "Value":$dest)>,
OpBuilder<(ins "Value":$source, "Value":$dest, "int64_t":$position)>,
OpBuilder<(ins "Value":$source, "Value":$dest, "OpFoldResult":$position)>,
OpBuilder<(ins "Value":$source, "Value":$dest, "ArrayRef<int64_t>":$position)>,
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8e0e723cf4ed3..860778fc9db38 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -560,11 +560,9 @@ struct ElideUnitDimsInMultiDimReduction
} else {
// This means we are reducing all the dimensions, and all reduction
// dimensions are of size 1. So a simple extraction would do.
- SmallVector<int64_t> zeroIdx(shape.size(), 0);
if (mask)
- mask = rewriter.create<vector::ExtractOp>(loc, mask, zeroIdx);
- cast = rewriter.create<vector::ExtractOp>(loc, reductionOp.getSource(),
- zeroIdx);
+ mask = rewriter.create<vector::ExtractOp>(loc, mask);
+ cast = rewriter.create<vector::ExtractOp>(loc, reductionOp.getSource());
}
Value result =
@@ -698,16 +696,9 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
return failure();
Location loc = reductionOp.getLoc();
- Value result;
- if (vectorType.getRank() == 0) {
- if (mask)
- mask = rewriter.create<ExtractElementOp>(loc, mask);
- result = rewriter.create<ExtractElementOp>(loc, reductionOp.getVector());
- } else {
- if (mask)
- mask = rewriter.create<ExtractOp>(loc, mask, 0);
- result = rewriter.create<ExtractOp>(loc, reductionOp.getVector(), 0);
- }
+ if (mask)
+ mask = rewriter.create<ExtractOp>(loc, mask);
+ Value result = rewriter.create<ExtractOp>(loc, reductionOp.getVector());
if (Value acc = reductionOp.getAcc())
result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
@@ -1294,6 +1285,12 @@ void ExtractOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
setResultRanges(getResult(), argRanges.front());
}
+void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
+ Value source) {
+ auto vectorTy = cast<VectorType>(source.getType());
+ build(builder, result, source, SmallVector<int64_t>(vectorTy.getRank(), 0));
+}
+
void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
Value source, int64_t position) {
build(builder, result, source, ArrayRef<int64_t>{position});
@@ -2916,6 +2913,13 @@ void vector::InsertOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
}
+void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
+ Value source, Value dest) {
+ auto vectorTy = cast<VectorType>(dest.getType());
+ build(builder, result, source, dest,
+ SmallVector<int64_t>(vectorTy.getRank(), 0));
+}
+
void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
Value source, Value dest, int64_t position) {
build(builder, result, source, dest, ArrayRef<int64_t>{position});
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
index fec3c6c52e5e4..11dcfe421e0c4 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
@@ -52,11 +52,7 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
// Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
if (srcRank <= 1 && dstRank == 1) {
- Value ext;
- if (srcRank == 0)
- ext = rewriter.create<vector::ExtractElementOp>(loc, op.getSource());
- else
- ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0);
+ Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource());
rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
return success();
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
index 9c1e5fcee91de..23324a007377e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
@@ -189,25 +189,9 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
incIdx(resIdx, resultVectorType);
}
- Value extract;
- if (srcRank == 0) {
- // 0-D vector special case
- assert(srcIdx.empty() && "Unexpected indices for 0-D vector");
- extract = rewriter.create<vector::ExtractElementOp>(
- loc, op.getSourceVectorType().getElementType(), op.getSource());
- } else {
- extract =
- rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
- }
-
- if (resRank == 0) {
- // 0-D vector special case
- assert(resIdx.empty() && "Unexpected indices for 0-D vector");
- result = rewriter.create<vector::InsertElementOp>(loc, extract, result);
- } else {
- result =
- rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
- }
+ Value extract =
+ rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
+ result = rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
}
rewriter.replaceOp(op, result);
return success();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 2413a4126f3f7..074c2d5664f64 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -920,17 +920,8 @@ class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> {
if (!xferOp.getPermutationMap().isMinorIdentity())
return failure();
// Only float and integer element types are supported.
- Value scalar;
- if (vecType.getRank() == 0) {
- // vector.extract does not support vector<f32> etc., so use
- // vector.extractelement instead.
- scalar = rewriter.create<vector::ExtractElementOp>(xferOp.getLoc(),
- xferOp.getVector());
- } else {
- SmallVector<int64_t> pos(vecType.getRank(), 0);
- scalar = rewriter.create<vector::ExtractOp>(xferOp.getLoc(),
- xferOp.getVector(), pos);
- }
+ Value scalar =
+ rewriter.create<vector::ExtractOp>(xferOp.getLoc(), xferOp.getVector());
// Construct a scalar store.
if (isa<MemRefType>(xferOp.getSource().getType())) {
rewriter.replaceOpWithNewOp<memref::StoreOp>(
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 1ab28b9df2d19..52a2224b963f2 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -187,7 +187,7 @@ func.func @broadcast_vec2d_from_vec0d(%arg0: vector<f32>) -> vector<3x2xf32> {
// CHECK: %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<f32> to vector<1xf32>
// CHECK: %[[T1:.*]] = ub.poison : vector<3x2xf32>
// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
-// CHECK: %[[T4:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[T4:.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[T5:.*]] = llvm.extractelement %[[T0]][%[[T4]] : i64] : vector<1xf32>
// CHECK: %[[T6Insert:.*]] = llvm.insertelement %[[T5]]
// CHECK: %[[T6:.*]] = llvm.shufflevector %[[T6Insert]]
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index bf755b466c7eb..8bb6593d99058 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2658,7 +2658,7 @@ func.func @fold_extractelement_of_broadcast(%f: f32) -> f32 {
// CHECK-LABEL: func.func @fold_0d_vector_reduction
func.func @fold_0d_vector_reduction(%arg0: vector<f32>) -> f32 {
- // CHECK-NEXT: %[[RES:.*]] = vector.extractelement %arg{{.*}}[] : vector<f32>
+ // CHECK-NEXT: %[[RES:.*]] = vector.extract %arg{{.*}}[] : f32 from vector<f32>
// CHECK-NEXT: return %[[RES]] : f32
%0 = vector.reduction <add>, %arg0 : vector<f32> into f32
return %0 : f32
diff --git a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
index b4ebb14b8829e..52b0fdee184f6 100644
--- a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
+++ b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
@@ -45,9 +45,7 @@ func.func @tensor_transfer_read_0d(%t: tensor<?x?x?xf32>, %idx: index) -> f32 {
// CHECK-LABEL: func @transfer_write_0d(
// CHECK-SAME: %[[m:.*]]: memref<?x?x?xf32>, %[[idx:.*]]: index, %[[f:.*]]: f32
-// CHECK: %[[bc:.*]] = vector.broadcast %[[f]] : f32 to vector<f32>
-// CHECK: %[[extract:.*]] = vector.extractelement %[[bc]][] : vector<f32>
-// CHECK: memref.store %[[extract]], %[[m]][%[[idx]], %[[idx]], %[[idx]]]
+// CHECK: memref.store %[[f]], %[[m]][%[[idx]], %[[idx]], %[[idx]]]
func.func @transfer_write_0d(%m: memref<?x?x?xf32>, %idx: index, %f: f32) {
%0 = vector.broadcast %f : f32 to vector<f32>
vector.transfer_write %0, %m[%idx, %idx, %idx] : vector<f32>, memref<?x?x?xf32>
@@ -69,9 +67,7 @@ func.func @transfer_write_1d(%m: memref<?x?x?xf32>, %idx: index, %f: f32) {
// CHECK-LABEL: func @tensor_transfer_write_0d(
// CHECK-SAME: %[[t:.*]]: tensor<?x?x?xf32>, %[[idx:.*]]: index, %[[f:.*]]: f32
-// CHECK: %[[bc:.*]] = vector.broadcast %[[f]] : f32 to vector<f32>
-// CHECK: %[[extract:.*]] = vector.extractelement %[[bc]][] : vector<f32>
-// CHECK: %[[r:.*]] = tensor.insert %[[extract]] into %[[t]][%[[idx]], %[[idx]], %[[idx]]]
+// CHECK: %[[r:.*]] = tensor.insert %[[f]] into %[[t]][%[[idx]], %[[idx]], %[[idx]]]
// CHECK: return %[[r]]
func.func @tensor_transfer_write_0d(%t: tensor<?x?x?xf32>, %idx: index, %f: f32) -> tensor<?x?x?xf32> {
%0 = vector.broadcast %f : f32 to vector<f32>
diff --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
index ab30acf68b30b..ef32f8c6a1cdb 100644
--- a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
@@ -117,7 +117,7 @@ func.func @shape_cast_1d3d(%arg0 : vector<6xf32>) -> vector<2x1x3xf32> {
// CHECK-LABEL: func.func @shape_cast_0d1d(
// CHECK-SAME: %[[ARG0:.*]]: vector<f32>) -> vector<1xf32> {
// CHECK: %[[UB:.*]] = ub.poison : vector<1xf32>
-// CHECK: %[[EXTRACT0:.*]] = vector.extractelement %[[ARG0]][] : vector<f32>
+// CHECK: %[[EXTRACT0:.*]] = vector.extract %[[ARG0]][] : f32 from vector<f32>
// CHECK: %[[RES:.*]] = vector.insert %[[EXTRACT0]], %[[UB]] [0] : f32 into vector<1xf32>
// CHECK: return %[[RES]] : vector<1xf32>
// CHECK: }
@@ -131,7 +131,7 @@ func.func @shape_cast_0d1d(%arg0 : vector<f32>) -> vector<1xf32> {
// CHECK-SAME: %[[ARG0:.*]]: vector<1xf32>) -> vector<f32> {
// CHECK: %[[UB:.*]] = ub.poison : vector<f32>
// CHECK: %[[EXTRACT0:.*]] = vector.extract %[[ARG0]][0] : f32 from vector<1xf32>
-// CHECK: %[[RES:.*]] = vector.insertelement %[[EXTRACT0]], %[[UB]][] : vector<f32>
+// CHECK: %[[RES:.*]] = vector.insert %[[EXTRACT0]], %[[UB]] [] : f32 into vector<f32>
// CHECK: return %[[RES]] : vector<f32>
// CHECK: }
|
While I think this change is straightforward, and does not handle any dynamic indices, so it should have no difference in performance (and probably an improvement), I'm happy to benchmark this for IREE if asked for. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lovely to see progress on this, thank you @Groverkss ! 🙏🏻
LGTM % request for comments :)
No regressions in IREE: iree-org/iree#20268 |
58a09e2
to
c85a2bd
Compare
A number of places in our codebase special case to use extractelement/insertelement for 0D vectors, because extract/insert did not support 0D vectors previously. Since insert/extract support 0D vectors now, use them instead of special casing.