Skip to content

Commit d9f526e

Browse files
add option to control poison padding
Signed-off-by: Nirvedh Meshram <nirvedh@gmail.com>
1 parent 511fa23 commit d9f526e

File tree

5 files changed

+119
-36
lines changed

5 files changed

+119
-36
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,14 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
239239
ArrayRef<int64_t> outerDimsPerm,
240240
ArrayRef<OpFoldResult> innerTiles);
241241

242+
// Same as above function but here dynamic dimensions are assumed
243+
// to require padding.
244+
static bool requirePaddingValueStrict(ArrayRef<int64_t> inputShape,
245+
ArrayRef<int64_t> innerDimsPos,
246+
ArrayRef<int64_t> outputShape,
247+
ArrayRef<int64_t> outerDimsPerm,
248+
ArrayRef<OpFoldResult> innerTiles);
249+
242250
static Value createDestinationTensor(OpBuilder &b, Location loc,
243251
Value source, ArrayRef<OpFoldResult> innerTileSizes,
244252
ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm);

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1914,9 +1914,12 @@ void populateElementwiseOpsFusionPatterns(
19141914
using ControlPropagationFn = std::function<bool(OpOperand *opOperand)>;
19151915

19161916
/// Patterns to bubble up or down data layout ops across other operations.
1917+
/// The function also has an option to allow the patterns to propagate with
1918+
/// poison padding if requested by the caller.
19171919
void populateDataLayoutPropagationPatterns(
19181920
RewritePatternSet &patterns,
1919-
const ControlPropagationFn &controlPackUnPackPropagation);
1921+
const ControlPropagationFn &controlPackUnPackPropagation,
1922+
bool PoisonPaddingOk = false);
19201923

19211924
/// Patterns to sink extract slice across other operations.
19221925
void populateExtractSliceSinkingPatterns(

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5310,6 +5310,35 @@ bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
53105310
return false;
53115311
}
53125312

