Skip to content

Commit 4dd5214

Browse files
committed
fix second portion comment
1 parent 3ad0e30 commit 4dd5214

File tree

4 files changed

+127
-128
lines changed

4 files changed

+127
-128
lines changed

lib/gc/Transforms/IterativeTilingAndFusion.cpp

Lines changed: 87 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,14 @@ getClosestExtractSliceOfOperand(OpOperand &operand) {
4646
}
4747

4848
Operation *defineOp = operand.get().getDefiningOp();
49-
if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(defineOp)) {
49+
if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(defineOp))
5050
return sliceOp;
51-
} else if (isa<linalg::FillOp, tensor::ExpandShapeOp,
52-
tensor::CollapseShapeOp>(defineOp)) {
53-
// For downstream cases
51+
// For downstream cases
52+
if (isa<linalg::FillOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp>(
53+
defineOp))
5454
return getClosestExtractSliceOfOperand(defineOp->getOpOperand(0));
55-
} else {
56-
return failure();
57-
}
55+
56+
return failure();
5857
}
5958

6059
static FailureOr<OffsetSizeAndStrideOpInterface>
@@ -104,9 +103,8 @@ struct CandidateDefOrUse {
104103
using CandidateSliceFilter = std::function<LogicalResult(
105104
RewriterBase &, OffsetSizeAndStrideOpInterface, CandidateDefOrUse)>;
106105

107-
using CandidateSliceComparer =
108-
std::function<int(RewriterBase &, OffsetSizeAndStrideOpInterface,
109-
OffsetSizeAndStrideOpInterface, CandidateDefOrUse)>;
106+
using CandidateSliceComparer = std::function<int(
107+
OffsetSizeAndStrideOpInterface, OffsetSizeAndStrideOpInterface)>;
110108

111109
static LogicalResult
112110
noTilingOnReductionFilter(RewriterBase &rewriter,
@@ -205,10 +203,9 @@ exactTilingOnPackUnPackFilter(RewriterBase &rewriter,
205203
return success();
206204
}
207205

208-
static LogicalResult
209-
alreadyTiledOpFilter(RewriterBase &rewriter,
210-
OffsetSizeAndStrideOpInterface candidate,
211-
CandidateDefOrUse defOrUse) {
206+
static LogicalResult unTiledOpFilter(RewriterBase &rewriter,
207+
OffsetSizeAndStrideOpInterface candidate,
208+
CandidateDefOrUse defOrUse) {
212209
// In general tiledOp would not have uses any more.
213210
return failure(defOrUse.ownerOp->use_empty());
214211
}
@@ -294,7 +291,7 @@ struct CandidateSliceFilterPipeLine
294291

295292
SmallVector<CandidateSliceFilter> getDefaultPipeLine() {
296293
return SmallVector<CandidateSliceFilter>{
297-
alreadyTiledOpFilter, NonContractionOpFilter, noTilingOnReductionFilter,
294+
unTiledOpFilter, NonContractionOpFilter, noTilingOnReductionFilter,
298295
exactTilingOnPackUnPackFilter, SingleCandidateInBlockFilter};
299296
}
300297

@@ -325,22 +322,19 @@ computeTileSizeProductOfCandidate(OffsetSizeAndStrideOpInterface candidate) {
325322
return totalSize;
326323
}
327324

328-
static int TilingSizeComparer(RewriterBase &rewriter,
329-
OffsetSizeAndStrideOpInterface candidateA,
330-
OffsetSizeAndStrideOpInterface candidateB,
331-
CandidateDefOrUse defOrUse) {
325+
static int TilingSizeComparer(OffsetSizeAndStrideOpInterface candidateA,
326+
OffsetSizeAndStrideOpInterface candidateB) {
332327
FailureOr<int64_t> sizeProductA =
333328
computeTileSizeProductOfCandidate(candidateA),
334329
sizeProductB =
335330
computeTileSizeProductOfCandidate(candidateB);
336331
if (failed(sizeProductA) || failed(sizeProductB))
337332
return 0;
338333
// deal with equality
339-
if (*sizeProductA == *sizeProductB) {
334+
if (*sizeProductA == *sizeProductB)
340335
return 0;
341-
} else {
342-
return *sizeProductA < *sizeProductB ? -1 : 1;
343-
}
336+
337+
return *sizeProductA < *sizeProductB ? -1 : 1;
344338
}
345339

346340
struct CandidateSliceComparerPipeLine
@@ -352,17 +346,15 @@ struct CandidateSliceComparerPipeLine
352346
return SmallVector<CandidateSliceComparer>{TilingSizeComparer};
353347
}
354348

355-
bool compare(RewriterBase &rewriter,
356-
OffsetSizeAndStrideOpInterface candidateA,
357-
OffsetSizeAndStrideOpInterface candidateB,
358-
CandidateDefOrUse defOrUse) const {
349+
bool compare(OffsetSizeAndStrideOpInterface candidateA,
350+
OffsetSizeAndStrideOpInterface candidateB) const {
359351
// deal with weak order
360352
int cmpResult = -1;
361-
for (auto &fn : candidateProcessFn) {
362-
cmpResult = fn(rewriter, candidateA, candidateB, defOrUse);
363-
if (cmpResult != 0)
364-
break;
365-
}
353+
llvm::any_of(candidateProcessFn, [&cmpResult, &candidateA, &candidateB](
354+
const CandidateSliceComparer &fn) {
355+
cmpResult = fn(candidateA, candidateB);
356+
return cmpResult != 0;
357+
});
366358
return cmpResult == -1;
367359
}
368360
};
@@ -389,6 +381,29 @@ struct CandidateSliceOptions {
389381
}
390382
};
391383

384+
static FailureOr<OffsetSizeAndStrideOpInterface> filterAndSelectCandidate(
385+
RewriterBase &rewriter,
386+
ArrayRef<OffsetSizeAndStrideOpInterface> candidateSliceList,
387+
const CandidateDefOrUse &defOrUse, const CandidateSliceOptions &options) {
388+
SmallVector<OffsetSizeAndStrideOpInterface> validCandidates =
389+
llvm::to_vector(llvm::make_filter_range(
390+
candidateSliceList,
391+
[&rewriter, &options,
392+
&defOrUse](const OffsetSizeAndStrideOpInterface &candidate) {
393+
return succeeded(
394+
options.filterPipeLine.filter(rewriter, candidate, defOrUse));
395+
}));
396+
if (validCandidates.empty())
397+
return failure();
398+
399+
OffsetSizeAndStrideOpInterface bestCandidate = *llvm::min_element(
400+
validCandidates, [&options](OffsetSizeAndStrideOpInterface &candidateA,
401+
OffsetSizeAndStrideOpInterface &candidateB) {
402+
return options.comparerPipeLine.compare(candidateA, candidateB);
403+
});
404+
return bestCandidate;
405+
}
406+
392407
std::optional<scf::SCFFuseProducerOfSliceResult>
393408
tileAndFuseProducerOfOpOperand(RewriterBase &rewriter, OpOperand &operand,
394409
const CandidateSliceOptions &options) {
@@ -412,31 +427,20 @@ tileAndFuseProducerOfOpOperand(RewriterBase &rewriter, OpOperand &operand,
412427
return std::nullopt;
413428

414429
CandidateDefOrUse defOrUse{*realProducer};
415-
// d. Filter out invalid candidates
416-
SmallVector<tensor::ExtractSliceOp> validCandidates =
417-
llvm::to_vector(llvm::make_filter_range(
418-
backwardSlice,
419-
[&rewriter, &options, &defOrUse](tensor::ExtractSliceOp &candidate) {
420-
return succeeded(options.filterPipeLine.filter(
421-
rewriter,
422-
cast<OffsetSizeAndStrideOpInterface>(candidate.getOperation()),
423-
defOrUse));
424-
}));
425-
if (validCandidates.empty())
430+
// d. Filter out invalid candidates and select best candidates
431+
SmallVector<OffsetSizeAndStrideOpInterface> ossBackwardSlice =
432+
llvm::map_to_vector(backwardSlice,
433+
[](tensor::ExtractSliceOp &extractSlice) {
434+
return cast<OffsetSizeAndStrideOpInterface>(
435+
extractSlice.getOperation());
436+
});
437+
FailureOr<OffsetSizeAndStrideOpInterface> bestCandidate =
438+
filterAndSelectCandidate(rewriter, ossBackwardSlice, defOrUse, options);
439+
if (failed(bestCandidate))
426440
return std::nullopt;
427-
// e. Select best candidates by Cost Model
428-
tensor::ExtractSliceOp bestCandidate = *llvm::min_element(
429-
validCandidates,
430-
[&rewriter, &options, &defOrUse](tensor::ExtractSliceOp &candidateA,
431-
tensor::ExtractSliceOp &candidateB) {
432-
return options.comparerPipeLine.compare(
433-
rewriter,
434-
cast<OffsetSizeAndStrideOpInterface>(candidateA.getOperation()),
435-
cast<OffsetSizeAndStrideOpInterface>(candidateB.getOperation()),
436-
defOrUse);
437-
});
438-
// f. call tiling interface
439-
return scfX::tileAndFuseProducerOfSlice(rewriter, bestCandidate);
441+
442+
// e. call tiling interface
443+
return scfX::tileAndFuseProducerOfSlice(rewriter, *bestCandidate);
440444
}
441445

442446
std::optional<SmallVector<scf::SCFFuseConsumerOfSliceResult>>
@@ -464,28 +468,15 @@ tileAndFuseConsumerOfOpResult(RewriterBase &rewriter, OpResult result,
464468
continue;
465469

466470
CandidateDefOrUse defOrUse{useOperand};
467-
// d. Filter out invalid candidates
468-
SmallVector<OffsetSizeAndStrideOpInterface> validCandidates =
469-
llvm::to_vector(llvm::make_filter_range(
470-
forwardSlice, [&rewriter, &options, &defOrUse](
471-
const OffsetSizeAndStrideOpInterface &candidate) {
472-
return succeeded(
473-
options.filterPipeLine.filter(rewriter, candidate, defOrUse));
474-
}));
475-
if (validCandidates.empty())
471+
// d. Filter out invalid candidates and select best candidates
472+
FailureOr<OffsetSizeAndStrideOpInterface> bestCandidate =
473+
filterAndSelectCandidate(rewriter, forwardSlice, defOrUse, options);
474+
if (failed(bestCandidate))
476475
continue;
477476

478-
// e. Select best candidates by Cost Model
479-
OffsetSizeAndStrideOpInterface bestCandidate = *llvm::min_element(
480-
validCandidates, [&rewriter, &options, &defOrUse](
481-
const OffsetSizeAndStrideOpInterface &candidateA,
482-
const OffsetSizeAndStrideOpInterface &candidateB) {
483-
return options.comparerPipeLine.compare(rewriter, candidateA,
484-
candidateB, defOrUse);
485-
});
486-
// f. call tiling interface
477+
// e. call tiling interface
487478
FailureOr<scf::SCFFuseConsumerOfSliceResult> fusedResult =
488-
scfX::tileAndFuseConsumerOfSlice(rewriter, bestCandidate);
479+
scfX::tileAndFuseConsumerOfSlice(rewriter, *bestCandidate);
489480

490481
if (succeeded(fusedResult)) {
491482
fusedResultList.push_back(*fusedResult);
@@ -496,7 +487,7 @@ tileAndFuseConsumerOfOpResult(RewriterBase &rewriter, OpResult result,
496487
};
497488
SmallVector<LoopLikeOpInterface> outerLoops =
498489
scfX::getOuterNestLoopsWhile(
499-
bestCandidate->getParentOfType<LoopLikeOpInterface>(),
490+
(*bestCandidate)->getParentOfType<LoopLikeOpInterface>(),
500491
whileProducerOutOfLoopBlock);
501492
// g. Manually run cse on region which contains top-level loop of
502493
// candidate slice in avoid of conflict with subsequent
@@ -506,11 +497,10 @@ tileAndFuseConsumerOfOpResult(RewriterBase &rewriter, OpResult result,
506497
{*outerLoops.front()->getParentRegion()});
507498
}
508499
}
509-
if (fusedResultList.empty()) {
500+
if (fusedResultList.empty())
510501
return std::nullopt;
511-
} else {
512-
return fusedResultList;
513-
}
502+
503+
return fusedResultList;
514504
}
515505

516506
/// Target at following general topology:
@@ -527,7 +517,7 @@ tileAndFuseConsumerOfOpResult(RewriterBase &rewriter, OpResult result,
527517
LogicalResult iterativelyFuseProducerAndConsumerOfTiledOp(
528518
RewriterBase &rewriter, Operation *tiledOp,
529519
const CandidateSliceOptions &options) {
530-
int numTiledOps = 0;
520+
unsigned numTiledOps = 0;
531521
std::deque<Operation *> tiledOpList = {tiledOp};
532522
while (!tiledOpList.empty()) {
533523
tiledOp = tiledOpList.front();
@@ -552,7 +542,7 @@ LogicalResult iterativelyFuseProducerAndConsumerOfTiledOp(
552542
return success(numTiledOps > 1);
553543
}
554544

555-
/// What is single tiled op in loop?
545+
/// What is self tiled op compared with other fused op?
556546
/// E.g.
557547
/// %1 = scf.for(){
558548
/// %2 = scf.for(){
@@ -562,7 +552,7 @@ LogicalResult iterativelyFuseProducerAndConsumerOfTiledOp(
562552
/// yield %5
563553
/// }
564554
/// }
565-
static LogicalResult isSingleTiledOpInLoop(Operation *targetOp) {
555+
static LogicalResult isSelfTiledOp(Operation *targetOp) {
566556
// 0. check tilable
567557
if (!isa<TilingInterface>(targetOp))
568558
return failure();
@@ -694,16 +684,15 @@ static bool defaultTilingOfType(RewriterBase &rewriter, Operation *op,
694684
if (succeeded(tilingResult)) {
695685
rewriter.replaceOp(op, tilingResult->replacements);
696686
return true;
697-
} else {
698-
return false;
699687
}
688+
return false;
700689
}
701690

702691
void iterativeTilingAndFusionUntilExhaustion(
703692
RewriterBase &rewriter, func::FuncOp &f,
704693
const CandidateSliceOptions &sliceOptions, const OpTileSizeMap &tsMap) {
705694
// Collect untiled and tiled ops respectively
706-
llvm::SetVector<Operation *> singleTiledOpInLoop, unTiledOps;
695+
llvm::SetVector<Operation *> selfTiledOp, unTiledOps;
707696

708697
auto collectUnTiledOps = [&f, &unTiledOps]() -> bool {
709698
// Reset
@@ -712,8 +701,7 @@ void iterativeTilingAndFusionUntilExhaustion(
712701
f->walk<WalkOrder::PreOrder>([&unTiledOps](Operation *op) {
713702
if (isa<LoopLikeOpInterface>(op))
714703
return WalkResult::skip();
715-
716-
if (isa<TilingInterface>(op) && !op->use_empty()) {
704+
if (isa<TilingInterface>(op)) {
717705
auto parentLoop = op->getParentOfType<LoopLikeOpInterface>();
718706
if (!parentLoop.getOperation())
719707
unTiledOps.insert(op);
@@ -723,32 +711,32 @@ void iterativeTilingAndFusionUntilExhaustion(
723711
return !unTiledOps.empty();
724712
};
725713

726-
auto collectSingleTiledOpInLoop = [&f, &singleTiledOpInLoop]() -> bool {
714+
auto collectSelfTiledOp = [&f, &selfTiledOp]() -> bool {
727715
// Reset
728-
singleTiledOpInLoop.clear();
716+
selfTiledOp.clear();
729717
// Walk through funcOp
730-
f->walk([&singleTiledOpInLoop](Operation *op) {
718+
f->walk([&selfTiledOp](Operation *op) {
731719
// Target at certain kind of tiled op, such as matmul/conv implemented
732720
// by multiple level of nest loops and candidate slices for better
733721
// utilization of parallelism and memory hierarchy.
734-
if (succeeded(isSingleTiledOpInLoop(op))) {
735-
singleTiledOpInLoop.insert(op);
722+
if (succeeded(isSelfTiledOp(op))) {
723+
selfTiledOp.insert(op);
736724
}
737725
});
738-
return !singleTiledOpInLoop.empty();
726+
return !selfTiledOp.empty();
739727
};
740728

741729
// Iterative tiling and fusion until exhaustion.
742730
while (collectUnTiledOps()) {
743731
// If existing tiled op before tiling.
744-
if (collectSingleTiledOpInLoop()) {
732+
if (collectSelfTiledOp()) {
745733
// Sort by topology
746-
mlir::topologicalSort(singleTiledOpInLoop);
734+
mlir::topologicalSort(selfTiledOp);
747735
// Record if any fusion happens
748736
bool changed = false;
749737
// Iteratively fuse in forward and backward fashion.
750-
llvm::for_each(singleTiledOpInLoop, [&rewriter, &sliceOptions,
751-
&changed](Operation *tiledOp) {
738+
llvm::for_each(selfTiledOp, [&rewriter, &sliceOptions,
739+
&changed](Operation *tiledOp) {
752740
changed |= succeeded(iterativelyFuseProducerAndConsumerOfTiledOp(
753741
rewriter, tiledOp, sliceOptions));
754742
});
@@ -774,7 +762,7 @@ void iterativeTilingAndFusionUntilExhaustion(
774762
std::cref(tsMap)));
775763
})) {
776764
// If no op can be tiled
777-
break;
765+
return;
778766
}
779767
}
780768
}

0 commit comments

Comments
 (0)