@@ -221,9 +221,10 @@ static SmallVector<int64_t> computeOuterDims(ArrayRef<int64_t> perm,
221221// / inner_dims_pos = [0]
222222// / inner_tiles = [8]
223223// / into %init : tensor<?xf32> -> tensor<?x8xf32>
224- static std::tuple<Value, AffineMap>
224+ static FailureOr< std::tuple<Value, AffineMap> >
225225getOrCreatePackedViewOfOperand (OpBuilder &b, Location loc, PackInfo packInfo,
226- GenericOp genericOp, OpOperand *opOperand) {
226+ GenericOp genericOp, OpOperand *opOperand,
227+ bool poisonPaddingOk) {
227228 int64_t numOrigLoops = genericOp.getNumLoops ();
228229 int64_t numInnerLoops = packInfo.getNumTiledLoops ();
229230 int64_t numLoops = numOrigLoops + numInnerLoops;
@@ -287,12 +288,24 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
287288 // The operand does not have dimensions that relates to pack op.
288289 if (innerDimsPos.empty () && outerDimsPerm.empty ())
289290 return std::make_tuple (opOperand->get (), indexingMap);
290-
291+ auto inputType = cast<RankedTensorType>(opOperand->get ().getType ());
292+ auto maybeIntInnerTileSizes = getConstantIntValues (innerTileSizes);
293+ if (!maybeIntInnerTileSizes.has_value ()) {
294+ return failure ();
295+ }
296+ if (!poisonPaddingOk &&
297+ linalg::PackOp::requirePaddingValueStrict (
298+ inputType.getShape (), innerDimsPos,
299+ linalg::PackOp::inferPackedType (inputType, *maybeIntInnerTileSizes,
300+ innerDimsPos, outerDimsPerm)
301+ .getShape (),
302+ outerDimsPerm, innerTileSizes))
303+ return failure ();
291304 auto empty = linalg::PackOp::createDestinationTensor (
292305 b, loc, opOperand->get (), innerTileSizes, innerDimsPos, outerDimsPerm);
293306 auto poison = ub::PoisonOp::create (
294307 b, loc, getElementTypeOrSelf (opOperand->get ().getType ()));
295- auto packedOperand =
308+ Value packedOperand =
296309 linalg::PackOp::create (b, loc, opOperand->get (), empty, innerDimsPos,
297310 innerTileSizes, poison, outerDimsPerm);
298311 return std::make_tuple (packedOperand, indexingMap);
@@ -304,10 +317,10 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
304317// / around it. Implicitly this will only work when a packInfo can be obtained.
305318// / This make sure that we are only using this function on parallel permuted
306319// / dimensions.
307- static GenericOp packGenericOp (RewriterBase &rewriter, GenericOp genericOp,
308- Value dest, AffineMap packedOutIndexingMap ,
309- const PackInfo &packInfo,
310- bool isFoldableUnpackPack ) {
320+ static FailureOr< GenericOp>
321+ packGenericOp (RewriterBase &rewriter, GenericOp genericOp, Value dest,
322+ AffineMap packedOutIndexingMap, const PackInfo &packInfo,
323+ bool isFoldableUnpackPack, bool poisonPaddingOk ) {
311324 Location loc = genericOp.getLoc ();
312325 SmallVector<Value> inputOperands;
313326 SmallVector<Value> inputOperandsFromUnpackedSource;
@@ -318,8 +331,13 @@ static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
318331 llvm::equal (packOp.getMixedTiles (), unPackOp.getMixedTiles ());
319332 };
320333 for (OpOperand *inputOperand : genericOp.getDpsInputOperands ()) {
321- auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand (
322- rewriter, loc, packInfo, genericOp, inputOperand);
334+ auto mayBepackedOperandAndIndexing = getOrCreatePackedViewOfOperand (
335+ rewriter, loc, packInfo, genericOp, inputOperand, poisonPaddingOk);
336+ if (failed (mayBepackedOperandAndIndexing)) {
337+ return failure ();
338+ }
339+ auto packedOperand = std::get<0 >(*mayBepackedOperandAndIndexing);
340+ auto packedIndexingMap = std::get<1 >(*mayBepackedOperandAndIndexing);
323341 auto unpackOp = inputOperand->get ().getDefiningOp <linalg::UnPackOp>();
324342 auto packOp = packedOperand.getDefiningOp <linalg::PackOp>();
325343 if (packOp && unpackOp && hasEquivalentTiles (packOp, unpackOp)) {
@@ -410,7 +428,8 @@ static bool isGenericOutsNotUsed(linalg::GenericOp genericOp) {
410428// / } -> tensor<?x?x8x2xf32>
411429static FailureOr<GenericOp>
412430bubbleUpPackOpThroughGenericOp (RewriterBase &rewriter, linalg::PackOp packOp,
413- const ControlPropagationFn &controlFn) {
431+ const ControlPropagationFn &controlFn,
432+ bool poisonPaddingOk) {
414433 auto genericOp = packOp.getSource ().getDefiningOp <GenericOp>();
415434 if (!genericOp)
416435 return failure ();
@@ -473,9 +492,14 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp,
473492 }
474493
475494 // Rebuild the indexing map for the corresponding init operand.
476- auto [packedOutOperand, packedOutIndexingMap] =
495+ auto mayBepackedOperandAndIndexing =
477496 getOrCreatePackedViewOfOperand (rewriter, genericOp.getLoc (), *packInfo,
478- genericOp, opOperand);
497+ genericOp, opOperand, poisonPaddingOk);
498+ if (failed (mayBepackedOperandAndIndexing)) {
499+ return failure ();
500+ }
501+ auto packedOutOperand = std::get<0 >(*mayBepackedOperandAndIndexing);
502+ auto packedOutIndexingMap = std::get<1 >(*mayBepackedOperandAndIndexing);
479503
480504 // Forward the new tensor.empty as a destination if it is one of the following
481505 // situations:
@@ -491,21 +515,24 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp,
491515 // pack(unpack) isn't naively foldable because the unpack op can be from
492516 // an arbitrary domain so we need to keep both.
493517 return packGenericOp (rewriter, genericOp, dest, packedOutIndexingMap,
494- *packInfo, /* isFoldableUnpackPack=*/ false );
518+ *packInfo, /* isFoldableUnpackPack=*/ false ,
519+ poisonPaddingOk);
495520}
496521
497522// / Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method.
498523struct BubbleUpPackOpThroughGenericOpPattern
499524 : public OpRewritePattern<linalg::PackOp> {
500525public:
501526 BubbleUpPackOpThroughGenericOpPattern (MLIRContext *context,
502- ControlPropagationFn fun)
503- : OpRewritePattern<linalg::PackOp>(context), controlFn(std::move(fun)) {}
527+ ControlPropagationFn fun,
528+ bool poisonPaddingOk)
529+ : OpRewritePattern<linalg::PackOp>(context), controlFn(std::move(fun)),
530+ poisonPaddingOk (std::move(poisonPaddingOk)) {}
504531
505532 LogicalResult matchAndRewrite (linalg::PackOp packOp,
506533 PatternRewriter &rewriter) const override {
507- auto genericOp =
508- bubbleUpPackOpThroughGenericOp (rewriter, packOp, controlFn );
534+ auto genericOp = bubbleUpPackOpThroughGenericOp (rewriter, packOp, controlFn,
535+ poisonPaddingOk );
509536 if (failed (genericOp))
510537 return failure ();
511538 rewriter.replaceOp (packOp, genericOp->getResults ());
@@ -514,6 +541,7 @@ struct BubbleUpPackOpThroughGenericOpPattern
514541
515542private:
516543 ControlPropagationFn controlFn;
544+ bool poisonPaddingOk;
517545};
518546
519547// / Propagate a linalg.pack operation up through a tensor.pad. The idea is to
@@ -1083,7 +1111,8 @@ static FailureOr<OpOperand *> getUnPackedOperand(GenericOp genericOp) {
10831111// /
10841112static FailureOr<std::tuple<GenericOp, Value>>
10851113pushDownUnPackOpThroughGenericOp (RewriterBase &rewriter, GenericOp genericOp,
1086- ControlPropagationFn controlFn) {
1114+ ControlPropagationFn controlFn,
1115+ bool poisonPaddingOk) {
10871116 if (genericOp.getNumResults () != 1 )
10881117 return failure ();
10891118
@@ -1110,9 +1139,14 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
11101139 return failure ();
11111140
11121141 // Rebuild the indexing map for the corresponding init operand.
1113- auto [packedOutOperand, packedOutIndexingMap] =
1114- getOrCreatePackedViewOfOperand (rewriter, genericOp.getLoc (), *packInfo,
1115- genericOp, genericOp.getDpsInitOperand (0 ));
1142+ auto mayBepackedOperandAndIndexing = getOrCreatePackedViewOfOperand (
1143+ rewriter, genericOp.getLoc (), *packInfo, genericOp,
1144+ genericOp.getDpsInitOperand (0 ), poisonPaddingOk);
1145+ if (failed (mayBepackedOperandAndIndexing)) {
1146+ return failure ();
1147+ }
1148+ auto packedOutOperand = std::get<0 >(*mayBepackedOperandAndIndexing);
1149+ auto packedOutIndexingMap = std::get<1 >(*mayBepackedOperandAndIndexing);
11161150 auto destPack = packedOutOperand.getDefiningOp <linalg::PackOp>();
11171151
11181152 // Forward the new tensor.empty as a destination if it is one of the following
@@ -1132,9 +1166,12 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
11321166 // pack(unpack) is foldable in this case. This is because in pushing down the
11331167 // unpack, by default we will populate an additional pack op after the unpack.
11341168 // This guarantees them to be foldable.
1135- GenericOp newGenericOp =
1169+ auto maybeGenericOp =
11361170 packGenericOp (rewriter, genericOp, dest, packedOutIndexingMap, *packInfo,
1137- /* isFoldableUnpackPack=*/ true );
1171+ /* isFoldableUnpackPack=*/ true , poisonPaddingOk);
1172+ if (failed (maybeGenericOp))
1173+ return failure ();
1174+ GenericOp newGenericOp = *maybeGenericOp;
11381175 Value newResult =
11391176 newGenericOp.getTiedOpResult (newGenericOp.getDpsInitOperand (0 ));
11401177
@@ -1160,13 +1197,15 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
11601197struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern <GenericOp> {
11611198public:
11621199 PushDownUnPackOpThroughGenericOp (MLIRContext *context,
1163- ControlPropagationFn fun)
1164- : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {}
1200+ ControlPropagationFn fun,
1201+ bool poisonPaddingOk)
1202+ : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)),
1203+ poisonPaddingOk (std::move(poisonPaddingOk)) {}
11651204
11661205 LogicalResult matchAndRewrite (GenericOp genericOp,
11671206 PatternRewriter &rewriter) const override {
1168- auto genericAndRepl =
1169- pushDownUnPackOpThroughGenericOp ( rewriter, genericOp, controlFn);
1207+ auto genericAndRepl = pushDownUnPackOpThroughGenericOp (
1208+ rewriter, genericOp, controlFn, poisonPaddingOk );
11701209 if (failed (genericAndRepl))
11711210 return failure ();
11721211 rewriter.replaceOp (genericOp, std::get<1 >(*genericAndRepl));
@@ -1175,6 +1214,7 @@ struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern<GenericOp> {
11751214
11761215private:
11771216 ControlPropagationFn controlFn;
1217+ bool poisonPaddingOk;
11781218};
11791219
11801220// / Propagate a linalg.unpack operation through a tensor.pad. The idea is to
@@ -1525,12 +1565,14 @@ class PushDownExtractSliceOpThroughGenericOp final
15251565
15261566void mlir::linalg::populateDataLayoutPropagationPatterns (
15271567 RewritePatternSet &patterns,
1528- const ControlPropagationFn &controlPackUnPackPropagation) {
1529- patterns
1530- .insert <BubbleUpPackOpThroughGenericOpPattern, BubbleUpPackThroughPadOp,
1531- BubbleUpPackOpThroughReshapeOp, PushDownUnPackOpThroughGenericOp,
1532- PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
1533- patterns.getContext (), controlPackUnPackPropagation);
1568+ const ControlPropagationFn &controlPackUnPackPropagation,
1569+ bool PoisonPaddingOk) {
1570+ patterns.insert <BubbleUpPackThroughPadOp, BubbleUpPackOpThroughReshapeOp,
1571+ PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
1572+ patterns.getContext (), controlPackUnPackPropagation);
1573+ patterns.insert <BubbleUpPackOpThroughGenericOpPattern,
1574+ PushDownUnPackOpThroughGenericOp>(
1575+ patterns.getContext (), controlPackUnPackPropagation, PoisonPaddingOk);
15341576}
15351577
15361578void mlir::linalg::populateExtractSliceSinkingPatterns (
0 commit comments