5313+
bool PackOp::requirePaddingValueStrict(ArrayRef<int64_t> inputShape,
5314+
ArrayRef<int64_t> innerDimsPos,
5315+
ArrayRef<int64_t> outputShape,
5316+
ArrayRef<int64_t> outerDimsPerm,
5317+
ArrayRef<OpFoldResult> innerTiles) {
5318+
SmallVector<int64_t> outputTileSizes(
5319+
outputShape.take_front(inputShape.size()));
5320+
if (!outerDimsPerm.empty()) {
5321+
assert(outerDimsPerm.size() == outputTileSizes.size() &&
5322+
"expected output and outer_dims_perm to have same size");
5323+
applyPermutationToVector(outputTileSizes,
5324+
invertPermutationVector(outerDimsPerm));
5325+
}
5326+
for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
5327+
if (ShapedType::isDynamic(inputShape[pos]))
5328+
return true;
5329+
std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
5330+
5331+
if (!constantTile) {
5332+
if (ShapedType::isStatic(outputTileSizes[pos]) &&
5333+
(inputShape[pos] % outputTileSizes[pos] != 0))
5334+
return true;
5335+
} else if (inputShape[pos] % (*constantTile) != 0) {
5336+
return true;
5337+
}
5338+
}
5339+
return false;
5340+
}
5341+
53135342
LogicalResult PackOp::verify() {
53145343
if (failed(commonVerifierPackAndUnPackOp(*this)))
53155344
return failure();

mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp

Lines changed: 76 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -221,9 +221,10 @@ static SmallVector<int64_t> computeOuterDims(ArrayRef<int64_t> perm,
221221
/// inner_dims_pos = [0]
222222
/// inner_tiles = [8]
223223
/// into %init : tensor<?xf32> -> tensor<?x8xf32>
224-
static std::tuple<Value, AffineMap>
224+
static FailureOr<std::tuple<Value, AffineMap>>
225225
getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
226-
GenericOp genericOp, OpOperand *opOperand) {
226+
GenericOp genericOp, OpOperand *opOperand,
227+
bool poisonPaddingOk) {
227228
int64_t numOrigLoops = genericOp.getNumLoops();
228229
int64_t numInnerLoops = packInfo.getNumTiledLoops();
229230
int64_t numLoops = numOrigLoops + numInnerLoops;
@@ -287,12 +288,24 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
287288
// The operand does not have dimensions that relates to pack op.
288289
if (innerDimsPos.empty() && outerDimsPerm.empty())
289290
return std::make_tuple(opOperand->get(), indexingMap);
290-
291+
auto inputType = cast<RankedTensorType>(opOperand->get().getType());
292+
auto maybeIntInnerTileSizes = getConstantIntValues(innerTileSizes);
293+
if (!maybeIntInnerTileSizes.has_value()) {
294+
return failure();
295+
}
296+
if (!poisonPaddingOk &&
297+
linalg::PackOp::requirePaddingValueStrict(
298+
inputType.getShape(), innerDimsPos,
299+
linalg::PackOp::inferPackedType(inputType, *maybeIntInnerTileSizes,
300+
innerDimsPos, outerDimsPerm)
301+
.getShape(),
302+
outerDimsPerm, innerTileSizes))
303+
return failure();
291304
auto empty = linalg::PackOp::createDestinationTensor(
292305
b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm);
293306
auto poison = ub::PoisonOp::create(
294307
b, loc, getElementTypeOrSelf(opOperand->get().getType()));
295-
auto packedOperand =
308+
Value packedOperand =
296309
linalg::PackOp::create(b, loc, opOperand->get(), empty, innerDimsPos,
297310
innerTileSizes, poison, outerDimsPerm);
298311
return std::make_tuple(packedOperand, indexingMap);
@@ -304,10 +317,10 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
304317
/// around it. Implicitly this will only work when a packInfo can be obtained.
305318
/// This make sure that we are only using this function on parallel permuted
306319
/// dimensions.
307-
static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
308-
Value dest, AffineMap packedOutIndexingMap,
309-
const PackInfo &packInfo,
310-
bool isFoldableUnpackPack) {
320+
static FailureOr<GenericOp>
321+
packGenericOp(RewriterBase &rewriter, GenericOp genericOp, Value dest,
322+
AffineMap packedOutIndexingMap, const PackInfo &packInfo,
323+
bool isFoldableUnpackPack, bool poisonPaddingOk) {
311324
Location loc = genericOp.getLoc();
312325
SmallVector<Value> inputOperands;
313326
SmallVector<Value> inputOperandsFromUnpackedSource;
@@ -318,8 +331,13 @@ static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
318331
llvm::equal(packOp.getMixedTiles(), unPackOp.getMixedTiles());
319332
};
320333
for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
321-
auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
322-
rewriter, loc, packInfo, genericOp, inputOperand);
334+
auto mayBepackedOperandAndIndexing = getOrCreatePackedViewOfOperand(
335+
rewriter, loc, packInfo, genericOp, inputOperand, poisonPaddingOk);
336+
if (failed(mayBepackedOperandAndIndexing)) {
337+
return failure();
338+
}
339+
auto packedOperand = std::get<0>(*mayBepackedOperandAndIndexing);
340+
auto packedIndexingMap = std::get<1>(*mayBepackedOperandAndIndexing);
323341
auto unpackOp = inputOperand->get().getDefiningOp<linalg::UnPackOp>();
324342
auto packOp = packedOperand.getDefiningOp<linalg::PackOp>();
325343
if (packOp && unpackOp && hasEquivalentTiles(packOp, unpackOp)) {
@@ -410,7 +428,8 @@ static bool isGenericOutsNotUsed(linalg::GenericOp genericOp) {
410428
/// } -> tensor<?x?x8x2xf32>
411429
static FailureOr<GenericOp>
412430
bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp,
413-
const ControlPropagationFn &controlFn) {
431+
const ControlPropagationFn &controlFn,
432+
bool poisonPaddingOk) {
414433
auto genericOp = packOp.getSource().getDefiningOp<GenericOp>();
415434
if (!genericOp)
416435
return failure();
@@ -473,9 +492,14 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp,
473492
}
474493

475494
// Rebuild the indexing map for the corresponding init operand.
476-
auto [packedOutOperand, packedOutIndexingMap] =
495+
auto mayBepackedOperandAndIndexing =
477496
getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
478-
genericOp, opOperand);
497+
genericOp, opOperand, poisonPaddingOk);
498+
if (failed(mayBepackedOperandAndIndexing)) {
499+
return failure();
500+
}
501+
auto packedOutOperand = std::get<0>(*mayBepackedOperandAndIndexing);
502+
auto packedOutIndexingMap = std::get<1>(*mayBepackedOperandAndIndexing);
479503

