-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[mlir][bufferization]-Add ControlBuildSubsetExtractionFn to TensorEmptyElimination #120851
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][bufferization]-Add ControlBuildSubsetExtractionFn to TensorEmptyElimination #120851
Conversation
@llvm/pr-subscribers-mlir-bufferization Author: Amir Bishara (amirBish) ChangesThis PR Adds a This control function returns the subsets extraction value that will replace the The default control function will stay like today's behavior without any additional changes. Full diff: https://github.com/llvm/llvm-project/pull/120851.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
index 892675954493b9..bd9242e2caccb4 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
@@ -10,7 +10,9 @@
#define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_TRANSFORMS_H
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Operation.h"
+#include "mlir/Interfaces/SubsetOpInterface.h"
namespace mlir {
namespace bufferization {
@@ -34,13 +36,39 @@ struct OneShotBufferizationOptions;
/// "tensor.empty" op.
LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op);
+/// Find a valid insertion point for a replacement of `emptyTensorOp`'s
+/// use of `user` operation, assuming that the replacement may use any
+/// value from `neededValues`.
+Operation *findValidInsertionPoint(Operation *emptyTensorOp, Operation *user,
+ const SmallVector<Value> &neededValues);
+
+/// A function type that defines a callBack to control the build of the
+/// subsets extraction of the `SubsetInsertionOpInterface`.
+/// The subsets extraction value will replace the `emptyTensorOp` value
+/// which is being consumed by `user`, failing of building such a value
+/// should be indicated with an empty value.
+/// This function should guarantee the legality of the replacement.
+using ControlBuildSubsetExtractionFn =
+ std::function<Value(RewriterBase &, SubsetInsertionOpInterface,
+ tensor::EmptyOp emptyTensorOp, Operation *user)>;
+
+/// This method Builds and returns a subset extraction value for the
+/// destination tensor that the given `op` inserts into.
+/// It returns a value which should replace the `emptyTensorOp` use
+/// that is being consumed by `user`, If no such a value found it
+/// will return an empty Value.
+Value buildSubsetExtraction(RewriterBase &rewriter,
+ SubsetInsertionOpInterface op,
+ tensor::EmptyOp emptyTensorOp, Operation *user);
+
/// Try to eliminate "tensor.empty" ops inside `op`.
///
/// This function overload accepts an existing `OneShotAnalysisState`, which
/// contains in-place bufferization decisions. This overload is useful if an
/// existing analysis should be reused for empty tensor elimination.
-LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op,
- OneShotAnalysisState &state);
+LogicalResult eliminateEmptyTensors(
+ RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state,
+ ControlBuildSubsetExtractionFn subsetsExtractionFn = buildSubsetExtraction);
/// Within the given operation, hoist buffers from loops where possible. See
/// "BufferLoopHoistingPass" for more information.
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
index abc0635a2cdff0..dabb44edd32783 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
@@ -51,9 +51,9 @@ neededValuesDominateInsertionPoint(const DominanceInfo &domInfo,
/// Find a valid insertion point for a replacement of `emptyTensorOp`'s
/// use of `user` operation, assuming that the replacement may use any
/// value from `neededValues`.
-static Operation *
-findValidInsertionPoint(Operation *emptyTensorOp, Operation *user,
- const SmallVector<Value> &neededValues) {
+Operation *mlir::bufferization::findValidInsertionPoint(
+ Operation *emptyTensorOp, Operation *user,
+ const SmallVector<Value> &neededValues) {
DominanceInfo domInfo;
Operation *candidateInsertionPoint = emptyTensorOp;
@@ -93,8 +93,31 @@ findValidInsertionPoint(Operation *emptyTensorOp, Operation *user,
return nullptr;
}
+Value mlir::bufferization::buildSubsetExtraction(RewriterBase &rewriter,
+ SubsetInsertionOpInterface op,
+ tensor::EmptyOp emptyTensorOp,
+ Operation *user) {
+
+ mlir::OpBuilder::InsertionGuard guard(rewriter);
+ // All values that are needed to create the replacement op.
+ SmallVector<Value> neededValues = op.getValuesNeededToBuildSubsetExtraction();
+ // Find a suitable insertion point. If no suitable insertion point
+ // for the replacement can be found, return an empty value to skip
+ // this replacement.
+ Operation *insertionPoint =
+ findValidInsertionPoint(emptyTensorOp, user, neededValues);
+ if (!insertionPoint)
+ return {};
+
+ rewriter.setInsertionPoint(insertionPoint);
+ Value replacement =
+ op.buildSubsetExtraction(rewriter, emptyTensorOp->getLoc());
+ return replacement;
+}
+
LogicalResult mlir::bufferization::eliminateEmptyTensors(
- RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) {
+ RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state,
+ ControlBuildSubsetExtractionFn subsetsExtractionFn) {
OpBuilder::InsertionGuard g(rewriter);
llvm::DenseSet<OpOperand *> visitedOpOperands;
op->walk([&](SubsetInsertionOpInterface op) {
@@ -105,10 +128,6 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
if (!state.isInPlace(source))
return WalkResult::skip();
- // All values that are needed to create the replacement op.
- SmallVector<Value> neededValues =
- op.getValuesNeededToBuildSubsetExtraction();
-
// Find tensor.empty ops on the reverse SSA use-def chain. Only follow
// equivalent tensors. I.e., stop when there are ops such as extract_slice
// on the path.
@@ -129,7 +148,7 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
&visitedOpOperands);
for (Value v : emptyTensors) {
- Operation *emptyTensorOp = v.getDefiningOp();
+ auto emptyTensorOp = v.getDefiningOp<tensor::EmptyOp>();
// Find the use to be replaced from the use-def chain.
auto iter = llvm::find_if(
@@ -142,17 +161,7 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
continue;
OpOperand *useToBeReplaced = *iter;
Operation *user = useToBeReplaced->getOwner();
-
- // Find a suitable insertion point. If no suitable insertion point for
- // the replacement can be found, skip this replacement.
- Operation *insertionPoint =
- findValidInsertionPoint(emptyTensorOp, user, neededValues);
- if (!insertionPoint)
- continue;
-
- rewriter.setInsertionPoint(insertionPoint);
- Value replacement =
- op.buildSubsetExtraction(rewriter, emptyTensorOp->getLoc());
+ auto replacement = subsetsExtractionFn(rewriter, op, emptyTensorOp, user);
if (!replacement)
continue;
if (emptyTensorOp == replacement.getDefiningOp())
|
@llvm/pr-subscribers-mlir Author: Amir Bishara (amirBish) ChangesThis PR Adds a This control function returns the subsets extraction value that will replace the The default control function will stay like today's behavior without any additional changes. Full diff: https://github.com/llvm/llvm-project/pull/120851.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
index 892675954493b9..bd9242e2caccb4 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
@@ -10,7 +10,9 @@
#define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_TRANSFORMS_H
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Operation.h"
+#include "mlir/Interfaces/SubsetOpInterface.h"
namespace mlir {
namespace bufferization {
@@ -34,13 +36,39 @@ struct OneShotBufferizationOptions;
/// "tensor.empty" op.
LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op);
+/// Find a valid insertion point for a replacement of `emptyTensorOp`'s
+/// use of `user` operation, assuming that the replacement may use any
+/// value from `neededValues`.
+Operation *findValidInsertionPoint(Operation *emptyTensorOp, Operation *user,
+ const SmallVector<Value> &neededValues);
+
+/// A function type that defines a callBack to control the build of the
+/// subsets extraction of the `SubsetInsertionOpInterface`.
+/// The subsets extraction value will replace the `emptyTensorOp` value
+/// which is being consumed by `user`, failing of building such a value
+/// should be indicated with an empty value.
+/// This function should guarantee the legality of the replacement.
+using ControlBuildSubsetExtractionFn =
+ std::function<Value(RewriterBase &, SubsetInsertionOpInterface,
+ tensor::EmptyOp emptyTensorOp, Operation *user)>;
+
+/// This method Builds and returns a subset extraction value for the
+/// destination tensor that the given `op` inserts into.
+/// It returns a value which should replace the `emptyTensorOp` use
+/// that is being consumed by `user`, If no such a value found it
+/// will return an empty Value.
+Value buildSubsetExtraction(RewriterBase &rewriter,
+ SubsetInsertionOpInterface op,
+ tensor::EmptyOp emptyTensorOp, Operation *user);
+
/// Try to eliminate "tensor.empty" ops inside `op`.
///
/// This function overload accepts an existing `OneShotAnalysisState`, which
/// contains in-place bufferization decisions. This overload is useful if an
/// existing analysis should be reused for empty tensor elimination.
-LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op,
- OneShotAnalysisState &state);
+LogicalResult eliminateEmptyTensors(
+ RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state,
+ ControlBuildSubsetExtractionFn subsetsExtractionFn = buildSubsetExtraction);
/// Within the given operation, hoist buffers from loops where possible. See
/// "BufferLoopHoistingPass" for more information.
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
index abc0635a2cdff0..dabb44edd32783 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
@@ -51,9 +51,9 @@ neededValuesDominateInsertionPoint(const DominanceInfo &domInfo,
/// Find a valid insertion point for a replacement of `emptyTensorOp`'s
/// use of `user` operation, assuming that the replacement may use any
/// value from `neededValues`.
-static Operation *
-findValidInsertionPoint(Operation *emptyTensorOp, Operation *user,
- const SmallVector<Value> &neededValues) {
+Operation *mlir::bufferization::findValidInsertionPoint(
+ Operation *emptyTensorOp, Operation *user,
+ const SmallVector<Value> &neededValues) {
DominanceInfo domInfo;
Operation *candidateInsertionPoint = emptyTensorOp;
@@ -93,8 +93,31 @@ findValidInsertionPoint(Operation *emptyTensorOp, Operation *user,
return nullptr;
}
+Value mlir::bufferization::buildSubsetExtraction(RewriterBase &rewriter,
+ SubsetInsertionOpInterface op,
+ tensor::EmptyOp emptyTensorOp,
+ Operation *user) {
+
+ mlir::OpBuilder::InsertionGuard guard(rewriter);
+ // All values that are needed to create the replacement op.
+ SmallVector<Value> neededValues = op.getValuesNeededToBuildSubsetExtraction();
+ // Find a suitable insertion point. If no suitable insertion point
+ // for the replacement can be found, return an empty value to skip
+ // this replacement.
+ Operation *insertionPoint =
+ findValidInsertionPoint(emptyTensorOp, user, neededValues);
+ if (!insertionPoint)
+ return {};
+
+ rewriter.setInsertionPoint(insertionPoint);
+ Value replacement =
+ op.buildSubsetExtraction(rewriter, emptyTensorOp->getLoc());
+ return replacement;
+}
+
LogicalResult mlir::bufferization::eliminateEmptyTensors(
- RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) {
+ RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state,
+ ControlBuildSubsetExtractionFn subsetsExtractionFn) {
OpBuilder::InsertionGuard g(rewriter);
llvm::DenseSet<OpOperand *> visitedOpOperands;
op->walk([&](SubsetInsertionOpInterface op) {
@@ -105,10 +128,6 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
if (!state.isInPlace(source))
return WalkResult::skip();
- // All values that are needed to create the replacement op.
- SmallVector<Value> neededValues =
- op.getValuesNeededToBuildSubsetExtraction();
-
// Find tensor.empty ops on the reverse SSA use-def chain. Only follow
// equivalent tensors. I.e., stop when there are ops such as extract_slice
// on the path.
@@ -129,7 +148,7 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
&visitedOpOperands);
for (Value v : emptyTensors) {
- Operation *emptyTensorOp = v.getDefiningOp();
+ auto emptyTensorOp = v.getDefiningOp<tensor::EmptyOp>();
// Find the use to be replaced from the use-def chain.
auto iter = llvm::find_if(
@@ -142,17 +161,7 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
continue;
OpOperand *useToBeReplaced = *iter;
Operation *user = useToBeReplaced->getOwner();
-
- // Find a suitable insertion point. If no suitable insertion point for
- // the replacement can be found, skip this replacement.
- Operation *insertionPoint =
- findValidInsertionPoint(emptyTensorOp, user, neededValues);
- if (!insertionPoint)
- continue;
-
- rewriter.setInsertionPoint(insertionPoint);
- Value replacement =
- op.buildSubsetExtraction(rewriter, emptyTensorOp->getLoc());
+ auto replacement = subsetsExtractionFn(rewriter, op, emptyTensorOp, user);
if (!replacement)
continue;
if (emptyTensorOp == replacement.getDefiningOp())
|
@matthias-springer This is the PR which adds the control function, as a follow up for this PR #118958 . Would appreciate your review :) |
@matthias-springer ping, can you have a look please |
mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
Outdated
Show resolved
Hide resolved
mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
Outdated
Show resolved
Hide resolved
mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
Outdated
Show resolved
Hide resolved
mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
Outdated
Show resolved
Hide resolved
mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
Outdated
Show resolved
Hide resolved
mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
Outdated
Show resolved
Hide resolved
mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
Outdated
Show resolved
Hide resolved
mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
Outdated
Show resolved
Hide resolved
…tyElimination util This PR Adds a `ControlBuildSubsetExtractionFn` to the tensor empty elimination util, This will control the building of the subsets extraction of the `SubsetInsertionOpInterface`. This control function returns the subsets extraction value that will replace the `emptyTensorOp` use which is being consumed by a specefic user (which the util expects to eliminate it). The default control function will stay like today's behavior without any additional changes.
b2c97a0
to
c94756c
Compare
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/198/builds/635 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/94/builds/3256 Here is the relevant piece of the build log for the reference
|
This PR Adds a
ControlBuildSubsetExtractionFn
to the tensor empty elimination util, This will control the building of the subsets extraction of theSubsetInsertionOpInterface
.This control function returns the subsets extraction value that will replace the
emptyTensorOp
usewhich is being consumed by a specefic user (which the
util expects to eliminate it).
The default control function will stay like today's behavior without any additional changes.