2020#include " mlir/Dialect/Utils/IndexingUtils.h"
2121#include " mlir/Dialect/Utils/StaticValueUtils.h"
2222#include " mlir/Dialect/Utils/StructuredOpsUtils.h"
23+ #include " mlir/IR/BuiltinTypeInterfaces.h"
2324#include " mlir/Interfaces/TilingInterface.h"
2425#include " mlir/Interfaces/ValueBoundsOpInterface.h"
2526#include " llvm/Support/Debug.h"
@@ -887,26 +888,55 @@ struct PackOpTiling
887888
888889 ArrayRef<OpFoldResult> offsets (allOffsets[0 ]);
889890 ArrayRef<OpFoldResult> sizes (allSizes[0 ]);
890-
891891 auto packOp = cast<PackOp>(op);
892- // It is not trivial to infer dest tile from source tile if `packOp` has
893- // padding semantic.
894- if (packOp.getPaddingValue ())
895- return failure ();
896-
897892 Location loc = packOp.getLoc ();
898-
899893 SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
900894 DenseMap<int64_t , OpFoldResult> dimAndTileMapping =
901895 packOp.getDimAndTileMapping ();
902896 for (auto dim : llvm::seq<int64_t >(packOp.getSourceRank ())) {
903897 if (dimAndTileMapping.count (dim)) {
904- FailureOr<int64_t > cstSize =
898+ FailureOr<int64_t > cstTileSize =
905899 ValueBoundsConstraintSet::computeConstantBound (
906900 presburger::BoundType::UB, sizes[dim],
907901 /* stopCondition=*/ nullptr , /* closedUB=*/ true );
908902 std::optional<int64_t > cstInnerSize =
909903 getConstantIntValue (dimAndTileMapping[dim]);
904+
905+ // If a dimension is not tiled, it is always valid to fuse the pack op,
906+ // even if the op has padding semantics. Because it always generates a
907+ // full slice along the dimension.
908+ // TODO: It could be untiled if the `srcDimSize` is dynamic. It is a
909+ // hard check to determine if a dimension is tiled or not.
910+ int64_t srcDimSize = packOp.getSourceType ().getDimSize (dim);
911+ int64_t destDimSize = packOp.getDestType ().getDimSize (dim);
912+ bool isTiled = failed (cstTileSize) ||
913+ ShapedType::isDynamic (srcDimSize) ||
914+ cstTileSize.value () != srcDimSize;
915+ if (!isTiled) {
916+ outerDimOffsets.push_back (offsets[dim]);
917+ if (ShapedType::isStatic (destDimSize)) {
918+ outerDimSizes.push_back (b.getIndexAttr (destDimSize));
919+ } else {
920+ outerDimSizes.push_back (
921+ b.createOrFold <tensor::DimOp>(loc, packOp.getDest (), dim));
922+ }
923+ continue ;
924+ }
925+
926+ // If the dimension needs padding, it is not supported because there are
927+ // iterations that only write padding values to the whole tile. The
928+ // consumer fusion is driven by the source, so it is not possible to map
929+ // an empty slice to the tile.
930+ bool needExtraPadding =
931+ ShapedType::isDynamic (destDimSize) || !cstInnerSize ||
932+ destDimSize * cstInnerSize.value () != srcDimSize;
933+ // Prioritize the case that the op already says that it does not need
934+ // padding.
935+ if (!packOp.getPaddingValue ())
936+ needExtraPadding = false ;
937+ if (needExtraPadding)
938+ return failure ();
939+
910940 // Currently fusing `packOp` as consumer only expects perfect tiling
911941 // scenario because even if without padding semantic, the `packOp` may
912942 // also yield incomplete tiles. E.g. tensor<30xf32> -> tensor<5x6xf32>,
@@ -921,9 +951,9 @@ struct PackOpTiling
921951 // another word, we can only support tiling with consumer if the tile
922952 // size for the producer is a multiple of the inner tile size for the
923953 // packed dimensions at this moment.
924- if (failed (cstSize) || !cstInnerSize || *cstSize % *cstInnerSize != 0 ) {
954+ if ((failed (cstTileSize) || !cstInnerSize ||
955+ *cstTileSize % *cstInnerSize != 0 ))
925956 return failure ();
926- }
927957
928958 using AV = affine::AffineValueExpr;
929959 affine::AffineBuilder ab (b, loc);
@@ -988,7 +1018,8 @@ struct PackOpTiling
9881018 loc, packOp.getDest (), outputOffsets, outputSizes, strides);
9891019 tiledOperands.push_back (outSlice);
9901020
991- assert (!packOp.getPaddingValue () && " Expect no padding semantic" );
1021+ if (auto val = packOp.getPaddingValue ())
1022+ tiledOperands.push_back (val);
9921023 for (auto tile : packOp.getInnerTiles ())
9931024 tiledOperands.push_back (tile);
9941025
0 commit comments