480504
// Forward the new tensor.empty as a destination if it is one of the following
481505
// situations:
@@ -491,21 +515,24 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp,
491515
// pack(unpack) isn't naively foldable because the unpack op can be from
492516
// an arbitrary domain so we need to keep both.
493517
return packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap,
494-
*packInfo, /*isFoldableUnpackPack=*/false);
518+
*packInfo, /*isFoldableUnpackPack=*/false,
519+
poisonPaddingOk);
495520
}
496521

497522
/// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method.
498523
struct BubbleUpPackOpThroughGenericOpPattern
499524
: public OpRewritePattern<linalg::PackOp> {
500525
public:
501526
BubbleUpPackOpThroughGenericOpPattern(MLIRContext *context,
502-
ControlPropagationFn fun)
503-
: OpRewritePattern<linalg::PackOp>(context), controlFn(std::move(fun)) {}
527+
ControlPropagationFn fun,
528+
bool poisonPaddingOk)
529+
: OpRewritePattern<linalg::PackOp>(context), controlFn(std::move(fun)),
530+
poisonPaddingOk(std::move(poisonPaddingOk)) {}
504531

505532
LogicalResult matchAndRewrite(linalg::PackOp packOp,
506533
PatternRewriter &rewriter) const override {
507-
auto genericOp =
508-
bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn);
534+
auto genericOp = bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn,
535+
poisonPaddingOk);
509536
if (failed(genericOp))
510537
return failure();
511538
rewriter.replaceOp(packOp, genericOp->getResults());
@@ -514,6 +541,7 @@ struct BubbleUpPackOpThroughGenericOpPattern
514541

515542
private:
516543
ControlPropagationFn controlFn;
544+
bool poisonPaddingOk;
517545
};
518546

519547
/// Propagate a linalg.pack operation up through a tensor.pad. The idea is to
@@ -1083,7 +1111,8 @@ static FailureOr<OpOperand *> getUnPackedOperand(GenericOp genericOp) {
10831111
///
10841112
static FailureOr<std::tuple<GenericOp, Value>>
10851113
pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
1086-
ControlPropagationFn controlFn) {
1114+
ControlPropagationFn controlFn,
1115+
bool poisonPaddingOk) {
10871116
if (genericOp.getNumResults() != 1)
10881117
return failure();
10891118

@@ -1110,9 +1139,14 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
11101139
return failure();
11111140

11121141
// Rebuild the indexing map for the corresponding init operand.
1113-
auto [packedOutOperand, packedOutIndexingMap] =
1114-
getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
1115-
genericOp, genericOp.getDpsInitOperand(0));
1142+
auto mayBepackedOperandAndIndexing = getOrCreatePackedViewOfOperand(
1143+
rewriter, genericOp.getLoc(), *packInfo, genericOp,
1144+
genericOp.getDpsInitOperand(0), poisonPaddingOk);
1145+
if (failed(mayBepackedOperandAndIndexing)) {
1146+
return failure();
1147+
}
1148+
auto packedOutOperand = std::get<0>(*mayBepackedOperandAndIndexing);
1149+
auto packedOutIndexingMap = std::get<1>(*mayBepackedOperandAndIndexing);
11161150
auto destPack = packedOutOperand.getDefiningOp<linalg::PackOp>();
11171151

