@@ -246,6 +246,97 @@ struct PackOpTiling
246246 return failure ();
247247 return tilingResult.value ();
248248 }
249+
250+ // / Method to return the position of iteration domain tile computed by the
251+ // / tiled operation. In current `tensor.pack` context, the `resultOffsets` and
252+ // / `resultSizes` only cover outer dimensions.
253+ LogicalResult getIterationDomainTileFromOperandTile (
254+ Operation *op, OpBuilder &b, unsigned operandNumber,
255+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
256+ SmallVectorImpl<OpFoldResult> &resultOffsets,
257+ SmallVectorImpl<OpFoldResult> &resultSizes) const {
258+ auto packOp = cast<PackOp>(op);
259+ Location loc = packOp.getLoc ();
260+
261+ SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
262+ DenseMap<int64_t , OpFoldResult> dimAndTileMapping =
263+ packOp.getDimAndTileMapping ();
264+ for (auto dim : packOp.getOuterDimsPerm ()) {
265+ if (dimAndTileMapping.count (dim)) {
266+ FailureOr<int64_t > cstSize =
267+ ValueBoundsConstraintSet::computeConstantBound (
268+ presburger::BoundType::UB, sizes[dim],
269+ /* stopCondition=*/ nullptr , /* closedUB=*/ true );
270+ std::optional<int64_t > cstInnerSize =
271+ getConstantIntValue (dimAndTileMapping[dim]);
272+ // Currently only expect perfect tiling cases.
273+ if (failed (cstSize) || !cstInnerSize || *cstSize % *cstInnerSize != 0 ) {
274+ return failure ();
275+ }
276+
277+ using AV = affine::AffineValueExpr;
278+ affine::AffineBuilder ab (b, loc);
279+ AffineExpr dim0, sym;
280+ bindDims (b.getContext (), dim0);
281+ bindSymbols (b.getContext (), sym);
282+ auto avOffset = AV (dim0).bind (offsets[dim]);
283+ auto avSize = AV (dim0).bind (sizes[dim]);
284+ auto avTileSize = AV (sym).bind (dimAndTileMapping[dim]);
285+ outerDimOffsets.push_back (ab.floor (avOffset, avTileSize));
286+ outerDimSizes.push_back (ab.ceil (avSize, avTileSize));
287+ } else {
288+ outerDimOffsets.push_back (offsets[dim]);
289+ outerDimSizes.push_back (sizes[dim]);
290+ }
291+ }
292+
293+ resultOffsets = outerDimOffsets;
294+ resultSizes = outerDimSizes;
295+ return success ();
296+ }
297+
298+ // / Method to return the tiled implementation of tensor.pack as a consumer.
299+ FailureOr<TilingResult> getTiledImplementationFromOperandTile (
300+ Operation *op, OpBuilder &b, unsigned operandNumber,
301+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
302+ auto packOp = cast<PackOp>(op);
303+ Location loc = packOp.getLoc ();
304+
305+ int64_t inputRank = packOp.getSourceRank ();
306+ auto oneAttr = b.getI64IntegerAttr (1 );
307+ SmallVector<OpFoldResult> strides (inputRank, oneAttr);
308+
309+ SmallVector<Value> tiledOperands;
310+ tiledOperands.push_back (b.create <ExtractSliceOp>(loc, packOp.getSource (),
311+ offsets, sizes, strides));
312+
313+ SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
314+ if (failed (getIterationDomainTileFromOperandTile (
315+ op, b, /* operandNumber=*/ 0 , offsets, sizes, outerDimOffsets,
316+ outerDimSizes)))
317+ return failure ();
318+
319+ SmallVector<OpFoldResult> outputOffsets, outputSizes;
320+ if (failed (getResultTilePosition (op, b, 0 , outerDimOffsets, outerDimSizes,
321+ outputOffsets, outputSizes)))
322+ return failure ();
323+
324+ strides.append (packOp.getDestRank () - inputRank, oneAttr);
325+ auto extractSlice = b.create <ExtractSliceOp>(
326+ loc, packOp.getDest (), outputOffsets, outputSizes, strides);
327+ tiledOperands.push_back (extractSlice);
328+
329+ if (auto val = packOp.getPaddingValue ())
330+ tiledOperands.push_back (val);
331+ for (auto tile : packOp.getInnerTiles ())
332+ tiledOperands.push_back (tile);
333+
334+ Operation *tiledPackOp = b.create <PackOp>(
335+ loc, TypeRange{extractSlice.getType ()}, tiledOperands, op->getAttrs ());
336+
337+ return TilingResult{{tiledPackOp},
338+ SmallVector<Value>(tiledPackOp->getResults ())};
339+ }
249340};
250341
251342struct UnpackTileDimInfo {
0 commit comments