|
10 | 10 |
|
11 | 11 | #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
12 | 12 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
| 13 | +#include "mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.h" |
13 | 14 | #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
|
14 | 15 | #include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
|
15 | 16 | #include "mlir/Dialect/Tensor/IR/Tensor.h"
|
@@ -99,154 +100,67 @@ findValidInsertionPoint(Operation *emptyTensorOp,
|
99 | 100 | return nullptr;
|
100 | 101 | }
|
101 | 102 |
|
102 |
| -/// Try to eliminate tensor::EmptyOps inside `op`. A tensor::EmptyOp is replaced |
103 |
| -/// with the result of `rewriteFunc` if it is anchored on a matching |
104 |
| -/// OpOperand. "Anchored" means that there is a path on the reverse SSA use-def |
105 |
| -/// chain, starting from the OpOperand and always following the aliasing |
106 |
| -/// OpOperand, that eventually ends at the tensor::EmptyOp. |
107 |
| -/// |
108 |
| -/// E.g.: |
109 |
| -/// %0 = tensor.empty() : tensor<10xf32> |
110 |
| -/// %1 = linalg.fill ... outs(%0 : tensor<10xf32>) |
111 |
| -/// %2 = tensor.insert_slice %0 into %t ... |
112 |
| -/// |
113 |
| -/// In the above example, the anchor is the source operand of the insert_slice |
114 |
| -/// op. When tracing back the reverse use-def chain, we end up at a |
115 |
| -/// tensor.empty op. |
116 | 103 | LogicalResult mlir::bufferization::eliminateEmptyTensors(
|
117 |
| - RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state, |
118 |
| - AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc) { |
| 104 | + RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) { |
119 | 105 | OpBuilder::InsertionGuard g(rewriter);
|
120 | 106 |
|
121 |
| - op->walk([&](Operation *op) { |
122 |
| - for (OpOperand &operand : op->getOpOperands()) { |
123 |
| - // Skip operands that do not bufferize inplace. |
124 |
| - if (!state.isInPlace(operand)) |
125 |
| - continue; |
126 |
| - // All values that are needed to create the replacement op. |
127 |
| - SmallVector<Value> neededValues; |
128 |
| - // Is this an anchor? |
129 |
| - if (!anchorMatchFunc(operand, neededValues)) |
| 107 | + op->walk([&](SubsetInsertionOpInterface op) { |
| 108 | + OpOperand &source = op.getSourceOperand(); |
| 109 | + // Skip operands that do not bufferize inplace. "tensor.empty" could still |
| 110 | + // be replaced, but the transformation may not be beneficial. |
| 111 | + if (!state.isInPlace(source)) |
| 112 | + return WalkResult::skip(); |
| 113 | + // All values that are needed to create the replacement op. |
| 114 | + SmallVector<Value> neededValues = |
| 115 | + op.getValuesNeededToBuildSubsetExtraction(); |
| 116 | + |
| 117 | + // Find tensor.empty ops on the reverse SSA use-def chain. Only follow |
| 118 | + // equivalent tensors. I.e., stop when there are ops such as extract_slice |
| 119 | + // on the path. |
| 120 | + TraversalConfig config; |
| 121 | + config.followEquivalentOnly = true; |
| 122 | + config.alwaysIncludeLeaves = false; |
| 123 | + // Replace only if the types match or are static <-> dynamic casts. We do |
| 124 | + // not support slices or reshapes. |
| 125 | + // TODO: This could be extended to support IR such as: |
| 126 | + // %0 = tensor.empty() : tensor<128xf32> |
| 127 | + // %1 = "some_op"(%0) : (tensor<128xf32>) -> (tensor<128xf32>) |
| 128 | + // %2 = tensor.expand_shape %1 ... |
| 129 | + // %3 = tensor.insert_slice %2 into ... |
| 130 | + config.followSameTypeOrCastsOnly = true; |
| 131 | + SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain( |
| 132 | + source.get(), /*condition=*/ |
| 133 | + [&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); }, |
| 134 | + config); |
| 135 | + |
| 136 | + for (Value v : emptyTensors) { |
| 137 | + Operation *emptyTensorOp = v.getDefiningOp(); |
| 138 | + |
| 139 | + // Find a suitable insertion point. If no suitable insertion point for |
| 140 | + // the replacement can be found, skip this replacement. |
| 141 | + Operation *insertionPoint = |
| 142 | + findValidInsertionPoint(emptyTensorOp, neededValues); |
| 143 | + if (!insertionPoint) |
130 | 144 | continue;
|
131 | 145 |
|
132 |
| - // Find tensor.empty ops on the reverse SSA use-def chain. Only follow |
133 |
| - // equivalent tensors. I.e., stop when there are ops such as extract_slice |
134 |
| - // on the path. |
135 |
| - TraversalConfig config; |
136 |
| - config.followEquivalentOnly = true; |
137 |
| - config.alwaysIncludeLeaves = false; |
138 |
| - // Replace only if the types match or are static <-> dynamic casts. We do |
139 |
| - // not support slices or reshapes. |
140 |
| - // TODO: This could be extended to support IR such as: |
141 |
| - // %0 = tensor.empty() : tensor<128xf32> |
142 |
| - // %1 = "some_op"(%0) : (tensor<128xf32>) -> (tensor<128xf32>) |
143 |
| - // %2 = tensor.expand_shape %1 ... |
144 |
| - // %3 = tensor.insert_slice %2 into ... |
145 |
| - config.followSameTypeOrCastsOnly = true; |
146 |
| - SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain( |
147 |
| - operand.get(), /*condition=*/ |
148 |
| - [&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); }, |
149 |
| - config); |
150 |
| - |
151 |
| - for (Value v : emptyTensors) { |
152 |
| - Operation *emptyTensorOp = v.getDefiningOp(); |
153 |
| - |
154 |
| - // Find a suitable insertion point. If no suitable insertion point for |
155 |
| - // the replacement can be found, skip this replacement. |
156 |
| - Operation *insertionPoint = |
157 |
| - findValidInsertionPoint(emptyTensorOp, neededValues); |
158 |
| - if (!insertionPoint) |
159 |
| - continue; |
160 |
| - |
161 |
| - rewriter.setInsertionPoint(insertionPoint); |
162 |
| - Value replacement = |
163 |
| - rewriteFunc(rewriter, emptyTensorOp->getLoc(), operand); |
164 |
| - if (!replacement) |
165 |
| - continue; |
166 |
| - if (replacement.getType() != v.getType()) { |
167 |
| - rewriter.setInsertionPointAfterValue(replacement); |
168 |
| - replacement = rewriter.create<tensor::CastOp>(v.getLoc(), v.getType(), |
169 |
| - replacement); |
170 |
| - } |
171 |
| - // Replace the tensor::EmptyOp. |
172 |
| - rewriter.replaceOp(emptyTensorOp, replacement); |
173 |
| - state.resetCache(); |
| 146 | + rewriter.setInsertionPoint(insertionPoint); |
| 147 | + Value replacement = |
| 148 | + op.buildSubsetExtraction(rewriter, emptyTensorOp->getLoc()); |
| 149 | + if (!replacement) |
| 150 | + continue; |
| 151 | + if (replacement.getType() != v.getType()) { |
| 152 | + rewriter.setInsertionPointAfterValue(replacement); |
| 153 | + replacement = rewriter.create<tensor::CastOp>(v.getLoc(), v.getType(), |
| 154 | + replacement); |
174 | 155 | }
|
| 156 | + // Replace the tensor::EmptyOp. |
| 157 | + rewriter.replaceOp(emptyTensorOp, replacement); |
| 158 | + state.resetCache(); |
175 | 159 | }
|
176 |
| - }); |
177 |
| - |
178 |
| - return success(); |
179 |
| -} |
180 |
| - |
181 |
| -/// Try to eliminate tensor::EmptyOps inside `op`. An tensor::EmptyOp can be |
182 |
| -/// eliminated if it is eventually inserted into another tensor (and some other |
183 |
| -/// conditions are met). |
184 |
| -/// |
185 |
| -/// E.g.: |
186 |
| -/// %0 = tensor.empty() |
187 |
| -/// %1 = linalg.fill(%cst, %0) {inplace = [true]} |
188 |
| -/// %2 = tensor.insert_slice %1 into %t[10][20][1] |
189 |
| -/// |
190 |
| -/// tensor::EmptyOp elimination will try to fill %t inplace instead of filling a |
191 |
| -/// new allocation %0 and inserting it into %t. This is done by replacing the |
192 |
| -/// tensor::EmptyOp with: |
193 |
| -/// |
194 |
| -/// %0 = tensor.extract_slice %t[10][20][1] |
195 |
| -/// |
196 |
| -/// The analysis looks for matching ExtractSliceOp/InsertSliceOp pairs and lets |
197 |
| -/// those bufferize inplace in the absence of other conflicts. |
198 |
| -/// |
199 |
| -/// Starting from an InsertSliceOp, an tensor::EmptyOp at the end of the insert |
200 |
| -/// source's reverse use-def chain is eliminated if: |
201 |
| -/// * On the reverse use-def chain path from the InsertSliceOp to the |
202 |
| -/// tensor::EmptyOp, all ops were decided to bufferize inplace and the buffer |
203 |
| -/// relation is "equivalent" (TODO: can be relaxed if needed). |
204 |
| -/// * The reverse use-def chain has exactly one end, which is the |
205 |
| -/// tensor::EmptyOp. |
206 |
| -template <typename OpTy> |
207 |
| -static LogicalResult insertSliceLikeAnchoredEmptyTensorEliminationStep( |
208 |
| - RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) { |
209 |
| - return eliminateEmptyTensors( |
210 |
| - rewriter, op, state, |
211 |
| - /*anchorMatchFunc=*/ |
212 |
| - [&](OpOperand &operand, SmallVector<Value> &neededValues) { |
213 |
| - auto insertSliceOp = dyn_cast<OpTy>(operand.getOwner()); |
214 |
| - if (!insertSliceOp) |
215 |
| - return false; |
216 |
| - if (&operand != &insertSliceOp->getOpOperand(0) /*source*/) |
217 |
| - return false; |
218 | 160 |
|
219 |
| - // Collect all values that are needed to construct the replacement op. |
220 |
| - neededValues.append(insertSliceOp.getOffsets().begin(), |
221 |
| - insertSliceOp.getOffsets().end()); |
222 |
| - neededValues.append(insertSliceOp.getSizes().begin(), |
223 |
| - insertSliceOp.getSizes().end()); |
224 |
| - neededValues.append(insertSliceOp.getStrides().begin(), |
225 |
| - insertSliceOp.getStrides().end()); |
226 |
| - neededValues.push_back(insertSliceOp.getDest()); |
227 |
| - |
228 |
| - return true; |
229 |
| - }, |
230 |
| - /*rewriteFunc=*/ |
231 |
| - [](OpBuilder &b, Location loc, OpOperand &operand) { |
232 |
| - auto insertOp = cast<OpTy>(operand.getOwner()); |
233 |
| - auto extractOp = b.create<tensor::ExtractSliceOp>( |
234 |
| - loc, insertOp.getSourceType(), insertOp.getDest(), |
235 |
| - insertOp.getMixedOffsets(), insertOp.getMixedSizes(), |
236 |
| - insertOp.getMixedStrides()); |
237 |
| - return extractOp.getResult(); |
238 |
| - }); |
239 |
| -} |
| 161 | + return WalkResult::advance(); |
| 162 | + }); |
240 | 163 |
|
241 |
| -LogicalResult |
242 |
| -mlir::bufferization::insertSliceAnchoredEmptyTensorEliminationStep( |
243 |
| - RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) { |
244 |
| - if (failed(insertSliceLikeAnchoredEmptyTensorEliminationStep< |
245 |
| - tensor::InsertSliceOp>(rewriter, op, state))) |
246 |
| - return failure(); |
247 |
| - if (failed(insertSliceLikeAnchoredEmptyTensorEliminationStep< |
248 |
| - tensor::ParallelInsertSliceOp>(rewriter, op, state))) |
249 |
| - return failure(); |
250 | 164 | return success();
|
251 | 165 | }
|
252 | 166 |
|
@@ -276,8 +190,7 @@ void EmptyTensorElimination::runOnOperation() {
|
276 | 190 | }
|
277 | 191 |
|
278 | 192 | IRRewriter rewriter(op->getContext());
|
279 |
| - if (failed(bufferization::insertSliceAnchoredEmptyTensorEliminationStep( |
280 |
| - rewriter, op, state))) |
| 193 | + if (failed(bufferization::eliminateEmptyTensors(rewriter, op, state))) |
281 | 194 | signalPassFailure();
|
282 | 195 | }
|
283 | 196 |
|
|
0 commit comments