Skip to content

Commit 3ad0e30

Browse files
committed
add FileCheck and fix comment
1 parent d0c456f commit 3ad0e30

File tree

4 files changed

+128
-77
lines changed

4 files changed

+128
-77
lines changed

lib/gc/Transforms/IterativeTilingAndFusion.cpp

Lines changed: 28 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,8 @@ static FailureOr<tensor::ExtractSliceOp>
4141
getClosestExtractSliceOfOperand(OpOperand &operand) {
4242
if (auto iterArg = dyn_cast<BlockArgument>(operand.get())) {
4343
if (auto loop =
44-
dyn_cast<LoopLikeOpInterface>(iterArg.getOwner()->getParentOp())) {
44+
dyn_cast<LoopLikeOpInterface>(iterArg.getOwner()->getParentOp()))
4545
return getClosestExtractSliceOfOperand(*loop.getTiedLoopInit(iterArg));
46-
}
4746
}
4847

4948
Operation *defineOp = operand.get().getDefiningOp();
@@ -69,10 +68,9 @@ getClosestInsertSliceOfResult(OpResult result) {
6968
sliceOp =
7069
dyn_cast<OffsetSizeAndStrideOpInterface>(useOfResult.getOwner());
7170
} else if (auto yieldOp = dyn_cast<scf::YieldOp>(useOfResult.getOwner())) {
72-
if (auto loop = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp())) {
71+
if (auto loop = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp()))
7372
return getClosestInsertSliceOfResult(
7473
loop->getResult(useOfResult.getOperandNumber()));
75-
}
7674
}
7775
}
7876

