Skip to content

Commit f3a357b

Browse files
committed
refactor
Signed-off-by: James Newling <james.newling@gmail.com>
1 parent 93185ea commit f3a357b

File tree

3 files changed

+74
-71
lines changed

3 files changed

+74
-71
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
#include "mlir/Interfaces/TilingInterface.h"
2626
#include "mlir/Transforms/DialectConversion.h"
2727
#include "llvm/ADT/SmallBitVector.h"
28-
#include "llvm/ADT/SmallSet.h"
2928

3029
namespace mlir {
3130
namespace bufferization {
@@ -621,35 +620,40 @@ LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
621620
/// In the future, more general interfaces can be devised to encode similar
622621
/// shape evolutions and map between an op and its operands.
623622
SmallVector<OpFoldResult>
624-
computePaddedShape(RewriterBase &rewriter, TypedValue<RankedTensorType> v,
623+
computePaddedShape(OpBuilder &rewriter, TypedValue<RankedTensorType> v,
625624
AffineMap indexingMap, ArrayRef<OpFoldResult> indexingSizes,
626625
const PadTilingInterfaceOptions &options);
627626

628627
using PadSizeComputationFunction =
629628
std::function<FailureOr<SmallVector<OpFoldResult>>(
630-
RewriterBase &, OpOperand &, ArrayRef<Range>,
629+
OpBuilder &, OpOperand &, ArrayRef<Range>,
631630
const PadTilingInterfaceOptions &)>;
632631

633632
/// Specific helper for Linalg ops.
634633
FailureOr<SmallVector<OpFoldResult>> computeIndexingMapOpInterfacePaddedShape(
635-
RewriterBase &rewriter, OpOperand &operandToPad,
634+
OpBuilder &rewriter, OpOperand &operandToPad,
636635
ArrayRef<Range> iterationDomain, const PadTilingInterfaceOptions &options);
637636

638-
/// Pad the iterator dimensions `options.paddingDimensions` of `opToPad`.
639-
///
637+
/// Pad the iterator dimensions of `toPad`.
640638
/// * "options.paddingSizes" indicates that each padding dimension should be
641639
/// padded to the specified padding size.
642640
/// * "options.padToMultipleOf" indicates that the paddingSizes should be
643641
// interpreted as the bounding box (dynamic) value to pad to.
644642
/// * Use "options.paddingValues" to set the padding value of the created
645643
// tensor::PadOp.
646-
/// * The tensor::PadOp is returned on success.
647644

648-
FailureOr<TilingInterface>
649-
rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad,
650-
const PadTilingInterfaceOptions &constOptions,
651-
SmallVector<tensor::PadOp> &padOps,
652-
const PadSizeComputationFunction &computePaddingSizeFun =
645+
struct PadTilingInterfaceResult {
646+
/// Padded operands of `toPad`.
647+
SmallVector<tensor::PadOp> padOps;
648+
/// Slices of the padded op that have the same shapes as `toPad` results.
649+
SmallVector<Value> replacements;
650+
/// The cloned and padded version of `toPad`.
651+
TilingInterface paddedOp;
652+
};
653+
FailureOr<PadTilingInterfaceResult>
654+
rewriteAsPaddedOp(OpBuilder &, TilingInterface toPad,
655+
PadTilingInterfaceOptions options,
656+
const PadSizeComputationFunction & =
653657
&computeIndexingMapOpInterfacePaddedShape);
654658

655659
namespace detail {

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2457,26 +2457,27 @@ transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,
24572457
}
24582458

24592459
// Set options.
2460-
TilingInterface paddedOp;
24612460
PadTilingInterfaceOptions options;
24622461
options.setPaddingValues(paddingValues)
24632462
.setPaddingSizes(getMixedPaddingSizes())
24642463
.setPadToMultipleOf(getPadToMultipleOf());
24652464

2466-
// Apply padding.
2467-
SmallVector<tensor::PadOp> newPadOps;
2468-
FailureOr<TilingInterface> maybePaddedOp = rewriteAsPaddedOp(
2469-
rewriter, cast<TilingInterface>(targetOp.getOperation()), options,
2470-
newPadOps);
2471-
if (failed(maybePaddedOp)) {
2465+
auto maybePadOps = rewriteAsPaddedOp(
2466+
rewriter, cast<TilingInterface>(targetOp.getOperation()), options);
2467+
if (failed(maybePadOps)) {
24722468
auto diag = emitSilenceableError() << "failed to pad op";
24732469
diag.attachNote(target->getLoc()) << "target op";
24742470
return diag;
24752471
}
24762472

2473+
const auto &[newPadOps, replacementValues, newPaddedOp] = *maybePadOps;
2474+
24772475
// Set transform results.
2478-
paddedOps.push_back(cast<TilingInterface>(maybePaddedOp->getOperation()));
2476+
paddedOps.push_back(newPaddedOp);
24792477
padOps.append(newPadOps.begin(), newPadOps.end());
2478+
2479+
// erase targetOp:
2480+
rewriter.replaceOp(targetOp.getOperation(), replacementValues);
24802481
}
24812482

24822483
results.set(cast<OpResult>(getPadded()), paddedOps);

mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp

Lines changed: 49 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,11 @@ static int64_t extractConstantMultiplier(AffineExpr expr) {
9595
/// - affine_map<(d0, d1) -> (d0 * 3 + d1)>
9696
/// In the future, more general interfaces can be devised to encode similar
9797
/// shape evolutions and map between an op and its operands.
98-
SmallVector<OpFoldResult> linalg::computePaddedShape(
99-
RewriterBase &rewriter, TypedValue<RankedTensorType> v,
100-
AffineMap indexingMap, ArrayRef<OpFoldResult> indexingSizes,
101-
const PadTilingInterfaceOptions &options) {
98+
SmallVector<OpFoldResult>
99+
linalg::computePaddedShape(OpBuilder &rewriter, TypedValue<RankedTensorType> v,
100+
AffineMap indexingMap,
101+
ArrayRef<OpFoldResult> indexingSizes,
102+
const PadTilingInterfaceOptions &options) {
102103
Location loc = v.getLoc();
103104
SmallVector<OpFoldResult> paddedShape;
104105
auto tensorType = cast<RankedTensorType>(v.getType());
@@ -198,7 +199,7 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
198199

199200
FailureOr<SmallVector<OpFoldResult>>
200201
linalg::computeIndexingMapOpInterfacePaddedShape(
201-
RewriterBase &rewriter, OpOperand &operandToPad,
202+
OpBuilder &rewriter, OpOperand &operandToPad,
202203
ArrayRef<Range> iterationDomain, const PadTilingInterfaceOptions &options) {
203204
auto transferOp =
204205
llvm::dyn_cast<IndexingMapOpInterface>(operandToPad.getOwner());
@@ -224,7 +225,7 @@ linalg::computeIndexingMapOpInterfacePaddedShape(
224225

225226
/// Pad a single operand to `paddedShape` using `paddingValueAttr` as padding
226227
/// Value.
227-
static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad,
228+
static Value padOperand(OpBuilder &rewriter, TilingInterface opToPad,
228229
TypedValue<RankedTensorType> v,
229230
ArrayRef<OpFoldResult> paddedShape,
230231
Attribute paddingValueAttr) {
@@ -263,45 +264,44 @@ static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad,
263264
paddingValue, /*nofold=*/false, dynDims);
264265
}
265266

266-
FailureOr<TilingInterface> linalg::rewriteAsPaddedOp(
267-
RewriterBase &rewriter, TilingInterface opToPad,
268-
const PadTilingInterfaceOptions &constOptions,
269-
SmallVector<tensor::PadOp> &padOps,
267+
FailureOr<PadTilingInterfaceResult> linalg::rewriteAsPaddedOp(
268+
OpBuilder &builder, TilingInterface toPad,
269+
PadTilingInterfaceOptions options,
270270
const PadSizeComputationFunction &computePaddingSizeFun) {
271-
LLVM_DEBUG(DBGS() << "Start rewriteAsPaddedOp : " << opToPad << "\n");
271+
LLVM_DEBUG(DBGS() << "Start rewriteAsPaddedOp : " << toPad << "\n");
272+
SmallVector<tensor::PadOp> padOps;
273+
Location loc = toPad.getLoc();
272274

273-
Location loc = opToPad.getLoc();
274-
PadTilingInterfaceOptions options(constOptions);
275275
// Allow inference of pad values if they are not explicitly specified.
276276
// TODO: be mindful about the value depending on the actual operation.
277277
if (options.paddingValues.empty()) {
278-
SmallVector<Type> types(opToPad->getOperandTypes());
279-
llvm::append_range(types, opToPad->getResultTypes());
278+
SmallVector<Type> types(toPad->getOperandTypes());
279+
llvm::append_range(types, toPad->getResultTypes());
280280
for (Type t : types) {
281281
options.paddingValues.push_back(
282-
rewriter.getZeroAttr(getElementTypeOrSelf(t)));
282+
builder.getZeroAttr(getElementTypeOrSelf(t)));
283283
}
284284
}
285285

286-
if (llvm::any_of(opToPad->getOperands(),
286+
if (llvm::any_of(toPad->getOperands(),
287287
[](Value v) { return isa<MemRefType>(v.getType()); })) {
288-
return rewriter.notifyMatchFailure(opToPad,
289-
"expected operation on tensors");
288+
LLVM_DEBUG(DBGS() << "Not an operation on tensors: FAIL\n");
289+
return failure();
290290
}
291291

292-
OpBuilder::InsertionGuard g(rewriter);
293-
// Set IP after opToPad because we also take the dims of opToPad's output.
294-
rewriter.setInsertionPointAfter(opToPad);
292+
OpBuilder::InsertionGuard g(builder);
293+
// Set IP after toPad because we also take the dims of toPad's output.
294+
builder.setInsertionPointAfter(toPad);
295295

296296
// 1. Get the loopUpperBounds from the TilingInterface.
297-
SmallVector<Range> iterationDomain = opToPad.getIterationDomain(rewriter);
297+
SmallVector<Range> iterationDomain = toPad.getIterationDomain(builder);
298298

299299
// 2. For each operand.
300300
SmallVector<Value> newOperands;
301-
newOperands.reserve(opToPad->getNumOperands());
302-
for (OpOperand &opOperand : opToPad->getOpOperands()) {
301+
newOperands.reserve(toPad->getNumOperands());
302+
for (OpOperand &opOperand : toPad->getOpOperands()) {
303303
Value operand = opOperand.get();
304-
LLVM_DEBUG(DBGS() << "--start padding oprd: " << operand << "\n");
304+
LLVM_DEBUG(DBGS() << "--start padding operand: " << operand << "\n");
305305

306306
// 2.a. Skip scalar-like operands.
307307
Type operandType = operand.getType();
@@ -311,27 +311,29 @@ FailureOr<TilingInterface> linalg::rewriteAsPaddedOp(
311311
newOperands.push_back(operand);
312312
continue;
313313
}
314+
314315
// 2.a. Compute padded shape.
315316
FailureOr<SmallVector<OpFoldResult>> maybePaddedShape =
316-
computePaddingSizeFun(rewriter, opOperand, iterationDomain, options);
317+
computePaddingSizeFun(builder, opOperand, iterationDomain, options);
317318
if (failed(maybePaddedShape)) {
318-
return rewriter.notifyMatchFailure(opToPad, "could not pad op");
319+
LLVM_DEBUG(DBGS() << "Could not get padded shape of operand: FAIL\n");
320+
return failure();
319321
}
320322

321323
// 2.b. Expect proper `paddingValues`.
322324
// TODO: we may want to allow garbage padding in the future, in which case
323325
// we would just not assert.
324326
if (opOperand.getOperandNumber() >= options.paddingValues.size()) {
325-
return rewriter.notifyMatchFailure(opToPad,
326-
"--no padding value specified");
327+
LLVM_DEBUG(DBGS() << "Too few padding values specified: FAIL\n");
328+
return failure();
327329
}
328330
Attribute paddingValueAttr =
329331
options.paddingValues[opOperand.getOperandNumber()];
330332

331333
// 2.c. Perform actual padding.
332-
Value paddedOperand = padOperand(
333-
rewriter, opToPad, cast<TypedValue<RankedTensorType>>(operand),
334-
*maybePaddedShape, paddingValueAttr);
334+
Value paddedOperand =
335+
padOperand(builder, toPad, cast<TypedValue<RankedTensorType>>(operand),
336+
*maybePaddedShape, paddingValueAttr);
335337
LLVM_DEBUG(DBGS() << "--done padding operand: " << paddedOperand << "\n");
336338

337339
// 2.d. Perform actual padding.
@@ -342,38 +344,34 @@ FailureOr<TilingInterface> linalg::rewriteAsPaddedOp(
342344

343345
// 3. Form the resulting tensor::ExtractSliceOp.
344346
ReifiedRankedShapedTypeDims reifiedResultShapes;
345-
if (failed(reifyResultShapes(rewriter, opToPad, reifiedResultShapes))) {
346-
LLVM_DEBUG(DBGS() << "--failed to reify result shapes -> FAIL\n");
347-
return rewriter.notifyMatchFailure(opToPad,
348-
"failed to reify result shapes");
347+
if (failed(reifyResultShapes(builder, toPad, reifiedResultShapes))) {
348+
LLVM_DEBUG(DBGS() << "Failed to reify result shapes: FAIL\n");
349+
return failure();
349350
}
350-
assert(reifiedResultShapes.size() == opToPad->getNumResults() &&
351+
assert(reifiedResultShapes.size() == toPad->getNumResults() &&
351352
"expected same number of results");
352353

353-
// Clone `opToPad` to operate on the statically padded shapes.
354+
// Clone `toPad` to operate on the statically padded shapes.
354355
auto resultTensorTypes =
355-
ValueRange(newOperands).take_back(opToPad->getNumResults()).getTypes();
356-
// clone **should** properly notify the rewriter.
356+
ValueRange(newOperands).take_back(toPad->getNumResults()).getTypes();
357+
// clone **should** properly notify the builder.
357358
TilingInterface paddedOp =
358-
clone(rewriter, opToPad, resultTensorTypes, newOperands);
359+
clone(builder, toPad, resultTensorTypes, newOperands);
359360
LLVM_DEBUG(DBGS() << "--cloned padded op: " << paddedOp << "\n");
360361

361-
// Recover the slice out of the new static results. This keeps the original
362-
// opToPad around because it uses the dims of the original results.
362+
// Recover the slice out of the new static results.
363363
SmallVector<Value> paddedSubtensorResults;
364-
paddedSubtensorResults.reserve(opToPad->getNumResults());
364+
paddedSubtensorResults.reserve(toPad->getNumResults());
365365
for (const auto &en : llvm::enumerate(paddedOp->getResults())) {
366366
Value paddedResult = en.value();
367367
int64_t resultNumber = en.index();
368368
int64_t rank = cast<RankedTensorType>(paddedResult.getType()).getRank();
369-
SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
370-
SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
369+
SmallVector<OpFoldResult> offsets(rank, builder.getIndexAttr(0));
370+
SmallVector<OpFoldResult> strides(rank, builder.getIndexAttr(1));
371371
paddedSubtensorResults.push_back(tensor::ExtractSliceOp::create(
372-
rewriter, loc, paddedResult, offsets, reifiedResultShapes[resultNumber],
372+
builder, loc, paddedResult, offsets, reifiedResultShapes[resultNumber],
373373
strides));
374374
}
375375

376-
rewriter.replaceOp(opToPad, paddedSubtensorResults);
377-
378-
return paddedOp;
376+
return PadTilingInterfaceResult{padOps, paddedSubtensorResults, paddedOp};
379377
}

0 commit comments

Comments
 (0)