@@ -41,9 +41,8 @@ static FailureOr<tensor::ExtractSliceOp>
41
41
getClosestExtractSliceOfOperand (OpOperand &operand) {
42
42
if (auto iterArg = dyn_cast<BlockArgument>(operand.get ())) {
43
43
if (auto loop =
44
- dyn_cast<LoopLikeOpInterface>(iterArg.getOwner ()->getParentOp ())) {
44
+ dyn_cast<LoopLikeOpInterface>(iterArg.getOwner ()->getParentOp ()))
45
45
return getClosestExtractSliceOfOperand (*loop.getTiedLoopInit (iterArg));
46
- }
47
46
}
48
47
49
48
Operation *defineOp = operand.get ().getDefiningOp ();
@@ -69,10 +68,9 @@ getClosestInsertSliceOfResult(OpResult result) {
69
68
sliceOp =
70
69
dyn_cast<OffsetSizeAndStrideOpInterface>(useOfResult.getOwner ());
71
70
} else if (auto yieldOp = dyn_cast<scf::YieldOp>(useOfResult.getOwner ())) {
72
- if (auto loop = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp ())) {
71
+ if (auto loop = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp ()))
73
72
return getClosestInsertSliceOfResult (
74
73
loop->getResult (useOfResult.getOperandNumber ()));
75
- }
76
74
}
77
75
}
78
76
@@ -138,9 +136,8 @@ noTilingOnReductionFilter(RewriterBase &rewriter,
138
136
presburger::BoundType::UB, tileSizes[resultExpr.index ()], nullptr ,
139
137
true );
140
138
if (!cstIterDomain || failed (cstTileSizes) ||
141
- cstIterDomain != cstTileSizes) {
139
+ cstIterDomain != cstTileSizes)
142
140
return failure ();
143
- }
144
141
}
145
142
}
146
143
return success ();
@@ -246,9 +243,8 @@ SingleCandidateInBlockFilter(RewriterBase &rewriter,
246
243
scfX::getRealProducerOfExtractSliceOp (otherCandidate,
247
244
backwardSlice);
248
245
if (succeeded (realProducer) &&
249
- realProducer->getDefiningOp () == defOrUse.ownerOp ) {
246
+ realProducer->getDefiningOp () == defOrUse.ownerOp )
250
247
return failure ();
251
- }
252
248
} else {
253
249
SmallVector<OffsetSizeAndStrideOpInterface> forwardSlice;
254
250
FailureOr<SmallVector<OpOperand *>> realConsumers =
@@ -257,9 +253,8 @@ SingleCandidateInBlockFilter(RewriterBase &rewriter,
257
253
if (succeeded (realConsumers) &&
258
254
llvm::any_of (*realConsumers, [&defOrUse](OpOperand *use) {
259
255
return use->getOwner () == defOrUse.ownerOp ;
260
- })) {
256
+ }))
261
257
return failure ();
262
- }
263
258
}
264
259
}
265
260
}
@@ -338,9 +333,8 @@ static int TilingSizeComparer(RewriterBase &rewriter,
338
333
computeTileSizeProductOfCandidate (candidateA),
339
334
sizeProductB =
340
335
computeTileSizeProductOfCandidate (candidateB);
341
- if (failed (sizeProductA) || failed (sizeProductB)) {
336
+ if (failed (sizeProductA) || failed (sizeProductB))
342
337
return 0 ;
343
- }
344
338
// deal with equality
345
339
if (*sizeProductA == *sizeProductB) {
346
340
return 0 ;
@@ -401,17 +395,17 @@ tileAndFuseProducerOfOpOperand(RewriterBase &rewriter, OpOperand &operand,
401
395
// a. Find the closest sliceOp
402
396
FailureOr<tensor::ExtractSliceOp> closestSliceOp =
403
397
getClosestExtractSliceOfOperand (operand);
404
- if (failed (closestSliceOp)) {
398
+ if (failed (closestSliceOp))
405
399
return std::nullopt;
406
- }
400
+
407
401
// b. Find the real producer and collect the sliceOp chain during backward
408
402
// stage, sorted from inner to outer.
409
403
SmallVector<tensor::ExtractSliceOp> backwardSlice;
410
404
FailureOr<OpResult> realProducer =
411
405
scfX::getRealProducerOfExtractSliceOp (*closestSliceOp, backwardSlice);
412
- if (failed (realProducer)) {
406
+ if (failed (realProducer))
413
407
return std::nullopt;
414
- }
408
+
415
409
// c. Check the producer of root source if is tilable.
416
410
Operation *producer = realProducer->getDefiningOp <TilingInterface>();
417
411
if (!producer)
@@ -451,17 +445,16 @@ tileAndFuseConsumerOfOpResult(RewriterBase &rewriter, OpResult result,
451
445
// a. Find the closest sliceOp
452
446
FailureOr<tensor::ExtractSliceOp> closestSliceOp =
453
447
getClosestInsertSliceOfResult (result);
454
- if (failed (closestSliceOp)) {
448
+ if (failed (closestSliceOp))
455
449
return std::nullopt;
456
- }
450
+
457
451
// b. Find the real consumers and collect the sliceOp chain during forward
458
452
// stage, sorted from inner to outer.
459
453
SmallVector<OffsetSizeAndStrideOpInterface> forwardSlice;
460
454
FailureOr<SmallVector<OpOperand *>> realConsumers =
461
455
scfX::getRealConsumersFromInsertSliceOp (*closestSliceOp, forwardSlice);
462
- if (failed (realConsumers)) {
456
+ if (failed (realConsumers))
463
457
return std::nullopt;
464
- }
465
458
466
459
SmallVector<scf::SCFFuseConsumerOfSliceResult> fusedResultList;
467
460
for (auto useOperand : *realConsumers) {
@@ -543,18 +536,16 @@ LogicalResult iterativelyFuseProducerAndConsumerOfTiledOp(
543
536
// fuse producer
544
537
for (OpOperand &operand : tiledOp->getOpOperands ()) {
545
538
if (std::optional<scf::SCFFuseProducerOfSliceResult> fuseProducerResult =
546
- tileAndFuseProducerOfOpOperand (rewriter, operand, options)) {
539
+ tileAndFuseProducerOfOpOperand (rewriter, operand, options))
547
540
tiledOpList.push_back (fuseProducerResult.value ().tiledOps [0 ]);
548
- }
549
541
}
550
542
// fuse consumer(s)
551
543
for (OpResult result : tiledOp->getResults ()) {
552
544
if (std::optional<SmallVector<scf::SCFFuseConsumerOfSliceResult>>
553
545
fuseConsumerResults =
554
546
tileAndFuseConsumerOfOpResult (rewriter, result, options)) {
555
- for (auto &fuseConsumerResult : *fuseConsumerResults) {
547
+ for (auto &fuseConsumerResult : *fuseConsumerResults)
556
548
tiledOpList.push_back (fuseConsumerResult.tiledOps [0 ]);
557
- }
558
549
}
559
550
}
560
551
}
@@ -573,30 +564,26 @@ LogicalResult iterativelyFuseProducerAndConsumerOfTiledOp(
573
564
// / }
574
565
static LogicalResult isSingleTiledOpInLoop (Operation *targetOp) {
575
566
// 0. check tilable
576
- if (!isa<TilingInterface>(targetOp)) {
567
+ if (!isa<TilingInterface>(targetOp))
577
568
return failure ();
578
- }
579
569
// 1. check parentOp
580
570
auto forOp = targetOp->getParentOfType <LoopLikeOpInterface>();
581
- if (!forOp) {
571
+ if (!forOp)
582
572
return failure ();
583
- }
584
573
// 2. check single one tiling interface in loop body
585
574
auto walkResult = forOp->walk ([&targetOp](TilingInterface op) {
586
575
// some special op maybe already deal with in template
587
576
if (isa<linalg::FillOp, linalg::CopyOp>(op))
588
577
return WalkResult::skip ();
589
578
return op != targetOp ? WalkResult::interrupt () : WalkResult::advance ();
590
579
});
591
- if (walkResult.wasInterrupted ()) {
580
+ if (walkResult.wasInterrupted ())
592
581
return failure ();
593
- }
594
582
// 3. check whether has either extract or insert slice op
595
583
walkResult = forOp->walk (
596
584
[](tensor::ExtractSliceOp) { return WalkResult::interrupt (); });
597
- if (walkResult.wasInterrupted ()) {
585
+ if (walkResult.wasInterrupted ())
598
586
return success ();
599
- }
600
587
walkResult = forOp->walk (
601
588
[](tensor::InsertSliceOp) { return WalkResult::interrupt (); });
602
589
return success (walkResult.wasInterrupted ());
@@ -690,9 +677,8 @@ static bool defaultTilingOfType(RewriterBase &rewriter, Operation *op,
690
677
// word, all reduction dimensions should not be tiled.
691
678
if (iterType == utils::IteratorType::parallel &&
692
679
(en != iteratorTypes.size () - 1 ||
693
- llvm::count (iteratorTypes, utils::IteratorType::reduction))) {
680
+ llvm::count (iteratorTypes, utils::IteratorType::reduction)))
694
681
defaultTileSize[en] = rewriter.getIndexAttr (1 );
695
- }
696
682
}
697
683
}
698
684
// If the tile sizes are all zero, no tiling would happen.
@@ -724,14 +710,13 @@ void iterativeTilingAndFusionUntilExhaustion(
724
710
unTiledOps.clear ();
725
711
// Pre-order walk through funcOp
726
712
f->walk <WalkOrder::PreOrder>([&unTiledOps](Operation *op) {
727
- if (isa<LoopLikeOpInterface>(op)) {
713
+ if (isa<LoopLikeOpInterface>(op))
728
714
return WalkResult::skip ();
729
- }
715
+
730
716
if (isa<TilingInterface>(op) && !op->use_empty ()) {
731
717
auto parentLoop = op->getParentOfType <LoopLikeOpInterface>();
732
- if (!parentLoop.getOperation ()) {
718
+ if (!parentLoop.getOperation ())
733
719
unTiledOps.insert (op);
734
- }
735
720
}
736
721
return WalkResult::advance ();
737
722
});
@@ -767,9 +752,8 @@ void iterativeTilingAndFusionUntilExhaustion(
767
752
changed |= succeeded (iterativelyFuseProducerAndConsumerOfTiledOp (
768
753
rewriter, tiledOp, sliceOptions));
769
754
});
770
- if (changed) {
755
+ if (changed)
771
756
(void )mlir::simplifyRegions (rewriter, {f.getRegion ()});
772
- }
773
757
} else {
774
758
// Auto tiling with default tile size if no tiled op found. Follow tiling
775
759
// priority based on OpTy: `Contraction`->`Reduction`->`Elementwise`.
@@ -803,15 +787,15 @@ static OpTileSizeMap defaultTileSizeParser(ArrayRef<std::string> strArgs) {
803
787
for (auto str : strArgs) {
804
788
str.erase (llvm::remove_if (str, llvm::isSpace), str.end ());
805
789
size_t pos = str.find (" :" );
806
- if (pos == std::string::npos) {
790
+ if (pos == std::string::npos)
807
791
llvm_unreachable (warning);
808
- }
792
+
809
793
std::string opType = str.substr (0 , pos);
810
794
std::string strTileSize = str.erase (0 , pos + 1 );
811
795
if (strTileSize.size () <= 2 || strTileSize.front () != ' {' ||
812
- strTileSize.back () != ' }' ) {
796
+ strTileSize.back () != ' }' )
813
797
llvm_unreachable (warning);
814
- }
798
+
815
799
strTileSize = strTileSize.substr (1 , strTileSize.size () - 2 );
816
800
SmallVector<int64_t > intTileSize;
817
801
while ((pos = strTileSize.find (" ," )) != std::string::npos) {
0 commit comments