@@ -138,9 +136,8 @@ noTilingOnReductionFilter(RewriterBase &rewriter,
138136
presburger::BoundType::UB, tileSizes[resultExpr.index()], nullptr,
139137
true);
140138
if (!cstIterDomain || failed(cstTileSizes) ||
141-
cstIterDomain != cstTileSizes) {
139+
cstIterDomain != cstTileSizes)
142140
return failure();
143-
}
144141
}
145142
}
146143
return success();
@@ -246,9 +243,8 @@ SingleCandidateInBlockFilter(RewriterBase &rewriter,
246243
scfX::getRealProducerOfExtractSliceOp(otherCandidate,
247244
backwardSlice);
248245
if (succeeded(realProducer) &&
249-
realProducer->getDefiningOp() == defOrUse.ownerOp) {
246+
realProducer->getDefiningOp() == defOrUse.ownerOp)
250247
return failure();
251-
}
252248
} else {
253249
SmallVector<OffsetSizeAndStrideOpInterface> forwardSlice;
254250
FailureOr<SmallVector<OpOperand *>> realConsumers =
@@ -257,9 +253,8 @@ SingleCandidateInBlockFilter(RewriterBase &rewriter,
257253
if (succeeded(realConsumers) &&
258254
llvm::any_of(*realConsumers, [&defOrUse](OpOperand *use) {
259255
return use->getOwner() == defOrUse.ownerOp;
260-
})) {
256+
}))
261257
return failure();
262-
}
263258
}
264259
}
265260
}
@@ -338,9 +333,8 @@ static int TilingSizeComparer(RewriterBase &rewriter,
338333
computeTileSizeProductOfCandidate(candidateA),
339334
sizeProductB =
340335
computeTileSizeProductOfCandidate(candidateB);
341-
if (failed(sizeProductA) || failed(sizeProductB)) {
336+
if (failed(sizeProductA) || failed(sizeProductB))
342337
return 0;
343-
}
344338
// deal with equality
345339
if (*sizeProductA == *sizeProductB) {
346340
return 0;
@@ -401,17 +395,17 @@ tileAndFuseProducerOfOpOperand(RewriterBase &rewriter, OpOperand &operand,
401395
// a. Find the closest sliceOp
402396
FailureOr<tensor::ExtractSliceOp> closestSliceOp =
403397
getClosestExtractSliceOfOperand(operand);
404-
if (failed(closestSliceOp)) {
398+
if (failed(closestSliceOp))
405399
return std::nullopt;
406-
}
400+
407401
// b. Find the real producer and collect the sliceOp chain during backward
408402
// stage, sorted from inner to outer.
409403
SmallVector<tensor::ExtractSliceOp> backwardSlice;
410404
FailureOr<OpResult> realProducer =
411405
scfX::getRealProducerOfExtractSliceOp(*closestSliceOp, backwardSlice);
412-
if (failed(realProducer)) {
406+
if (failed(realProducer))
413407
return std::nullopt;
414-
}
408+
415409
// c. Check the producer of root source if is tilable.
416410
Operation *producer = realProducer->getDefiningOp<TilingInterface>();
417411
if (!producer)
@@ -451,17 +445,16 @@ tileAndFuseConsumerOfOpResult(RewriterBase &rewriter, OpResult result,
451445
// a. Find the closest sliceOp
452446
FailureOr<tensor::ExtractSliceOp> closestSliceOp =
453447
getClosestInsertSliceOfResult(result);
454-
if (failed(closestSliceOp)) {
448+
if (failed(closestSliceOp))
455449
return std::nullopt;
456-
}
450+
457451
// b. Find the real consumers and collect the sliceOp chain during forward
458452
// stage, sorted from inner to outer.
459453
SmallVector<OffsetSizeAndStrideOpInterface> forwardSlice;
460454
FailureOr<SmallVector<OpOperand *>> realConsumers =
461455
scfX::getRealConsumersFromInsertSliceOp(*closestSliceOp, forwardSlice);
462-
if (failed(realConsumers)) {
456+
if (failed(realConsumers))
463457
return std::nullopt;
464-
}
465458

466459
SmallVector<scf::SCFFuseConsumerOfSliceResult> fusedResultList;
467460
for (auto useOperand : *realConsumers) {
@@ -543,18 +536,16 @@ LogicalResult iterativelyFuseProducerAndConsumerOfTiledOp(
543536
// fuse producer
544537
for (OpOperand &operand : tiledOp->getOpOperands()) {
545538
if (std::optional<scf::SCFFuseProducerOfSliceResult> fuseProducerResult =
546-
tileAndFuseProducerOfOpOperand(rewriter, operand, options)) {
539+
tileAndFuseProducerOfOpOperand(rewriter, operand, options))
547540
tiledOpList.push_back(fuseProducerResult.value().tiledOps[0]);
548-
}
549541
}
550542
// fuse consumer(s)
551543
for (OpResult result : tiledOp->getResults()) {
552544
if (std::optional<SmallVector<scf::SCFFuseConsumerOfSliceResult>>
553545
fuseConsumerResults =
554546
tileAndFuseConsumerOfOpResult(rewriter, result, options)) {
555-
for (auto &fuseConsumerResult : *fuseConsumerResults) {
547+
for (auto &fuseConsumerResult : *fuseConsumerResults)
556548
tiledOpList.push_back(fuseConsumerResult.tiledOps[0]);
557-
}
558549
}
559550
}
560551
}
@@ -573,30 +564,26 @@ LogicalResult iterativelyFuseProducerAndConsumerOfTiledOp(
573564
/// }
574565
static LogicalResult isSingleTiledOpInLoop(Operation *targetOp) {
575566
// 0. check tilable
576-
if (!isa<TilingInterface>(targetOp)) {
567+
if (!isa<TilingInterface>(targetOp))
577568
return failure();
578-
}
579569
// 1. check parentOp
580570
auto forOp = targetOp->getParentOfType<LoopLikeOpInterface>();
581-
if (!forOp) {
571+
if (!forOp)
582572
return failure();
583-
}
584573
// 2. check single one tiling interface in loop body
585574
auto walkResult = forOp->walk([&targetOp](TilingInterface op) {
586575
// some special op maybe already deal with in template
587576
if (isa<linalg::FillOp, linalg::CopyOp>(op))
588577
return WalkResult::skip();
589578
return op != targetOp ? WalkResult::interrupt() : WalkResult::advance();
590579
});
591-
if (walkResult.wasInterrupted()) {
580+
if (walkResult.wasInterrupted())
592581
return failure();
593-
}
594582
// 3. check whether has either extract or insert slice op
595583
walkResult = forOp->walk(
596584
[](tensor::ExtractSliceOp) { return WalkResult::interrupt(); });
597-
if (walkResult.wasInterrupted()) {
585+
if (walkResult.wasInterrupted())
598586
return success();
599-
}
600587
walkResult = forOp->walk(
601588
[](tensor::InsertSliceOp) { return WalkResult::interrupt(); });
602589
return success(walkResult.wasInterrupted());
@@ -690,9 +677,8 @@ static bool defaultTilingOfType(RewriterBase &rewriter, Operation *op,
690677
// word, all reduction dimensions should not be tiled.
691678
if (iterType == utils::IteratorType::parallel &&
692679
(en != iteratorTypes.size() - 1 ||
693-
llvm::count(iteratorTypes, utils::IteratorType::reduction))) {
680+
llvm::count(iteratorTypes, utils::IteratorType::reduction)))
694681
defaultTileSize[en] = rewriter.getIndexAttr(1);
695-
}
696682
}
697683
}
698684
// If the tile sizes are all zero, no tiling would happen.
@@ -724,14 +710,13 @@ void iterativeTilingAndFusionUntilExhaustion(
724710
unTiledOps.clear();
725711
// Pre-order walk through funcOp
726712
f->walk<WalkOrder::PreOrder>([&unTiledOps](Operation *op) {
727-
if (isa<LoopLikeOpInterface>(op)) {
713+
if (isa<LoopLikeOpInterface>(op))
728714
return WalkResult::skip();
729-
}
715+
730716
if (isa<TilingInterface>(op) && !op->use_empty()) {
731717
auto parentLoop = op->getParentOfType<LoopLikeOpInterface>();
732-
if (!parentLoop.getOperation()) {
718+
if (!parentLoop.getOperation())
733719
unTiledOps.insert(op);
734-
}
735720
}
736721
return WalkResult::advance();
737722
});
@@ -767,9 +752,8 @@ void iterativeTilingAndFusionUntilExhaustion(
767752
changed |= succeeded(iterativelyFuseProducerAndConsumerOfTiledOp(
768753
rewriter, tiledOp, sliceOptions));
769754
});
770-
if (changed) {
755+
if (changed)
771756
(void)mlir::simplifyRegions(rewriter, {f.getRegion()});
772-
}
773757
} else {
774758
// Auto tiling with default tile size if no tiled op found. Follow tiling
775759
// priority based on OpTy: `Contraction`->`Reduction`->`Elementwise`.
@@ -803,15 +787,15 @@ static OpTileSizeMap defaultTileSizeParser(ArrayRef<std::string> strArgs) {
803787
for (auto str : strArgs) {
804788
str.erase(llvm::remove_if(str, llvm::isSpace), str.end());
805789
size_t pos = str.find(":");
806-
if (pos == std::string::npos) {
790+
if (pos == std::string::npos)
807791
llvm_unreachable(warning);
808-
}
792+
809793
std::string opType = str.substr(0, pos);
810794
std::string strTileSize = str.erase(0, pos + 1);
811795
if (strTileSize.size() <= 2 || strTileSize.front() != '{' ||
812-
strTileSize.back() != '}') {
796+
strTileSize.back() != '}')
813797
llvm_unreachable(warning);
814-
}
798+
815799
strTileSize = strTileSize.substr(1, strTileSize.size() - 2);
816800
SmallVector<int64_t> intTileSize;
817801
while ((pos = strTileSize.find(",")) != std::string::npos) {

lib/gc/Transforms/TilingUsingInterfaceX.cpp

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -46,16 +46,18 @@ static std::tuple<OpResult, std::optional<OpOperand *>>
4646
getUntiledProducerFromSliceSource(OpOperand *source,
4747
ArrayRef<LoopLikeOpInterface> loops) {
4848
std::optional<OpOperand *> destinationIterArg;
49-
auto loopIt = loops.rbegin();
50-
while (auto iterArg = dyn_cast<BlockArgument>(source->get())) {
51-
auto loop = *loopIt;
52-
if (iterArg.getOwner()->getParentOp() != loop)
53-
break;
54-
source = loop.getTiedLoopInit(iterArg);
55-
loopIt++;
49+
if (!loops.empty()) {
50+
auto loopIt = loops.rbegin();
51+
while (auto iterArg = dyn_cast<BlockArgument>(source->get())) {
52+
auto loop = *loopIt;
53+
if (iterArg.getOwner()->getParentOp() != loop)
54+
break;
55+
source = loop.getTiedLoopInit(iterArg);
56+
loopIt++;
57+
}
58+
if (loopIt == loops.rend())
59+
destinationIterArg = source;
5660
}
57-
if (loopIt == loops.rend())
58-
destinationIterArg = source;
5961
return {dyn_cast<OpResult>(source->get()), destinationIterArg};
6062
}
6163

@@ -190,8 +192,8 @@ tileAndFuseProducerOfSliceImpl(RewriterBase &rewriter,
190192
/// @return OpResult Producer : %0 = producer
191193
FailureOr<OpResult> mlir::scfX::getRealProducerOfExtractSliceOp(
192194
Operation *candidateSliceOp,
193-
SmallVector<tensor::ExtractSliceOp> &backwardSlice, int curDepth,
194-
int maxDepth) {
195+
SmallVector<tensor::ExtractSliceOp> &backwardSlice, unsigned curDepth,
196+
unsigned maxDepth) {
195197
if (!isa<tensor::ExtractSliceOp>(candidateSliceOp))
196198
return failure();
197199
// control recursive time in avoid of stack overflow
@@ -322,8 +324,8 @@ mlir::scfX::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
322324
FailureOr<SmallVector<OpOperand *>>
323325
mlir::scfX::getRealConsumersFromInsertSliceOp(
324326
Operation *candidateSliceOp,
325-
SmallVector<OffsetSizeAndStrideOpInterface> &forwardSlice, int curDepth,
326-
int maxDepth) {
327+
SmallVector<OffsetSizeAndStrideOpInterface> &forwardSlice,
328+
unsigned curDepth, unsigned maxDepth) {
327329
if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
328330
candidateSliceOp))
329331
return failure();
@@ -410,7 +412,8 @@ checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
410412
}
411413

412414
/// Fetches the FIRST OpOperand of the tilable user (and use) of the value `val`
413-
/// which implements `TilingInterface` and `DestinationStyleOpInterface`.
415+
/// within the same block, which implements `TilingInterface` and
416+
/// `DestinationStyleOpInterface` and has non-empty user list.
414417
/// Returns failure otherwise.
415418
static FailureOr<OpOperand *> getConsumerFromUses(Value val,
416419
Block *containingOpBlock) {
@@ -458,9 +461,8 @@ static LogicalResult isForOpYieldResultOfInnerLoop(LoopLikeOpInterface loop) {
458461
}
459462

460463
/// Fetch the untiled consumer of a scf.for's result which is yielded by a
461-
/// tensor.insert_slice. This function makes the following assumptions :
462-
/// 1. tensor.insert_slice has scf.yield as its only user.
463-
/// 2. scf.for's corresponding result has only one use.
464+
/// tensor.insert_slice. This function makes the following assumptions that
465+
/// tensor.insert_slice has scf.yield as its only user.
464466
static FailureOr<OpOperand *>
465467
getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
466468
if (failed(checkAssumptionForFusingConsumer(candidateSliceOp)))
@@ -564,15 +566,14 @@ static LogicalResult checkAssumptionForLoop(Operation *loopOp,
564566
// `ParallelInsertSlice` located inside `InParallelOp` has no same parent
565567
// block with any other types of operation. Thus, just redirecting to its
566568
// parent `InParallelOp`.
567-
if (isa<tensor::ParallelInsertSliceOp>(userOp)) {
569+
if (isa<tensor::ParallelInsertSliceOp>(userOp))
568570
userOp = userOp->getParentOfType<scf::InParallelOp>();
569-
}
570-
if (parentBlock != userOp->getBlock()) {
571+
572+
if (parentBlock != userOp->getBlock())
571573
return failure();
572-
}
573-
if (userOp->isBeforeInBlock(firstUserOfLoop)) {
574+
575+
if (userOp->isBeforeInBlock(firstUserOfLoop))
574576
firstUserOfLoop = userOp;
575-
}
576577
}
577578
// Find the last define of consumer
578579
for (Value operand : consumerOp->getOperands()) {
@@ -582,12 +583,10 @@ static LogicalResult checkAssumptionForLoop(Operation *loopOp,
582583
auto defineOp = operand.getDefiningOp();
583584
if (defineOp == loopOp)
584585
continue;
585-
if (!defineOp || parentBlock != defineOp->getBlock()) {
586+
if (!defineOp || parentBlock != defineOp->getBlock())
586587
return failure();
587-
}
588-
if (lastDefOfConsumer->isBeforeInBlock(defineOp)) {
588+
if (lastDefOfConsumer->isBeforeInBlock(defineOp))
589589
lastDefOfConsumer = defineOp;
590-
}
591590
}
592591
if (firstUserOfLoop->isBeforeInBlock(lastDefOfConsumer)) {
593592
// Try to move if possible

lib/gc/Transforms/TilingUsingInterfaceX.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,13 @@ SmallVector<LoopLikeOpInterface> getOuterNestLoopsWhile(
2020

2121
FailureOr<OpResult> getRealProducerOfExtractSliceOp(
2222
Operation *candidateSliceOp,
23-
SmallVector<tensor::ExtractSliceOp> &backwardSlice, int curDepth = 0,
24-
int maxDepth = 5);
23+
SmallVector<tensor::ExtractSliceOp> &backwardSlice, unsigned curDepth = 0,
24+
unsigned maxDepth = 5);
2525

2626
FailureOr<SmallVector<OpOperand *>> getRealConsumersFromInsertSliceOp(
2727
Operation *candidateSliceOp,
28-
SmallVector<OffsetSizeAndStrideOpInterface> &forwardSlice, int curDepth = 0,
29-
int maxDepth = 5);
28+
SmallVector<OffsetSizeAndStrideOpInterface> &forwardSlice, unsigned curDepth = 0,
29+
unsigned maxDepth = 5);
3030

3131
// Extension for upstream `tileAndFuseProducerOfSlice`
3232
std::optional<scf::SCFFuseProducerOfSliceResult>

0 commit comments

Comments
 (0)