Skip to content

Commit a1ef5a9

Browse files
[mlir][bufferization] Empty tensor elimination based on SubsetOpInterface (#65766)
This commit generalizes empty tensor elimination to operate on subset ops. No new test cases are added because all current subset ops were already supported previously. From this perspective, this change is NFC. A new interface method (and a helper method) are added to `SubsetInsertionOpInterface` to build the subset of the destination tensor.
1 parent 419f90e commit a1ef5a9

File tree

7 files changed

+180
-182
lines changed

7 files changed

+180
-182
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.td

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,42 @@ def SubsetInsertionOpInterface : OpInterface<"SubsetInsertionOpInterface"> {
9999
"::mlir::Value":$candidate,
100100
"::llvm::function_ref<bool(Value, Value)>":$equivalenceFn)
101101
>,
102+
InterfaceMethod<
103+
/*desc=*/[{
104+
Return the subset of the destination tensor that this operation
105+
inserts into.
106+
107+
Example:
108+
```
109+
// SubsetOpInterface op:
110+
%0 = tensor.insert_slice %t0 into %t1[%pos][5][1]
111+
: tensor<5xf32> into tensor<?xf32>
112+
// Subset (built by this function):
113+
%1 = tensor.extract_slice %t1[%pos][5][1]
114+
: tensor<?xf32> to tensor<5xf32>
115+
```
116+
117+
Note: Implementations do not necessarily have to build new IR. They
118+
may return existing SSA values.
119+
}],
120+
/*retType=*/"::mlir::Value",
121+
/*methodName=*/"buildSubsetExtraction",
122+
/*args=*/(ins "::mlir::OpBuilder &":$builder, "Location":$loc)
123+
>,
124+
InterfaceMethod<
125+
/*desc=*/[{
126+
Return all SSA values that are needed (i.e., must be in scope) at the
127+
insertion of the builder when calling `buildSubsetExtraction`. Users
128+
of `buildSubsetExtraction` can use this helper method to find a
129+
suitable insertion point.
130+
131+
Example: The SSA values needed to build the subset in the example of
132+
`buildSubsetExtraction` are %t1 and %pos.
133+
}],
134+
/*retType=*/"::llvm::SmallVector<::mlir::Value>",
135+
/*methodName=*/"getValuesNeededToBuildSubsetExtraction",
136+
/*args=*/(ins)
137+
>,
102138
];
103139

