Skip to content

[mlir][bufferization]-Add ControlBuildSubsetExtractionFn to TensorEmptyElimination #120851

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
#define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_TRANSFORMS_H

#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Operation.h"
#include "mlir/Interfaces/SubsetOpInterface.h"

namespace mlir {
namespace bufferization {
Expand All @@ -34,13 +36,35 @@ struct OneShotBufferizationOptions;
/// "tensor.empty" op.
LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op);

/// A function type that defines a callback to control the construction
/// of the subset extraction of the `SubsetInsertionOpInterface`.
/// The subset extraction value can be used as a replacement for the
/// `emptyTensorOp` value which is being consumed by `user`, failing
/// of building such a value should be indicated with an empty value.
/// This function should guarantee the legality of the replacement,
/// i.e. the replacement should dominate the user of the `emptyTensorOp`
/// being eliminated.
using ControlBuildSubsetExtractionFn =
std::function<Value(RewriterBase &, SubsetInsertionOpInterface,
tensor::EmptyOp emptyTensorOp, Operation *user)>;

/// This method builds and returns a subset extraction value for the
/// destination tensor that the given `op` inserts into.
/// It returns a value which should replace the `emptyTensorOp` use
/// that is being consumed by `user`.
/// If no such a value found it will return an empty Value.
Value buildSubsetExtraction(RewriterBase &rewriter,
SubsetInsertionOpInterface op,
tensor::EmptyOp emptyTensorOp, Operation *user);

/// Try to eliminate "tensor.empty" ops inside `op`.
///
/// This function overload accepts an existing `OneShotAnalysisState`, which
/// contains in-place bufferization decisions. This overload is useful if an
/// existing analysis should be reused for empty tensor elimination.
LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op,
OneShotAnalysisState &state);
LogicalResult eliminateEmptyTensors(
RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state,
ControlBuildSubsetExtractionFn subsetsExtractionFn = buildSubsetExtraction);

/// Within the given operation, hoist buffers from loops where possible. See
/// "BufferLoopHoistingPass" for more information.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,31 @@ findValidInsertionPoint(Operation *emptyTensorOp, Operation *user,
return nullptr;
}

Value mlir::bufferization::buildSubsetExtraction(RewriterBase &rewriter,
SubsetInsertionOpInterface op,
tensor::EmptyOp emptyTensorOp,
Operation *user) {

mlir::OpBuilder::InsertionGuard guard(rewriter);
// All values that are needed to create the replacement op.
SmallVector<Value> neededValues = op.getValuesNeededToBuildSubsetExtraction();
// Find a suitable insertion point. If no suitable insertion point
// for the replacement can be found, return an empty value to skip
// this replacement.
Operation *insertionPoint =
findValidInsertionPoint(emptyTensorOp, user, neededValues);
if (!insertionPoint)
return {};

rewriter.setInsertionPoint(insertionPoint);
Value replacement =
op.buildSubsetExtraction(rewriter, emptyTensorOp->getLoc());
return replacement;
}

LogicalResult mlir::bufferization::eliminateEmptyTensors(
RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) {
RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state,
ControlBuildSubsetExtractionFn subsetsExtractionFn) {
OpBuilder::InsertionGuard g(rewriter);
llvm::DenseSet<OpOperand *> visitedOpOperands;
op->walk([&](SubsetInsertionOpInterface op) {
Expand All @@ -105,10 +128,6 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
if (!state.isInPlace(source))
return WalkResult::skip();

// All values that are needed to create the replacement op.
SmallVector<Value> neededValues =
op.getValuesNeededToBuildSubsetExtraction();

// Find tensor.empty ops on the reverse SSA use-def chain. Only follow
// equivalent tensors. I.e., stop when there are ops such as extract_slice
// on the path.
Expand All @@ -129,8 +148,8 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
&visitedOpOperands);

for (Value v : emptyTensors) {
Operation *emptyTensorOp = v.getDefiningOp();

auto emptyTensorOp = v.getDefiningOp<tensor::EmptyOp>();
assert(emptyTensorOp && "expected tensor.empty op");
// Find the use to be replaced from the use-def chain.
auto iter = llvm::find_if(
visitedOpOperands, [&emptyTensorOp](OpOperand *opOperand) {
Expand All @@ -142,17 +161,7 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
continue;
OpOperand *useToBeReplaced = *iter;
Operation *user = useToBeReplaced->getOwner();

// Find a suitable insertion point. If no suitable insertion point for
// the replacement can be found, skip this replacement.
Operation *insertionPoint =
findValidInsertionPoint(emptyTensorOp, user, neededValues);
if (!insertionPoint)
continue;

rewriter.setInsertionPoint(insertionPoint);
Value replacement =
op.buildSubsetExtraction(rewriter, emptyTensorOp->getLoc());
auto replacement = subsetsExtractionFn(rewriter, op, emptyTensorOp, user);
if (!replacement)
continue;
if (emptyTensorOp == replacement.getDefiningOp())
Expand Down
Loading