Skip to content

[mlir][bufferization] Empty tensor elimination based on SubsetOpInterface #65766

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,42 @@ def SubsetInsertionOpInterface : OpInterface<"SubsetInsertionOpInterface"> {
"::mlir::Value":$candidate,
"::llvm::function_ref<bool(Value, Value)>":$equivalenceFn)
>,
InterfaceMethod<
/*desc=*/[{
Return the subset of the destination tensor that this operation
inserts into.

Example:
```
// SubsetOpInterface op:
%0 = tensor.insert_slice %t0 into %t1[%pos][5][1]
: tensor<5xf32> into tensor<?xf32>
// Subset (built by this function):
%1 = tensor.extract_slice %t1[%pos][5][1]
: tensor<?xf32> to tensor<5xf32>
```

Note: Implementations do not necessarily have to build new IR. They
may return existing SSA values.
}],
/*retType=*/"::mlir::Value",
/*methodName=*/"buildSubsetExtraction",
/*args=*/(ins "::mlir::OpBuilder &":$builder, "Location":$loc)
>,
InterfaceMethod<
/*desc=*/[{
Return all SSA values that are needed (i.e., must be in scope) at the
insertion of the builder when calling `buildSubsetExtraction`. Users
of `buildSubsetExtraction` can use this helper method to find a
suitable insertion point.

Example: The SSA values needed to build the subset in the example of
`buildSubsetExtraction` are %t1 and %pos.
}],
/*retType=*/"::llvm::SmallVector<::mlir::Value>",
/*methodName=*/"getValuesNeededToBuildSubsetExtraction",
/*args=*/(ins)
>,
];