104140
let extraClassDeclaration = [{

mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -109,19 +109,17 @@ def EliminateEmptyTensorsOp
109109
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
110110
let description = [{
111111
Try to eliminate all `tensor.empty` ops within the targeted op by replacing
112-
them with a destination tensor.
112+
them with another destination tensor.
113113

114-
`tensor.empty` ops cannot be bufferizes. They can either be converted to
115-
`bufferization.alloc_tensor` or replaced with another tensor (via this
116-
transform). `tensor.empty` does not specify the contents of the returned
114+
"tensor.empty" ops cannot be bufferized. They can either be converted to
115+
"bufferization.alloc_tensor" or replaced with another tensor (via this
116+
transform). "tensor.empty" does not specify the contents of the returned
117117
tensor so their results can be replaced with arbitrary tensor values as long
118118
as the dimensions match.
119119

120-
This transform looks for `tensor.empty` ops where the SSA use-def chain of
121-
the result ends in a supported "anchor op" (always following the aliasing
122-
OpOperand/OpResult chain). Currently supported anchor ops are:
123-
- `tensor.insert_slice`
124-
- `bufferization.yield` (inside `bufferization.alloc_tensor`)
120+
This transformation looks for subset ops that insert a tensor that
121+
originates from a "tensor.empty" (as per the reverse use-def chain). Such
122+
"tensor.empty" ops are replaced with the destination subset.
125123

126124
Example:
127125

@@ -138,6 +136,10 @@ def EliminateEmptyTensorsOp
138136
%2 = tensor.insert_slice %1 into %t[1][5][1]
139137
```
140138

139+
In the above example, the subset op is "tensor.insert_slice". When tracing
140+
back the reverse use-def chain of a the source, we end up at a
141+
"tensor.empty" op.
142+
141143
The above example can bufferize without an allocation (in the absence of
142144
other conflicts) because there is no longer a `tensor.empty` op.
143145

mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -402,11 +402,22 @@ def PromoteBuffersToStack : Pass<"promote-buffers-to-stack", "func::FuncOp"> {
402402
def EmptyTensorElimination : Pass<"eliminate-empty-tensors"> {
403403
let summary = "Try to eliminate all tensor.empty ops.";
404404
let description = [{
405-
This pass tries to eliminate all insert_slice op-anchored tensor.empty ops.
406-
I.e., when a value that is equivalent to an tensor.empty op is inserted into
407-
another tensor, this pass tries to rewrite the IR in such a way that the
408-
destination tensor of the insert_slice op is used directly instead of the
409-
tensor.empty result.
405+
Try to eliminate "tensor.empty" ops inside `op`. This transformation looks
406+
for subset ops that insert a tensor that originates from a "tensor.empty"
407+
(as per the reverse use-def chain). Such "tensor.empty" ops are replaced
408+
with the destination subset.
409+
410+
E.g.:
411+
```
412+
%0 = tensor.empty() : tensor<10xf32>
413+
%1 = linalg.fill ... outs(%0 : tensor<10xf32>)
414+
%2 = tensor.insert_slice %0 into %t ...
415+
```
416+
417+
In the above example, the subset op is "tensor.insert_slice". When tracing
418+
back the reverse use-def chain of a the source, we end up at a
419+
"tensor.empty" op. The "tensor.empty" op is replaced with a
420+
"tensor.extract_slice" op.
410421
}];
411422
let constructor = "mlir::bufferization::createEmptyTensorEliminationPass()";
412423
}

mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h

Lines changed: 13 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -19,38 +19,26 @@ struct BufferizationStatistics;
1919
class OneShotAnalysisState;
2020
struct OneShotBufferizationOptions;
2121

22-
/// A function that matches anchor OpOperands for tensor::EmptyOp elimination.
23-
/// If an OpOperand is matched, the function should populate the SmallVector
24-
/// with all values that are needed during `RewriteFn` to produce the
25-
/// replacement value.
26-
using AnchorMatchFn = std::function<bool(OpOperand &, SmallVector<Value> &)>;
27-
28-
/// A function that rewrites matched anchors.
29-
using RewriteFn = std::function<Value(OpBuilder &, Location, OpOperand &)>;
30-
31-
/// Try to eliminate tensor::EmptyOps inside `op`.
22+
/// Try to eliminate "tensor.empty" ops inside `op`. This transformation looks
23+
/// for subset ops that insert a tensor that originates from a "tensor.empty"
24+
/// (as per the reverse use-def chain). Such "tensor.empty" ops are replaced
25+
/// with the destination subset.
3226
///
33-
/// * `rewriteFunc` generates the replacement for the tensor::EmptyOp.
34-
/// * Only tensor::EmptyOps that are anchored on a matching OpOperand as per
35-
/// `anchorMatchFunc` are considered. "Anchored" means that there is a path
36-
/// on the reverse SSA use-def chain, starting from the OpOperand and always
37-
/// following the aliasing OpOperand, that eventually ends at a single
38-
/// tensor::EmptyOp.
27+
/// E.g.:
28+
/// %0 = tensor.empty() : tensor<10xf32>
29+
/// %1 = linalg.fill ... outs(%0 : tensor<10xf32>)
30+
/// %2 = tensor.insert_slice %0 into %t ...
31+
///
32+
/// In the above example, the subset op is "tensor.insert_slice". When tracing
33+
/// back the reverse use-def chain of a the source, we end up at a
34+
/// "tensor.empty" op.
3935
LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op,
40-
OneShotAnalysisState &state,
41-
AnchorMatchFn anchorMatchFunc,
42-
RewriteFn rewriteFunc);
36+
OneShotAnalysisState &state);
4337

4438
/// Within the given operation, hoist buffers from loops where possible. See
4539
/// "BufferLoopHoistingPass" for more information.
4640
void hoistBuffersFromLoops(Operation *op);
4741

48-
/// Try to eliminate tensor::EmptyOps inside `op` that are anchored on an
49-
/// InsertSliceOp, i.e., if it is eventually inserted into another tensor
50-
/// (and some other conditions are met).
51-
LogicalResult insertSliceAnchoredEmptyTensorEliminationStep(
52-
RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state);
53-
5442
/// Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops.
5543
/// After applying this transform, the IR can be bufferized without inserting
5644
/// additional buffer allocations.

mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,7 @@ DiagnosedSilenceableFailure transform::EliminateEmptyTensorsOp::apply(
121121
if (failed(analyzeOp(target, state)))
122122
return mlir::emitSilenceableFailure(target->getLoc())
123123
<< "failed to analyze op";
124-
if (failed(bufferization::insertSliceAnchoredEmptyTensorEliminationStep(
125-
rewriter, target, state)))
124+
if (failed(bufferization::eliminateEmptyTensors(rewriter, target, state)))
126125
return mlir::emitSilenceableFailure(target->getLoc())
127126
<< "failed to eliminate insert_slice anchored tensor.empty ops";
128127
}

mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp

Lines changed: 54 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
1212
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13+
#include "mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.h"
1314
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
1415
#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
1516
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -99,154 +100,67 @@ findValidInsertionPoint(Operation *emptyTensorOp,
99100
return nullptr;
100101
}
101102

102-
/// Try to eliminate tensor::EmptyOps inside `op`. A tensor::EmptyOp is replaced
103-
/// with the result of `rewriteFunc` if it is anchored on a matching
104-
/// OpOperand. "Anchored" means that there is a path on the reverse SSA use-def
105-
/// chain, starting from the OpOperand and always following the aliasing
106-
/// OpOperand, that eventually ends at the tensor::EmptyOp.
107-
///
108-
/// E.g.:
109-
/// %0 = tensor.empty() : tensor<10xf32>
110-
/// %1 = linalg.fill ... outs(%0 : tensor<10xf32>)
111-
/// %2 = tensor.insert_slice %0 into %t ...
112-
///
113-
/// In the above example, the anchor is the source operand of the insert_slice
114-
/// op. When tracing back the reverse use-def chain, we end up at a
115-
/// tensor.empty op.
116103
LogicalResult mlir::bufferization::eliminateEmptyTensors(
117-
RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state,
118-
AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc) {
104+
RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) {
119105
OpBuilder::InsertionGuard g(rewriter);
120106

121-
op->walk([&](Operation *op) {
122-
for (OpOperand &operand : op->getOpOperands()) {
123-
// Skip operands that do not bufferize inplace.
124-
if (!state.isInPlace(operand))
125-
continue;
126-
// All values that are needed to create the replacement op.
127-
SmallVector<Value> neededValues;
128-
// Is this an anchor?
129-
if (!anchorMatchFunc(operand, neededValues))
107+
op->walk([&](SubsetInsertionOpInterface op) {
108+
OpOperand &source = op.getSourceOperand();
109+
// Skip operands that do not bufferize inplace. "tensor.empty" could still
110+
// be replaced, but the transformation may not be beneficial.
111+
if (!state.isInPlace(source))
112+
return WalkResult::skip();
113+
// All values that are needed to create the replacement op.
114+
SmallVector<Value> neededValues =
115+
op.getValuesNeededToBuildSubsetExtraction();
116+
117+
// Find tensor.empty ops on the reverse SSA use-def chain. Only follow
118+
// equivalent tensors. I.e., stop when there are ops such as extract_slice
119+
// on the path.
120+
TraversalConfig config;
121+
config.followEquivalentOnly = true;
122+
config.alwaysIncludeLeaves = false;
123+
// Replace only if the types match or are static <-> dynamic casts. We do
124+
// not support slices or reshapes.
125+
// TODO: This could be extended to support IR such as:
126+
// %0 = tensor.empty() : tensor<128xf32>
127+
// %1 = "some_op"(%0) : (tensor<128xf32>) -> (tensor<128xf32>)
128+
// %2 = tensor.expand_shape %1 ...
129+
// %3 = tensor.insert_slice %2 into ...
130+
config.followSameTypeOrCastsOnly = true;
131+
SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
132+
source.get(), /*condition=*/
133+
[&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); },
134+
config);
135+
136+
for (Value v : emptyTensors) {
137+
Operation *emptyTensorOp = v.getDefiningOp();
138+
139+
// Find a suitable insertion point. If no suitable insertion point for
140+
// the replacement can be found, skip this replacement.
141+
Operation *insertionPoint =
142+
findValidInsertionPoint(emptyTensorOp, neededValues);
143+
if (!insertionPoint)
130144
continue;
131145

132-
// Find tensor.empty ops on the reverse SSA use-def chain. Only follow
133-
// equivalent tensors. I.e., stop when there are ops such as extract_slice
134-
// on the path.
135-
TraversalConfig config;
136-
config.followEquivalentOnly = true;
137-
config.alwaysIncludeLeaves = false;
138-
// Replace only if the types match or are static <-> dynamic casts. We do
139-
// not support slices or reshapes.
140-
// TODO: This could be extended to support IR such as:
141-
// %0 = tensor.empty() : tensor<128xf32>
142-
// %1 = "some_op"(%0) : (tensor<128xf32>) -> (tensor<128xf32>)
143-
// %2 = tensor.expand_shape %1 ...
144-
// %3 = tensor.insert_slice %2 into ...
145-
config.followSameTypeOrCastsOnly = true;
146-
SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
147-
operand.get(), /*condition=*/
148-
[&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); },
149-
config);
150-
151-
for (Value v : emptyTensors) {
152-
Operation *emptyTensorOp = v.getDefiningOp();
153-
154-
// Find a suitable insertion point. If no suitable insertion point for
155-
// the replacement can be found, skip this replacement.
156-
Operation *insertionPoint =
157-
findValidInsertionPoint(emptyTensorOp, neededValues);
158-
if (!insertionPoint)
159-
continue;
160-
161-
rewriter.setInsertionPoint(insertionPoint);
162-
Value replacement =
163-
rewriteFunc(rewriter, emptyTensorOp->getLoc(), operand);
164-
if (!replacement)
165-
continue;
166-
if (replacement.getType() != v.getType()) {
167-
rewriter.setInsertionPointAfterValue(replacement);
168-
replacement = rewriter.create<tensor::CastOp>(v.getLoc(), v.getType(),
169-
replacement);
170-
}
171-
// Replace the tensor::EmptyOp.
172-
rewriter.replaceOp(emptyTensorOp, replacement);
173-
state.resetCache();
146+
rewriter.setInsertionPoint(insertionPoint);
147+
Value replacement =
148+
op.buildSubsetExtraction(rewriter, emptyTensorOp->getLoc());
149+
if (!replacement)
150+
continue;
151+
if (replacement.getType() != v.getType()) {
152+
rewriter.setInsertionPointAfterValue(replacement);
153+
replacement = rewriter.create<tensor::CastOp>(v.getLoc(), v.getType(),
154+
replacement);
174155
}
156+
// Replace the tensor::EmptyOp.
157+
rewriter.replaceOp(emptyTensorOp, replacement);
158+
state.resetCache();
175159
}
176-
});
177-
178-
return success();
179-
}
180-
181-
/// Try to eliminate tensor::EmptyOps inside `op`. An tensor::EmptyOp can be
182-
/// eliminated if it is eventually inserted into another tensor (and some other
183-
/// conditions are met).
184-
///
185-
/// E.g.:
186-
/// %0 = tensor.empty()
187-
/// %1 = linalg.fill(%cst, %0) {inplace = [true]}
188-
/// %2 = tensor.insert_slice %1 into %t[10][20][1]
189-
///
190-
/// tensor::EmptyOp elimination will try to fill %t inplace instead of filling a
191-
/// new allocation %0 and inserting it into %t. This is done by replacing the
192-
/// tensor::EmptyOp with:
193-
///
194-
/// %0 = tensor.extract_slice %t[10][20][1]
195-
///
196-
/// The analysis looks for matching ExtractSliceOp/InsertSliceOp pairs and lets
197-
/// those bufferize inplace in the absence of other conflicts.
198-
///
199-
/// Starting from an InsertSliceOp, an tensor::EmptyOp at the end of the insert
200-
/// source's reverse use-def chain is eliminated if:
201-
/// * On the reverse use-def chain path from the InsertSliceOp to the
202-
/// tensor::EmptyOp, all ops were decided to bufferize inplace and the buffer
203-
/// relation is "equivalent" (TODO: can be relaxed if needed).
204-
/// * The reverse use-def chain has exactly one end, which is the
205-
/// tensor::EmptyOp.
206-
template <typename OpTy>
207-
static LogicalResult insertSliceLikeAnchoredEmptyTensorEliminationStep(
208-
RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) {
209-
return eliminateEmptyTensors(
210-
rewriter, op, state,
211-
/*anchorMatchFunc=*/
212-
[&](OpOperand &operand, SmallVector<Value> &neededValues) {
213-
auto insertSliceOp = dyn_cast<OpTy>(operand.getOwner());
214-
if (!insertSliceOp)
215-
return false;
216-
if (&operand != &insertSliceOp->getOpOperand(0) /*source*/)
217-
return false;
218160

219-
// Collect all values that are needed to construct the replacement op.
220-
neededValues.append(insertSliceOp.getOffsets().begin(),
221-
insertSliceOp.getOffsets().end());
222-
neededValues.append(insertSliceOp.getSizes().begin(),
223-
insertSliceOp.getSizes().end());
224-
neededValues.append(insertSliceOp.getStrides().begin(),
225-
insertSliceOp.getStrides().end());
226-
neededValues.push_back(insertSliceOp.getDest());
227-
228-
return true;
229-
},
230-
/*rewriteFunc=*/
231-
[](OpBuilder &b, Location loc, OpOperand &operand) {
232-
auto insertOp = cast<OpTy>(operand.getOwner());
233-
auto extractOp = b.create<tensor::ExtractSliceOp>(
234-
loc, insertOp.getSourceType(), insertOp.getDest(),
235-
insertOp.getMixedOffsets(), insertOp.getMixedSizes(),
236-
insertOp.getMixedStrides());
237-
return extractOp.getResult();
238-
});
239-
}
161+
return WalkResult::advance();
162+
});
240163

241-
LogicalResult
242-
mlir::bufferization::insertSliceAnchoredEmptyTensorEliminationStep(
243-
RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) {
244-
if (failed(insertSliceLikeAnchoredEmptyTensorEliminationStep<
245-
tensor::InsertSliceOp>(rewriter, op, state)))
246-
return failure();
247-
if (failed(insertSliceLikeAnchoredEmptyTensorEliminationStep<
248-
tensor::ParallelInsertSliceOp>(rewriter, op, state)))
249-
return failure();
250164
return success();
251165
}
252166

@@ -276,8 +190,7 @@ void EmptyTensorElimination::runOnOperation() {
276190
}
277191

278192
IRRewriter rewriter(op->getContext());
279-
if (failed(bufferization::insertSliceAnchoredEmptyTensorEliminationStep(
280-
rewriter, op, state)))
193+
if (failed(bufferization::eliminateEmptyTensors(rewriter, op, state)))
281194
signalPassFailure();
282195
}
283196

0 commit comments

Comments
 (0)