11181152
// Forward the new tensor.empty as a destination if it is one of the following
@@ -1132,9 +1166,12 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
11321166
// pack(unpack) is foldable in this case. This is because in pushing down the
11331167
// unpack, by default we will populate an additional pack op after the unpack.
11341168
// This guarantees them to be foldable.
1135-
GenericOp newGenericOp =
1169+
auto maybeGenericOp =
11361170
packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo,
1137-
/*isFoldableUnpackPack=*/true);
1171+
/*isFoldableUnpackPack=*/true, poisonPaddingOk);
1172+
if (failed(maybeGenericOp))
1173+
return failure();
1174+
GenericOp newGenericOp = *maybeGenericOp;
11381175
Value newResult =
11391176
newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0));
11401177

@@ -1160,13 +1197,15 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
11601197
struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern<GenericOp> {
11611198
public:
11621199
PushDownUnPackOpThroughGenericOp(MLIRContext *context,
1163-
ControlPropagationFn fun)
1164-
: OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {}
1200+
ControlPropagationFn fun,
1201+
bool poisonPaddingOk)
1202+
: OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)),
1203+
poisonPaddingOk(std::move(poisonPaddingOk)) {}
11651204

11661205
LogicalResult matchAndRewrite(GenericOp genericOp,
11671206
PatternRewriter &rewriter) const override {
1168-
auto genericAndRepl =
1169-
pushDownUnPackOpThroughGenericOp(rewriter, genericOp, controlFn);
1207+
auto genericAndRepl = pushDownUnPackOpThroughGenericOp(
1208+
rewriter, genericOp, controlFn, poisonPaddingOk);
11701209
if (failed(genericAndRepl))
11711210
return failure();
11721211
rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl));
@@ -1175,6 +1214,7 @@ struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern<GenericOp> {
11751214

11761215
private:
11771216
ControlPropagationFn controlFn;
1217+
bool poisonPaddingOk;
11781218
};
11791219

11801220
/// Propagate a linalg.unpack operation through a tensor.pad. The idea is to
@@ -1525,12 +1565,14 @@ class PushDownExtractSliceOpThroughGenericOp final
15251565

15261566
void mlir::linalg::populateDataLayoutPropagationPatterns(
15271567
RewritePatternSet &patterns,
1528-
const ControlPropagationFn &controlPackUnPackPropagation) {
1529-
patterns
1530-
.insert<BubbleUpPackOpThroughGenericOpPattern, BubbleUpPackThroughPadOp,
1531-
BubbleUpPackOpThroughReshapeOp, PushDownUnPackOpThroughGenericOp,
1532-
PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
1533-
patterns.getContext(), controlPackUnPackPropagation);
1568+
const ControlPropagationFn &controlPackUnPackPropagation,
1569+
bool PoisonPaddingOk) {
1570+
patterns.insert<BubbleUpPackThroughPadOp, BubbleUpPackOpThroughReshapeOp,
1571+
PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
1572+
patterns.getContext(), controlPackUnPackPropagation);
1573+
patterns.insert<BubbleUpPackOpThroughGenericOpPattern,
1574+
PushDownUnPackOpThroughGenericOp>(
1575+
patterns.getContext(), controlPackUnPackPropagation, PoisonPaddingOk);
15341576
}
15351577

15361578
void mlir::linalg::populateExtractSliceSinkingPatterns(

mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ struct TestDataLayoutPropagationPass
3333
MLIRContext *context = &getContext();
3434
RewritePatternSet patterns(context);
3535
linalg::populateDataLayoutPropagationPatterns(
36-
patterns, [](OpOperand *opOperand) { return true; });
36+
patterns, [](OpOperand *opOperand) { return true; },
37+
/*poisonPaddingOk=*/true);
3738
linalg::ControlPropagationFn controlExtract =
3839
[](OpOperand *opOperand) -> bool {
3940
Operation *producer = opOperand->get().getDefiningOp();

0 commit comments

Comments
 (0)