let extraClassDeclaration = [{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,19 +109,17 @@ def EliminateEmptyTensorsOp
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let description = [{
Try to eliminate all `tensor.empty` ops within the targeted op by replacing
them with a destination tensor.
them with another destination tensor.

`tensor.empty` ops cannot be bufferizes. They can either be converted to
`bufferization.alloc_tensor` or replaced with another tensor (via this
transform). `tensor.empty` does not specify the contents of the returned
"tensor.empty" ops cannot be bufferized. They can either be converted to
"bufferization.alloc_tensor" or replaced with another tensor (via this
transform). "tensor.empty" does not specify the contents of the returned
Comment on lines +114 to +116
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wondering why you are replacing back ticks with "?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think backticks are used when referring to function/variable names only.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, I often do the exact opposite. Is there a written guideline somewhere?

tensor so their results can be replaced with arbitrary tensor values as long
as the dimensions match.

This transform looks for `tensor.empty` ops where the SSA use-def chain of
the result ends in a supported "anchor op" (always following the aliasing
OpOperand/OpResult chain). Currently supported anchor ops are:
- `tensor.insert_slice`
- `bufferization.yield` (inside `bufferization.alloc_tensor`)
This transformation looks for subset ops that insert a tensor that
originates from a "tensor.empty" (as per the reverse use-def chain). Such
"tensor.empty" ops are replaced with the destination subset.

Example:

Expand All @@ -138,6 +136,10 @@ def EliminateEmptyTensorsOp
%2 = tensor.insert_slice %1 into %t[1][5][1]
```

In the above example, the subset op is "tensor.insert_slice". When tracing
back the reverse use-def chain of a the source, we end up at a
"tensor.empty" op.

The above example can bufferize without an allocation (in the absence of
other conflicts) because there is no longer a `tensor.empty` op.

Expand Down
21 changes: 16 additions & 5 deletions mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -402,11 +402,22 @@ def PromoteBuffersToStack : Pass<"promote-buffers-to-stack", "func::FuncOp"> {
def EmptyTensorElimination : Pass<"eliminate-empty-tensors"> {
let summary = "Try to eliminate all tensor.empty ops.";
let description = [{
This pass tries to eliminate all insert_slice op-anchored tensor.empty ops.
I.e., when a value that is equivalent to an tensor.empty op is inserted into
another tensor, this pass tries to rewrite the IR in such a way that the
destination tensor of the insert_slice op is used directly instead of the
tensor.empty result.
Try to eliminate "tensor.empty" ops inside `op`. This transformation looks
for subset ops that insert a tensor that originates from a "tensor.empty"
(as per the reverse use-def chain). Such "tensor.empty" ops are replaced
with the destination subset.

E.g.:
```
%0 = tensor.empty() : tensor<10xf32>
%1 = linalg.fill ... outs(%0 : tensor<10xf32>)
%2 = tensor.insert_slice %0 into %t ...
```

In the above example, the subset op is "tensor.insert_slice". When tracing
back the reverse use-def chain of a the source, we end up at a
"tensor.empty" op. The "tensor.empty" op is replaced with a
"tensor.extract_slice" op.
}];
let constructor = "mlir::bufferization::createEmptyTensorEliminationPass()";
}
Expand Down
38 changes: 13 additions & 25 deletions mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,38 +19,26 @@ struct BufferizationStatistics;
class OneShotAnalysisState;
struct OneShotBufferizationOptions;

/// A function that matches anchor OpOperands for tensor::EmptyOp elimination.
/// If an OpOperand is matched, the function should populate the SmallVector
/// with all values that are needed during `RewriteFn` to produce the
/// replacement value.
using AnchorMatchFn = std::function<bool(OpOperand &, SmallVector<Value> &)>;

/// A function that rewrites matched anchors.
using RewriteFn = std::function<Value(OpBuilder &, Location, OpOperand &)>;

/// Try to eliminate tensor::EmptyOps inside `op`.
/// Try to eliminate "tensor.empty" ops inside `op`. This transformation looks
/// for subset ops that insert a tensor that originates from a "tensor.empty"
/// (as per the reverse use-def chain). Such "tensor.empty" ops are replaced
/// with the destination subset.
///
/// * `rewriteFunc` generates the replacement for the tensor::EmptyOp.
/// * Only tensor::EmptyOps that are anchored on a matching OpOperand as per
/// `anchorMatchFunc` are considered. "Anchored" means that there is a path
/// on the reverse SSA use-def chain, starting from the OpOperand and always
/// following the aliasing OpOperand, that eventually ends at a single
/// tensor::EmptyOp.
/// E.g.:
/// %0 = tensor.empty() : tensor<10xf32>
/// %1 = linalg.fill ... outs(%0 : tensor<10xf32>)
/// %2 = tensor.insert_slice %0 into %t ...
///
/// In the above example, the subset op is "tensor.insert_slice". When tracing
/// back the reverse use-def chain of a the source, we end up at a
/// "tensor.empty" op.
LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op,
OneShotAnalysisState &state,
AnchorMatchFn anchorMatchFunc,
RewriteFn rewriteFunc);
OneShotAnalysisState &state);

/// Within the given operation, hoist buffers from loops where possible. See
/// "BufferLoopHoistingPass" for more information.
void hoistBuffersFromLoops(Operation *op);

/// Try to eliminate tensor::EmptyOps inside `op` that are anchored on an
/// InsertSliceOp, i.e., if it is eventually inserted into another tensor
/// (and some other conditions are met).
LogicalResult insertSliceAnchoredEmptyTensorEliminationStep(
RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state);

/// Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops.
/// After applying this transform, the IR can be bufferized without inserting
/// additional buffer allocations.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,7 @@ DiagnosedSilenceableFailure transform::EliminateEmptyTensorsOp::apply(
if (failed(analyzeOp(target, state)))
return mlir::emitSilenceableFailure(target->getLoc())
<< "failed to analyze op";
if (failed(bufferization::insertSliceAnchoredEmptyTensorEliminationStep(
rewriter, target, state)))
if (failed(bufferization::eliminateEmptyTensors(rewriter, target, state)))
return mlir::emitSilenceableFailure(target->getLoc())
<< "failed to eliminate insert_slice anchored tensor.empty ops";
}
Expand Down
195 changes: 54 additions & 141 deletions mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
Expand Down Expand Up @@ -99,154 +100,67 @@ findValidInsertionPoint(Operation *emptyTensorOp,
return nullptr;
}

/// Try to eliminate tensor::EmptyOps inside `op`. A tensor::EmptyOp is replaced
/// with the result of `rewriteFunc` if it is anchored on a matching
/// OpOperand. "Anchored" means that there is a path on the reverse SSA use-def
/// chain, starting from the OpOperand and always following the aliasing
/// OpOperand, that eventually ends at the tensor::EmptyOp.
///
/// E.g.:
/// %0 = tensor.empty() : tensor<10xf32>
/// %1 = linalg.fill ... outs(%0 : tensor<10xf32>)
/// %2 = tensor.insert_slice %0 into %t ...
///
/// In the above example, the anchor is the source operand of the insert_slice
/// op. When tracing back the reverse use-def chain, we end up at a
/// tensor.empty op.
LogicalResult mlir::bufferization::eliminateEmptyTensors(
RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state,
AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc) {
RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) {
OpBuilder::InsertionGuard g(rewriter);

op->walk([&](Operation *op) {
for (OpOperand &operand : op->getOpOperands()) {
// Skip operands that do not bufferize inplace.
if (!state.isInPlace(operand))
continue;
// All values that are needed to create the replacement op.
SmallVector<Value> neededValues;
// Is this an anchor?
if (!anchorMatchFunc(operand, neededValues))
op->walk([&](SubsetInsertionOpInterface op) {
OpOperand &source = op.getSourceOperand();
// Skip operands that do not bufferize inplace. "tensor.empty" could still
// be replaced, but the transformation may not be beneficial.
if (!state.isInPlace(source))
return WalkResult::skip();
// All values that are needed to create the replacement op.
SmallVector<Value> neededValues =
op.getValuesNeededToBuildSubsetExtraction();

// Find tensor.empty ops on the reverse SSA use-def chain. Only follow
// equivalent tensors. I.e., stop when there are ops such as extract_slice
// on the path.
TraversalConfig config;
config.followEquivalentOnly = true;
config.alwaysIncludeLeaves = false;
// Replace only if the types match or are static <-> dynamic casts. We do
// not support slices or reshapes.
// TODO: This could be extended to support IR such as:
// %0 = tensor.empty() : tensor<128xf32>
// %1 = "some_op"(%0) : (tensor<128xf32>) -> (tensor<128xf32>)
// %2 = tensor.expand_shape %1 ...
// %3 = tensor.insert_slice %2 into ...
config.followSameTypeOrCastsOnly = true;
SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
source.get(), /*condition=*/
[&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); },
config);

for (Value v : emptyTensors) {
Operation *emptyTensorOp = v.getDefiningOp();

// Find a suitable insertion point. If no suitable insertion point for
// the replacement can be found, skip this replacement.
Operation *insertionPoint =
findValidInsertionPoint(emptyTensorOp, neededValues);
if (!insertionPoint)
continue;

// Find tensor.empty ops on the reverse SSA use-def chain. Only follow
// equivalent tensors. I.e., stop when there are ops such as extract_slice
// on the path.
TraversalConfig config;
config.followEquivalentOnly = true;
config.alwaysIncludeLeaves = false;
// Replace only if the types match or are static <-> dynamic casts. We do
// not support slices or reshapes.
// TODO: This could be extended to support IR such as:
// %0 = tensor.empty() : tensor<128xf32>
// %1 = "some_op"(%0) : (tensor<128xf32>) -> (tensor<128xf32>)
// %2 = tensor.expand_shape %1 ...
// %3 = tensor.insert_slice %2 into ...
config.followSameTypeOrCastsOnly = true;
SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
operand.get(), /*condition=*/
[&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); },
config);

for (Value v : emptyTensors) {
Operation *emptyTensorOp = v.getDefiningOp();

// Find a suitable insertion point. If no suitable insertion point for
// the replacement can be found, skip this replacement.
Operation *insertionPoint =
findValidInsertionPoint(emptyTensorOp, neededValues);
if (!insertionPoint)
continue;

rewriter.setInsertionPoint(insertionPoint);
Value replacement =
rewriteFunc(rewriter, emptyTensorOp->getLoc(), operand);
if (!replacement)
continue;
if (replacement.getType() != v.getType()) {
rewriter.setInsertionPointAfterValue(replacement);
replacement = rewriter.create<tensor::CastOp>(v.getLoc(), v.getType(),
replacement);
}
// Replace the tensor::EmptyOp.
rewriter.replaceOp(emptyTensorOp, replacement);
state.resetCache();
rewriter.setInsertionPoint(insertionPoint);
Value replacement =
op.buildSubsetExtraction(rewriter, emptyTensorOp->getLoc());
if (!replacement)
continue;
if (replacement.getType() != v.getType()) {
rewriter.setInsertionPointAfterValue(replacement);
replacement = rewriter.create<tensor::CastOp>(v.getLoc(), v.getType(),
replacement);
}
// Replace the tensor::EmptyOp.
rewriter.replaceOp(emptyTensorOp, replacement);
state.resetCache();
}
});

return success();
}

/// Try to eliminate tensor::EmptyOps inside `op`. An tensor::EmptyOp can be
/// eliminated if it is eventually inserted into another tensor (and some other
/// conditions are met).
///
/// E.g.:
/// %0 = tensor.empty()
/// %1 = linalg.fill(%cst, %0) {inplace = [true]}
/// %2 = tensor.insert_slice %1 into %t[10][20][1]
///
/// tensor::EmptyOp elimination will try to fill %t inplace instead of filling a
/// new allocation %0 and inserting it into %t. This is done by replacing the
/// tensor::EmptyOp with:
///
/// %0 = tensor.extract_slice %t[10][20][1]
///
/// The analysis looks for matching ExtractSliceOp/InsertSliceOp pairs and lets
/// those bufferize inplace in the absence of other conflicts.
///
/// Starting from an InsertSliceOp, an tensor::EmptyOp at the end of the insert
/// source's reverse use-def chain is eliminated if:
/// * On the reverse use-def chain path from the InsertSliceOp to the
/// tensor::EmptyOp, all ops were decided to bufferize inplace and the buffer
/// relation is "equivalent" (TODO: can be relaxed if needed).
/// * The reverse use-def chain has exactly one end, which is the
/// tensor::EmptyOp.
template <typename OpTy>
static LogicalResult insertSliceLikeAnchoredEmptyTensorEliminationStep(
RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) {
return eliminateEmptyTensors(
rewriter, op, state,
/*anchorMatchFunc=*/
[&](OpOperand &operand, SmallVector<Value> &neededValues) {
auto insertSliceOp = dyn_cast<OpTy>(operand.getOwner());
if (!insertSliceOp)
return false;
if (&operand != &insertSliceOp->getOpOperand(0) /*source*/)
return false;

// Collect all values that are needed to construct the replacement op.
neededValues.append(insertSliceOp.getOffsets().begin(),
insertSliceOp.getOffsets().end());
neededValues.append(insertSliceOp.getSizes().begin(),
insertSliceOp.getSizes().end());
neededValues.append(insertSliceOp.getStrides().begin(),
insertSliceOp.getStrides().end());
neededValues.push_back(insertSliceOp.getDest());

return true;
},
/*rewriteFunc=*/
[](OpBuilder &b, Location loc, OpOperand &operand) {
auto insertOp = cast<OpTy>(operand.getOwner());
auto extractOp = b.create<tensor::ExtractSliceOp>(
loc, insertOp.getSourceType(), insertOp.getDest(),
insertOp.getMixedOffsets(), insertOp.getMixedSizes(),
insertOp.getMixedStrides());
return extractOp.getResult();
});
}
return WalkResult::advance();
});

LogicalResult
mlir::bufferization::insertSliceAnchoredEmptyTensorEliminationStep(
RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) {
if (failed(insertSliceLikeAnchoredEmptyTensorEliminationStep<
tensor::InsertSliceOp>(rewriter, op, state)))
return failure();
if (failed(insertSliceLikeAnchoredEmptyTensorEliminationStep<
tensor::ParallelInsertSliceOp>(rewriter, op, state)))
return failure();
return success();
}

Expand Down Expand Up @@ -276,8 +190,7 @@ void EmptyTensorElimination::runOnOperation() {
}

IRRewriter rewriter(op->getContext());
if (failed(bufferization::insertSliceAnchoredEmptyTensorEliminationStep(
rewriter, op, state)))
if (failed(bufferization::eliminateEmptyTensors(rewriter, op, state)))
signalPassFailure();
}

Expand Down
Loading