-
Notifications
You must be signed in to change notification settings - Fork 13.3k
[MLIR][Vector]: Generalize conversion of vector.insert
to LLVM in line with vector.extract
#128915
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
@llvm/pr-subscribers-mlir Author: Benoit Jacob (bjacob) ChangesThis is doing the same as #117731 did for It is a bit more complicated as the insertion destination may itself need to be extracted. As the test shows, this fixes two previously unsupported cases:
Full diff: https://github.com/llvm/llvm-project/pull/128915.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index c9d637ce81f93..e1c7547774c3b 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -35,13 +35,6 @@
using namespace mlir;
using namespace mlir::vector;
-// Helper to reduce vector type by *all* but one rank at back.
-static VectorType reducedVectorTypeBack(VectorType tp) {
- assert((tp.getRank() > 1) && "unlowerable vector type");
- return VectorType::get(tp.getShape().take_back(), tp.getElementType(),
- tp.getScalableDims().take_back());
-}
-
// Helper that picks the proper sequence for inserting.
static Value insertOne(ConversionPatternRewriter &rewriter,
const LLVMTypeConverter &typeConverter, Location loc,
@@ -1223,7 +1216,6 @@ class VectorInsertOpConversion
matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = insertOp->getLoc();
- auto sourceType = insertOp.getSourceType();
auto destVectorType = insertOp.getDestVectorType();
auto llvmResultType = typeConverter->convertType(destVectorType);
// Bail if result type cannot be lowered.
@@ -1233,53 +1225,74 @@ class VectorInsertOpConversion
SmallVector<OpFoldResult> positionVec = getMixedValues(
adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
- // Overwrite entire vector with value. Should be handled by folder, but
- // just to be safe.
- ArrayRef<OpFoldResult> position(positionVec);
- if (position.empty()) {
- rewriter.replaceOp(insertOp, adaptor.getSource());
- return success();
- }
-
- // One-shot insertion of a vector into an array (only requires insertvalue).
- if (isa<VectorType>(sourceType)) {
- if (insertOp.hasDynamicPosition())
- return failure();
-
- Value inserted = rewriter.create<LLVM::InsertValueOp>(
- loc, adaptor.getDest(), adaptor.getSource(), getAsIntegers(position));
- rewriter.replaceOp(insertOp, inserted);
- return success();
+ // The logic in this pattern mirrors VectorExtractOpConversion. Refer to
+ // its explanatory comment about how N-D vectors are converted as nested
+ // aggregates (llvm.array's) of 1D vectors.
+ //
+ // There are 3 steps here, vs 2 in VectorExtractOpConversion:
+ // - Extraction of a 1D vector from the nested aggregate: llvm.extractvalue.
+ // - Insertion into the 1D vector: llvm.insertelement.
+ // - Insertion of the 1D vector into the nested aggregate: llvm.insertvalue.
+
+ // Determine if we need to extract/insert a 1D vector out of the aggregate.
+ bool is1DVectorWithinAggregate = isa<LLVM::LLVMArrayType>(llvmResultType);
+ // Determine if we need to insert a scalar into the 1D vector.
+ bool isScalarWithin1DVector =
+ static_cast<int64_t>(positionVec.size()) == destVectorType.getRank();
+
+ ArrayRef<OpFoldResult> positionOf1DVectorWithinAggregate(
+ positionVec.begin(),
+ isScalarWithin1DVector ? positionVec.size() - 1 : positionVec.size());
+ OpFoldResult positionOfScalarWithin1DVector;
+ if (destVectorType.getRank() == 0) {
+ // Since the LLVM type converter converts 0D vectors to 1D vectors, we
+ // need to create a 0 here as the position into the 1D vector.
+ Type idxType = typeConverter->convertType(rewriter.getIndexType());
+ positionOfScalarWithin1DVector = rewriter.getZeroAttr(idxType);
+ } else if (isScalarWithin1DVector) {
+ positionOfScalarWithin1DVector = positionVec.back();
}
- // Potential extraction of 1-D vector from array.
- Value extracted = adaptor.getDest();
- auto oneDVectorType = destVectorType;
- if (position.size() > 1) {
- if (insertOp.hasDynamicPosition())
- return failure();
-
- oneDVectorType = reducedVectorTypeBack(destVectorType);
- extracted = rewriter.create<LLVM::ExtractValueOp>(
- loc, extracted, getAsIntegers(position.drop_back()));
+ // We are going to mutate this 1D vector until it is either the final
+ // result (in the non-aggregate case) or the value that needs to be
+ // inserted into the aggregate result.
+ Value vector1d;
+ if (isScalarWithin1DVector) {
+ // Scalar-into-1D-vector case, so we know we will have to create a
+ // InsertElementOp. The question is into what destination.
+ if (is1DVectorWithinAggregate) {
+ // Aggregate case: the destination for the InsertElementOp needs to be
+ // extracted from the aggregate.
+ if (!llvm::all_of(positionOf1DVectorWithinAggregate,
+ llvm::IsaPred<Attribute>)) {
+ // llvm.extractvalue does not support dynamic dimensions.
+ return failure();
+ }
+ vector1d = rewriter.create<LLVM::ExtractValueOp>(
+ loc, adaptor.getDest(),
+ getAsIntegers(positionOf1DVectorWithinAggregate));
+ } else {
+ // No-aggregate case. The destination for the InsertElementOp is just
+ // the insertOp's destination.
+ vector1d = adaptor.getDest();
+ }
+ // Insert the scalar into the 1D vector.
+ vector1d = rewriter.create<LLVM::InsertElementOp>(
+ loc, vector1d.getType(), vector1d, adaptor.getSource(),
+ getAsLLVMValue(rewriter, loc, positionOfScalarWithin1DVector));
+ } else {
+ // No scalar insertion. The 1D vector is just the source.
+ vector1d = adaptor.getSource();
}
- // Insertion of an element into a 1-D LLVM vector.
- Value inserted = rewriter.create<LLVM::InsertElementOp>(
- loc, typeConverter->convertType(oneDVectorType), extracted,
- adaptor.getSource(), getAsLLVMValue(rewriter, loc, position.back()));
-
- // Potential insertion of resulting 1-D vector into array.
- if (position.size() > 1) {
- if (insertOp.hasDynamicPosition())
- return failure();
-
- inserted = rewriter.create<LLVM::InsertValueOp>(
- loc, adaptor.getDest(), inserted,
- getAsIntegers(position.drop_back()));
+ Value result = vector1d;
+ if (is1DVectorWithinAggregate) {
+ result = rewriter.create<LLVM::InsertValueOp>(
+ loc, adaptor.getDest(), vector1d,
+ getAsIntegers(positionOf1DVectorWithinAggregate));
}
- rewriter.replaceOp(insertOp, inserted);
+ rewriter.replaceOp(insertOp, result);
return success();
}
};
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index fa7c030538401..7e60e62363ceb 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -628,6 +628,16 @@ func.func @insertelement_into_vec_1d_f32_scalable_idx_as_index_scalable(%arg0: f
// vector.insert
//===----------------------------------------------------------------------===//
+func.func @insert_scalar_into_vec_0d(%src: f32, %dst: vector<f32>) -> vector<f32> {
+ %0 = vector.insert %src, %dst[] : f32 into vector<f32>
+ return %0 : vector<f32>
+}
+
+// CHECK-LABEL: @insert_scalar_into_vec_0d
+// CHECK: llvm.insertelement {{.*}} : vector<1xf32>
+
+// -----
+
func.func @insert_scalar_into_vec_1d_f32(%arg0: f32, %arg1: vector<4xf32>) -> vector<4xf32> {
%0 = vector.insert %arg0, %arg1[3] : f32 into vector<4xf32>
return %0 : vector<4xf32>
@@ -780,10 +790,10 @@ func.func @insert_scalar_into_vec_2d_f32_dynamic_idx(%arg0: vector<1x16xf32>, %a
return %0 : vector<1x16xf32>
}
-// Multi-dim vectors are not supported but this test shouldn't crash.
-
// CHECK-LABEL: @insert_scalar_into_vec_2d_f32_dynamic_idx(
-// CHECK: vector.insert
+// CHECK: llvm.extractvalue {{.*}} : !llvm.array<1 x vector<16xf32>>
+// CHECK: llvm.insertelement {{.*}} : vector<16xf32>
+// CHECK: llvm.insertvalue {{.*}} : !llvm.array<1 x vector<16xf32>>
// -----
@@ -793,10 +803,10 @@ func.func @insert_scalar_into_vec_2d_f32_dynamic_idx_scalable(%arg0: vector<1x[1
return %0 : vector<1x[16]xf32>
}
-// Multi-dim vectors are not supported but this test shouldn't crash.
-
// CHECK-LABEL: @insert_scalar_into_vec_2d_f32_dynamic_idx_scalable(
-// CHECK: vector.insert
+// CHECK: llvm.extractvalue {{.*}} : !llvm.array<1 x vector<[16]xf32>>
+// CHECK: llvm.insertelement {{.*}} : vector<[16]xf32>
+// CHECK: llvm.insertvalue {{.*}} : !llvm.array<1 x vector<[16]xf32>>
// -----
|
// There are 3 steps here, vs 2 in VectorExtractOpConversion: | ||
// - Extraction of a 1D vector from the nested aggregate: llvm.extractvalue. | ||
// - Insertion into the 1D vector: llvm.insertelement. | ||
// - Insertion of the 1D vector into the nested aggregate: llvm.insertvalue. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I would reword this to:
The innermost dimension of the destination vector, when converted to a nested aggregate form, will always be a 1D vector.
- If the insertion is happening into the innermost dimension of the destination vector:
- If the destination is a nested aggregate, extract a 1D vector out of the aggregate. This can be done using llvm.extractvalue. The destination is now guranteed to be a 1D vector, to which we are inserting.
- Do the insertion on the 1D destination vector, and make the result the new source nested aggregate. This can be done using llvm.insertelement.
- Insert the source nested aggregate into the destination nested aggregate.
it makes it easier to understand the possible cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
// CHECK: vector.insert | ||
// CHECK: llvm.extractvalue {{.*}} : !llvm.array<1 x vector<[16]xf32>> | ||
// CHECK: llvm.insertelement {{.*}} : vector<[16]xf32> | ||
// CHECK: llvm.insertvalue {{.*}} : !llvm.array<1 x vector<[16]xf32>> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we have a negative test for when dynamic indices fail?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added.
Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
This is doing the same as #117731 did for
vector.extract
, but forvector.insert
.It is a bit more complicated as the insertion destination may itself need to be extracted.
As the test shows, this fixes two previously unsupported cases: