-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][vector] Support complete folding in single pass for vector.insert/vector.extract #142124
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[mlir][vector] Support complete folding in single pass for vector.insert/vector.extract #142124
Conversation
…ert/vector.extract After successfully converting dynamic indices to static indices, continue folding instead of returning early, allowing subsequent fold operations to be executed.
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Yang Bai (yangtetris) ChangesDescriptionThis patch improves the folding efficiency of MotivationSince the
If we use
But this is not the optimal result. Full diff: https://github.com/llvm/llvm-project/pull/142124.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 890a5e9e5c9b4..2e0c917b2139d 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2062,6 +2062,7 @@ static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
if (opChange) {
op.setStaticPosition(staticPosition);
op.getOperation()->setOperands(operands);
+ // Return the original result to indicate an in-place folding happened.
return op.getResult();
}
return {};
@@ -2148,8 +2149,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
// Fold `arith.constant` indices into the `vector.extract` operation. Make
// sure that patterns requiring constant indices are added after this fold.
SmallVector<Value> operands = {getVector()};
- if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands))
- return val;
+ auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
+
if (auto res = foldPoisonIndexInsertExtractOp(
getContext(), adaptor.getStaticPosition(), kPoisonIndex))
return res;
@@ -2171,7 +2172,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
return val;
if (auto val = foldScalarExtractFromFromElements(*this))
return val;
- return OpFoldResult();
+
+ return inplaceFolded;
}
namespace {
@@ -3150,8 +3152,8 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
// Fold `arith.constant` indices into the `vector.insert` operation. Make
// sure that patterns requiring constant indices are added after this fold.
SmallVector<Value> operands = {getValueToStore(), getDest()};
- if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands))
- return val;
+ auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
+
if (auto res = foldPoisonIndexInsertExtractOp(
getContext(), adaptor.getStaticPosition(), kPoisonIndex))
return res;
@@ -3161,7 +3163,7 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
return res;
}
- return {};
+ return inplaceFolded;
}
//===----------------------------------------------------------------------===//
|
The existing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please provide a test for this?
Thanks for fixing this problem! For testing, would it make sense to add a pass similar to |
I just added two tests in |
I happened to see an existing pass that implements what you said. Do you think it is a good choice to use TestConstantFold for folding tests? |
I was not aware of that pass. It looks like pretty focus on folding constants only. I'm also surprised that we don't have a pass to test the op folder in isolation. @joker-eph, @River707, @jpienaar, do you know? |
I don't see a focus on constant for this pass: it only has some special handling to cleanup constants after it's done as far as I can tell. Otherwise it is a single traversal applying the folder. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, LGTM!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice! LGTM.
I've confirmed that this does as expected by running
mlir-opt input.mlir
mlir-opt input.mlir -test-constant-fold
mlir-opt input.mlir -test-constant-fold -test-constant-fold
before and after.
The 'constant' in the pass and test file name are indeed a bit confusing, but probably not worth the churn of changing at this point.
Description
This patch improves the folding efficiency of
vector.insert
andvector.extract
operations by not returning early after successfully converting dynamic indices to static indices.Motivation
Since the
OpBuilder::createOrFold
function only callsfold
once, the currentfold
methods ofvector.insert
andvector.extract
may leave the op in a state that can be folded further. For example, consider the following un-folded IR:If we use
createOrFold
to create thevector.extract
op, then the result will be:But this is not the optimal result.
createOrFold
should have returned%e1
.The reason is that the execution of fold returns immediately after
extractInsertFoldConstantOp
, causing subsequent folding logics to be skipped.