@@ -167,10 +167,11 @@ exactTilingOnPackUnPackFilter(RewriterBase &rewriter,
167
167
tileSizesOnInnerDims =
168
168
llvm::to_vector (ArrayRef (tileSizes).take_back (innerTiles.size ()));
169
169
} else {
170
- // Upstream doesn't implement `getTiledImplementationFromOperandTile`
171
- // interface of `packOp` so far. In another word, `packOp` could not be
172
- // fused as consumer. As a result, just return failure currently.
173
- return failure ();
170
+ // tileSize comes from OpOperand
171
+ ArrayRef<int64_t > innerDimPos = packOp.getInnerDimsPos ();
172
+ for (auto &pos : innerDimPos) {
173
+ tileSizesOnInnerDims.push_back (tileSizes[pos]);
174
+ }
174
175
}
175
176
} else if (auto unPackOp = dyn_cast<tensor::UnPackOp>(defOrUse.ownerOp )) {
176
177
innerTiles = unPackOp.getMixedTiles ();
@@ -478,8 +479,8 @@ tileAndFuseProducerOfOpOperand(RewriterBase &rewriter, OpOperand &operand,
478
479
return std::nullopt;
479
480
480
481
// c. Check the producer of root source if is tilable.
481
- Operation *producer = realProducer->getDefiningOp <TilingInterface>();
482
- if (!producer )
482
+ Operation *producerOp = realProducer->getDefiningOp <TilingInterface>();
483
+ if (!producerOp )
483
484
return std::nullopt;
484
485
485
486
CandidateDefOrUse defOrUse{*realProducer};
@@ -536,8 +537,8 @@ tileAndFuseConsumerOfOpResult(RewriterBase &rewriter, OpResult result,
536
537
SmallVector<scf::SCFFuseConsumerOfSliceResult> fusedResultList;
537
538
for (auto useOperand : *realConsumers) {
538
539
// c. Check the consumer of top level result if is tilable.
539
- Operation *consumer = dyn_cast<TilingInterface>(useOperand->getOwner ());
540
- if (!consumer )
540
+ Operation *consumerOp = dyn_cast<TilingInterface>(useOperand->getOwner ());
541
+ if (!consumerOp )
541
542
continue ;
542
543
543
544
CandidateDefOrUse defOrUse{useOperand};
@@ -559,7 +560,7 @@ tileAndFuseConsumerOfOpResult(RewriterBase &rewriter, OpResult result,
559
560
// f. Manually run cse on region which contains original consumer op in
560
561
// avoid of conflict with subsequent `tileAndFuseConsumerOfSlice` get nest
561
562
// loops between next candidate sliceOp and tiled producer.
562
- (void )mlir::simplifyRegions (rewriter, {*consumer ->getParentRegion ()});
563
+ (void )mlir::simplifyRegions (rewriter, {*consumerOp ->getParentRegion ()});
563
564
}
564
565
}
565
566
if (fusedResultList.empty ())
@@ -647,11 +648,18 @@ static LogicalResult isTiledOpInLoop(Operation *targetOp) {
647
648
648
649
using OpTileSizeMap = std::unordered_map<std::string, SmallVector<int64_t >>;
649
650
651
+ struct defaultTileConfig {
652
+ // OpTy-to-TileSize mapping
653
+ OpTileSizeMap tsMap;
654
+ // ND-tile size
655
+ unsigned ndTile;
656
+ };
657
+
650
658
// / Default Tiling function only effective for certain `OpTy` operation
651
659
static FailureOr<scf::SCFTilingResult>
652
660
defaultTilingOfType (RewriterBase &rewriter, Operation *op,
653
661
function_ref<bool (Operation *)> isaOpTy,
654
- const OpTileSizeMap &tsMap ) {
662
+ const defaultTileConfig &cfg ) {
655
663
// a. Check <OpTy>
656
664
if (!isa<TilingInterface>(op) || !isaOpTy (op))
657
665
return failure ();
@@ -672,18 +680,20 @@ defaultTilingOfType(RewriterBase &rewriter, Operation *op,
672
680
// Erase dialect name, such as Linalg or Tensor.
673
681
opName.erase (0 , opName.find (" ." ) + 1 );
674
682
675
- if (tsMap.count (opName)) {
676
- SmallVector<int64_t > userDefaultTileSize = tsMap.find (opName)->second ;
683
+ if (cfg. tsMap .count (opName)) {
684
+ SmallVector<int64_t > userDefaultTileSize = cfg. tsMap .find (opName)->second ;
677
685
defaultTileSize =
678
686
getAsOpFoldResult (rewriter.getI64ArrayAttr (userDefaultTileSize));
679
687
} else {
680
688
defaultTileSize.resize (iteratorTypes.size (), rewriter.getIndexAttr (0 ));
681
689
// Try tileSize from `32` to `16`.
682
690
SmallVector<int64_t > tsOrder = {32 , 16 };
683
- // Only 2D tile is expected.
684
- int tileDims = (isa<mlir::linalg::LinalgOp>(op) && !linalgx::isMatmulOp (op))
685
- ? cast<mlir::linalg::LinalgOp>(op).getNumReductionLoops ()
686
- : 0 ;
691
+ // Record how many dims have been tiled, including fully tiled, i.e.
692
+ // tileSize == dimSize.
693
+ unsigned nonOneTileDims =
694
+ (isa<mlir::linalg::LinalgOp>(op) && !linalgx::isMatmulOp (op))
695
+ ? cast<mlir::linalg::LinalgOp>(op).getNumReductionLoops ()
696
+ : 0 ;
687
697
// Reverse both of iteration type and domain from inner to outer.
688
698
std::reverse (iteratorTypes.begin (), iteratorTypes.end ());
689
699
std::reverse (iterationDomain.begin (), iterationDomain.end ());
@@ -692,21 +702,29 @@ defaultTilingOfType(RewriterBase &rewriter, Operation *op,
692
702
// All parallel iterator will be tiled by `32` or `16`. If need
693
703
// specified, please set option `defaultTileSize`, like `matmul:{64,64}`.
694
704
if (iterType == utils::IteratorType::parallel) {
695
- Range curDomain = iterationDomain[en];
696
- std::optional<int64_t > tripCount = mlir::constantTripCount (
697
- curDomain.offset , curDomain.size , curDomain.stride );
698
- if (tileDims >= 2 && en > 0 ) {
705
+ if (nonOneTileDims >= cfg.ndTile && en > 0 ) {
699
706
defaultTileSize[en] = rewriter.getIndexAttr (1 );
700
707
continue ;
701
- } else if (tripCount) {
708
+ }
709
+ Range curDomain = iterationDomain[en];
710
+ if (std::optional<int64_t > tripCount = mlir::constantTripCount (
711
+ curDomain.offset , curDomain.size , curDomain.stride )) {
712
+ // skip dummy tiling.
713
+ if (tripCount == 1 )
714
+ continue ;
702
715
for (auto &ts : tsOrder) {
703
- if (*tripCount % ts == 0 && *tripCount > ts) {
716
+ // If `tripCount` equals to `tileSize`, Do NOT explicitly tile it in
717
+ // avoid of non-zero offset.
718
+ if (*tripCount == ts)
719
+ break ;
720
+ if (*tripCount % ts == 0 ) {
704
721
defaultTileSize[en] = rewriter.getIndexAttr (ts);
705
722
break ;
706
723
}
707
724
}
708
725
}
709
- tileDims++;
726
+ // Fallback to fully tiled.
727
+ nonOneTileDims++;
710
728
}
711
729
}
712
730
}
@@ -731,7 +749,7 @@ defaultTilingOfType(RewriterBase &rewriter, Operation *op,
731
749
732
750
void iterativeTilingAndFusionUntilExhaustion (
733
751
RewriterBase &rewriter, func::FuncOp &f,
734
- const CandidateSliceOptions &sliceOptions, const OpTileSizeMap &tsMap ) {
752
+ const CandidateSliceOptions &sliceOptions, const defaultTileConfig &cfg ) {
735
753
// Collect untiled and tiled ops respectively
736
754
llvm::SetVector<Operation *> tiledOps, unTiledOps;
737
755
@@ -799,7 +817,7 @@ void iterativeTilingAndFusionUntilExhaustion(
799
817
for (auto &isaOpTy : priorityOpTypeOrder) {
800
818
for (auto &op : unTiledOps) {
801
819
FailureOr<scf::SCFTilingResult> tilingResult =
802
- defaultTilingOfType (rewriter, op, isaOpTy, tsMap );
820
+ defaultTilingOfType (rewriter, op, isaOpTy, cfg );
803
821
if (succeeded (tilingResult)) {
804
822
tiledOps.insert (tilingResult->tiledOps [0 ]);
805
823
rewriter.replaceOp (op, tilingResult->replacements );
@@ -881,8 +899,8 @@ struct IterativeTilingAndFusion
881
899
// Get rewriter
882
900
IRRewriter rewriter (&ctx);
883
901
// Run iterative fusion
884
- iterativeTilingAndFusionUntilExhaustion (rewriter, func, sliceOptions,
885
- tsMap);
902
+ iterativeTilingAndFusionUntilExhaustion (
903
+ rewriter, func, sliceOptions, defaultTileConfig{ tsMap, defaultNDTile} );
886
904
}
887
905
};
888
906
0 commit comments