@@ -246,6 +246,120 @@ struct PackOpTiling
246246 return failure ();
247247 return tilingResult.value ();
248248 }
249+
250+ // / Method to return the position of iteration domain tile computed by the
251+ // / tiled operation. In current `tensor.pack` context, the `resultOffsets` and
252+ // / `resultSizes` only cover outer dimensions.
253+ LogicalResult getIterationDomainTileFromOperandTile (
254+ Operation *op, OpBuilder &b, unsigned operandNumber,
255+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
256+ SmallVectorImpl<OpFoldResult> &resultOffsets,
257+ SmallVectorImpl<OpFoldResult> &resultSizes) const {
258+ if (operandNumber != 0 )
259+ return failure ();
260+
261+ auto packOp = cast<PackOp>(op);
262+ // It is not trivial to infer dest tile from source tile if `packOp` has
263+ // padding semantic.
264+ if (packOp.getPaddingValue ())
265+ return failure ();
266+
267+ Location loc = packOp.getLoc ();
268+
269+ SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
270+ DenseMap<int64_t , OpFoldResult> dimAndTileMapping =
271+ packOp.getDimAndTileMapping ();
272+ for (auto dim : packOp.getOuterDimsPerm ()) {
273+ if (dimAndTileMapping.count (dim)) {
274+ FailureOr<int64_t > cstSize =
275+ ValueBoundsConstraintSet::computeConstantBound (
276+ presburger::BoundType::UB, sizes[dim],
277+ /* stopCondition=*/ nullptr , /* closedUB=*/ true );
278+ std::optional<int64_t > cstInnerSize =
279+ getConstantIntValue (dimAndTileMapping[dim]);
280+ // Currently fusing `packOp` as consumer only expects perfect tiling
281+ // scenario because even if without padding semantic, the `packOp` may
282+ // also yield incomplete tiles. E.g. tensor<30xf32> -> tensor<5x6xf32>,
283+ // where the `tileSize` from operand of `packOp` is 5, which is not
284+ // exactly divided by `innerTile`(=6) of `packOp`. As the result:
285+ // 1. the first slice is extracted from (0) to (4) and inserted into
286+ // (0,0)~(0,4) at first row.
287+ // 2. the second slice is extracted from (5) to (9) and SHOULD BE
288+ // respectively inserted into two rows with different length, including
289+ // first row: (0,5) and second row (1,0)~(1,3). It is hard to coordinate
290+ // them, thus adding below constraint to bypass them temporarily. In
291+ // another word, we can only support tiling with consumer if the tile
292+ // size for the producer is a multiple of the inner tile size for the
293+ // packed dimensions at this moment.
294+ if (failed (cstSize) || !cstInnerSize || *cstSize % *cstInnerSize != 0 ) {
295+ return failure ();
296+ }
297+
298+ using AV = affine::AffineValueExpr;
299+ affine::AffineBuilder ab (b, loc);
300+ AffineExpr dim0, sym;
301+ bindDims (b.getContext (), dim0);
302+ bindSymbols (b.getContext (), sym);
303+ auto avOffset = AV (dim0).bind (offsets[dim]);
304+ auto avSize = AV (dim0).bind (sizes[dim]);
305+ auto avTileSize = AV (sym).bind (dimAndTileMapping[dim]);
306+ outerDimOffsets.push_back (ab.floor (avOffset, avTileSize));
307+ outerDimSizes.push_back (ab.ceil (avSize, avTileSize));
308+ } else {
309+ outerDimOffsets.push_back (offsets[dim]);
310+ outerDimSizes.push_back (sizes[dim]);
311+ }
312+ }
313+
314+ resultOffsets = outerDimOffsets;
315+ resultSizes = outerDimSizes;
316+ return success ();
317+ }
318+
319+ // / Method to return the tiled implementation of tensor.pack as a consumer.
320+ FailureOr<TilingResult> getTiledImplementationFromOperandTile (
321+ Operation *op, OpBuilder &b, unsigned operandNumber,
322+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
323+ if (operandNumber != 0 )
324+ return failure ();
325+
326+ auto packOp = cast<PackOp>(op);
327+ Location loc = packOp.getLoc ();
328+
329+ int64_t inputRank = packOp.getSourceRank ();
330+ auto oneAttr = b.getI64IntegerAttr (1 );
331+ SmallVector<OpFoldResult> strides (inputRank, oneAttr);
332+
333+ SmallVector<Value> tiledOperands;
334+ tiledOperands.push_back (b.create <ExtractSliceOp>(loc, packOp.getSource (),
335+ offsets, sizes, strides));
336+
337+ SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
338+ if (failed (getIterationDomainTileFromOperandTile (
339+ op, b, /* operandNumber=*/ 0 , offsets, sizes, outerDimOffsets,
340+ outerDimSizes)))
341+ return failure ();
342+
343+ SmallVector<OpFoldResult> outputOffsets, outputSizes;
344+ if (failed (getResultTilePosition (op, b, 0 , outerDimOffsets, outerDimSizes,
345+ outputOffsets, outputSizes)))
346+ return failure ();
347+
348+ strides.append (packOp.getDestRank () - inputRank, oneAttr);
349+ auto extractSlice = b.create <ExtractSliceOp>(
350+ loc, packOp.getDest (), outputOffsets, outputSizes, strides);
351+ tiledOperands.push_back (extractSlice);
352+
353+ assert (!packOp.getPaddingValue () && " Expect no padding semantic" );
354+ for (auto tile : packOp.getInnerTiles ())
355+ tiledOperands.push_back (tile);
356+
357+ Operation *tiledPackOp = b.create <PackOp>(
358+ loc, TypeRange{extractSlice.getType ()}, tiledOperands, op->getAttrs ());
359+
360+ return TilingResult{{tiledPackOp},
361+ SmallVector<Value>(tiledPackOp->getResults ())};
362+ }
249363};
250364
251365struct UnpackTileDimInfo {
0 commit comments