@@ -221,10 +221,21 @@ static SmallVector<int64_t> computeOuterDims(ArrayRef<int64_t> perm,
221221// / inner_dims_pos = [0]
222222// / inner_tiles = [8]
223223// / into %init : tensor<?xf32> -> tensor<?x8xf32>
224- static FailureOr<std::tuple<Value, AffineMap>>
225- getOrCreatePackedViewOfOperand (OpBuilder &b, Location loc, PackInfo packInfo,
226- GenericOp genericOp, OpOperand *opOperand,
227- bool poisonPaddingOk) {
224+
225+ struct PackedOperandDetails {
226+ SmallVector<OpFoldResult> innerTileSizes;
227+ SmallVector<int64_t > innerDimsPos;
228+ SmallVector<int64_t > outerDimsPerm;
229+ AffineMap indexingMap;
230+ };
231+
232+ // / Helper function for getOrCreatePackedViewOfOperand that populates
233+ // / the details of the packedOperand that needs to be formed and also
234+ // returns if the packing would require padding.
235+ static bool getPackedOperandDetails (
236+ OpBuilder &b, PackInfo packInfo, GenericOp genericOp, OpOperand *opOperand,
237+ DenseMap<OpOperand *, PackedOperandDetails> &packedOperandMap) {
238+ PackedOperandDetails currOperandDetails;
228239 int64_t numOrigLoops = genericOp.getNumLoops ();
229240 int64_t numInnerLoops = packInfo.getNumTiledLoops ();
230241 int64_t numLoops = numOrigLoops + numInnerLoops;
@@ -233,9 +244,12 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
233244 SmallVector<AffineExpr> exprs (origIndexingMap.getResults ());
234245
235246 // If the OpOperand is a scalar or a zero-rank tensor, no need to pack.
236- if (genericOp.isScalar (opOperand) || exprs.empty ())
237- return std::make_tuple (opOperand->get (),
238- AffineMap::get (numLoops, 0 , exprs, b.getContext ()));
247+ if (genericOp.isScalar (opOperand) || exprs.empty ()) {
248+ currOperandDetails.indexingMap =
249+ AffineMap::get (numLoops, 0 , exprs, b.getContext ());
250+ packedOperandMap[opOperand] = currOperandDetails;
251+ return false ;
252+ }
239253
240254 // Step 1. Construct the information of packing data dimensions; append inner
241255 // dimensions to the indexing maps for the operand.
@@ -283,32 +297,57 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
283297 exprs = auxVec;
284298 }
285299 }
286- auto indexingMap = AffineMap::get (numLoops, 0 , exprs, b.getContext ());
300+ currOperandDetails.indexingMap =
301+ AffineMap::get (numLoops, 0 , exprs, b.getContext ());
287302
288303 // The operand does not have dimensions that relates to pack op.
289- if (innerDimsPos.empty () && outerDimsPerm.empty ())
290- return std::make_tuple (opOperand->get (), indexingMap);
304+ if (innerDimsPos.empty () && outerDimsPerm.empty ()) {
305+ packedOperandMap[opOperand] = currOperandDetails;
306+ return false ;
307+ }
291308 auto inputType = cast<RankedTensorType>(opOperand->get ().getType ());
292- auto maybeIntInnerTileSizes = getConstantIntValues (innerTileSizes);
293- if (!maybeIntInnerTileSizes.has_value ()) {
294- return failure ();
309+
310+ auto maybeIntInnerTileSizes =
311+ llvm::map_to_vector (innerTileSizes, [](OpFoldResult ofr) -> int64_t {
312+ std::optional<int64_t > maybeCst = getConstantIntValue (ofr);
313+ return maybeCst.value_or (ShapedType::kDynamic );
314+ });
315+ bool requirePadding = linalg::PackOp::requirePaddingValueStrict (
316+ inputType.getShape (), innerDimsPos,
317+ linalg::PackOp::inferPackedType (inputType, maybeIntInnerTileSizes,
318+ innerDimsPos, outerDimsPerm)
319+ .getShape (),
320+ outerDimsPerm, innerTileSizes);
321+ currOperandDetails.innerDimsPos = innerDimsPos;
322+ currOperandDetails.innerTileSizes = innerTileSizes;
323+ currOperandDetails.outerDimsPerm = outerDimsPerm;
324+ packedOperandMap[opOperand] = currOperandDetails;
325+
326+ if (requirePadding)
327+ return true ;
328+ return false ;
329+ }
330+
331+ static std::tuple<Value, AffineMap> getOrCreatePackedViewOfOperand (
332+ OpBuilder &b, Location loc, OpOperand *opOperand,
333+ DenseMap<OpOperand *, PackedOperandDetails> packedOperandMap) {
334+ assert (packedOperandMap.contains (opOperand) &&
335+ " packed operand details expected to be populated" );
336+ auto currOperandDetails = packedOperandMap[opOperand];
337+ auto innerDimsPos = currOperandDetails.innerDimsPos ;
338+ auto outerDimsPerm = currOperandDetails.outerDimsPerm ;
339+ auto innerTileSizes = currOperandDetails.innerTileSizes ;
340+ if (innerDimsPos.empty () && outerDimsPerm.empty ()) {
341+ return std::make_tuple (opOperand->get (), currOperandDetails.indexingMap );
295342 }
296- if (!poisonPaddingOk &&
297- linalg::PackOp::requirePaddingValueStrict (
298- inputType.getShape (), innerDimsPos,
299- linalg::PackOp::inferPackedType (inputType, *maybeIntInnerTileSizes,
300- innerDimsPos, outerDimsPerm)
301- .getShape (),
302- outerDimsPerm, innerTileSizes))
303- return failure ();
304343 auto empty = linalg::PackOp::createDestinationTensor (
305344 b, loc, opOperand->get (), innerTileSizes, innerDimsPos, outerDimsPerm);
306345 auto poison = ub::PoisonOp::create (
307346 b, loc, getElementTypeOrSelf (opOperand->get ().getType ()));
308347 Value packedOperand =
309348 linalg::PackOp::create (b, loc, opOperand->get (), empty, innerDimsPos,
310349 innerTileSizes, poison, outerDimsPerm);
311- return std::make_tuple (packedOperand, indexingMap);
350+ return std::make_tuple (packedOperand, currOperandDetails. indexingMap );
312351}
313352
314353// / This function is a helper subroutine to pack a genericOp and return it. It
@@ -330,14 +369,18 @@ packGenericOp(RewriterBase &rewriter, GenericOp genericOp, Value dest,
330369 packOp.getInnerDimsPos () == unPackOp.getInnerDimsPos () &&
331370 llvm::equal (packOp.getMixedTiles (), unPackOp.getMixedTiles ());
332371 };
372+ DenseMap<OpOperand *, PackedOperandDetails> packedOperandMap;
373+ bool requiresPadding = false ;
333374 for (OpOperand *inputOperand : genericOp.getDpsInputOperands ()) {
334- auto mayBepackedOperandAndIndexing = getOrCreatePackedViewOfOperand (
335- rewriter, loc, packInfo, genericOp, inputOperand, poisonPaddingOk);
336- if (failed (mayBepackedOperandAndIndexing)) {
337- return failure ();
338- }
339- auto packedOperand = std::get<0 >(*mayBepackedOperandAndIndexing);
340- auto packedIndexingMap = std::get<1 >(*mayBepackedOperandAndIndexing);
375+ requiresPadding |= getPackedOperandDetails (rewriter, packInfo, genericOp,
376+ inputOperand, packedOperandMap);
377+ }
378+ if (requiresPadding && !poisonPaddingOk) {
379+ return failure ();
380+ }
381+ for (OpOperand *inputOperand : genericOp.getDpsInputOperands ()) {
382+ auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand (
383+ rewriter, loc, inputOperand, packedOperandMap);
341384 auto unpackOp = inputOperand->get ().getDefiningOp <linalg::UnPackOp>();
342385 auto packOp = packedOperand.getDefiningOp <linalg::PackOp>();
343386 if (packOp && unpackOp && hasEquivalentTiles (packOp, unpackOp)) {
@@ -492,15 +535,15 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp,
492535 }
493536
494537 // Rebuild the indexing map for the corresponding init operand.
495- auto mayBepackedOperandAndIndexing =
496- getOrCreatePackedViewOfOperand (rewriter, genericOp. getLoc (), *packInfo,
497- genericOp, opOperand, poisonPaddingOk );
498- if (failed (mayBepackedOperandAndIndexing) ) {
538+ DenseMap<OpOperand *, PackedOperandDetails> packedOperandMap;
539+ bool requiresPadding = getPackedOperandDetails (rewriter, *packInfo, genericOp ,
540+ opOperand, packedOperandMap );
541+ if (requiresPadding && !poisonPaddingOk ) {
499542 return failure ();
500543 }
501- auto packedOutOperand = std::get< 0 >(*mayBepackedOperandAndIndexing);
502- auto packedOutIndexingMap = std::get< 1 >(*mayBepackedOperandAndIndexing);
503-
544+ auto [ packedOutOperand, packedOutIndexingMap] =
545+ getOrCreatePackedViewOfOperand (rewriter, genericOp. getLoc (), opOperand,
546+ packedOperandMap);
504547 // Forward the new tensor.empty as a destination if it is one of the following
505548 // situations:
506549 // 1) The dps init operand is a tensor.empty.
@@ -1139,14 +1182,17 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
11391182 return failure ();
11401183
11411184 // Rebuild the indexing map for the corresponding init operand.
1142- auto mayBepackedOperandAndIndexing = getOrCreatePackedViewOfOperand (
1143- rewriter, genericOp.getLoc (), *packInfo, genericOp,
1144- genericOp.getDpsInitOperand (0 ), poisonPaddingOk);
1145- if (failed (mayBepackedOperandAndIndexing)) {
1185+ DenseMap<OpOperand *, PackedOperandDetails> packedOperandMap;
1186+ bool requiresPadding =
1187+ getPackedOperandDetails (rewriter, *packInfo, genericOp,
1188+ genericOp.getDpsInitOperand (0 ), packedOperandMap);
1189+ if (requiresPadding && !poisonPaddingOk) {
11461190 return failure ();
11471191 }
1148- auto packedOutOperand = std::get<0 >(*mayBepackedOperandAndIndexing);
1149- auto packedOutIndexingMap = std::get<1 >(*mayBepackedOperandAndIndexing);
1192+ auto [packedOutOperand, packedOutIndexingMap] =
1193+ getOrCreatePackedViewOfOperand (rewriter, genericOp.getLoc (),
1194+ genericOp.getDpsInitOperand (0 ),
1195+ packedOperandMap);
11501196 auto destPack = packedOutOperand.getDefiningOp <linalg::PackOp>();
11511197
11521198 // Forward the new tensor.empty as a destination if it is one of the following
0 commit comments