Skip to content

Commit d130996

Browse files
committed
support fuse consumer into innerMost ConsumerAnchor
1 parent 5cc9bca commit d130996

File tree

3 files changed

+418
-377
lines changed

3 files changed

+418
-377
lines changed

lib/gc/Transforms/AnyTilableFusion.cpp

Lines changed: 17 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -39,26 +39,13 @@ namespace gc {
3939
#define GEN_PASS_DEF_ANYTILABLEFUSION
4040
#include "gc/Transforms/Passes.h.inc"
4141

42-
namespace {
43-
4442
struct SystemDesc {
4543
// get runtime OMP_NUM_THREADS
4644
uint32_t getNumThreads();
4745
// get cache size by cacheLevel
4846
size_t getCacheSize(uint8_t cacheLevel);
4947
};
5048

51-
SmallVector<LoopLikeOpInterface> static getOuterLoopsOfSliceOp(
52-
OffsetSizeAndStrideOpInterface sliceOp) {
53-
SmallVector<LoopLikeOpInterface> outerLoops;
54-
auto forOp = sliceOp->getParentOfType<LoopLikeOpInterface>();
55-
while (forOp) {
56-
outerLoops.push_back(forOp);
57-
forOp = forOp->getParentOfType<LoopLikeOpInterface>();
58-
}
59-
return {outerLoops.rbegin(), outerLoops.rend()};
60-
}
61-
6249
template <typename T> class FusionAnchorBase {
6350
static_assert(
6451
llvm::is_one_of<T, tensor::ExtractSliceOp, tensor::InsertSliceOp,
@@ -139,15 +126,15 @@ verifyTilableOpTileSizesOnDimAndTileMap(RewriterBase &rewriter, Operation *op,
139126
targetTileSizes =
140127
llvm::to_vector(tileSizes.take_back(targetInnerTileSizes.size()));
141128
} else // tileSize comes from OpOperand
142-
if (std::is_same<T, OffsetSizeAndStrideOpInterface>::value) {
143-
targetTileSizes = llvm::to_vector(tileSizes);
144-
DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
145-
packOp.getDimAndTileMapping();
146-
targetInnerTileSizes.resize(dimAndTileMapping.size());
147-
for (const auto &dimAndTile : dimAndTileMapping) {
148-
targetInnerTileSizes[dimAndTile.first] = dimAndTile.second;
129+
if (std::is_same<T, OffsetSizeAndStrideOpInterface>::value) {
130+
targetTileSizes = llvm::to_vector(tileSizes);
131+
DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
132+
packOp.getDimAndTileMapping();
133+
targetInnerTileSizes.resize(dimAndTileMapping.size());
134+
for (const auto &dimAndTile : dimAndTileMapping) {
135+
targetInnerTileSizes[dimAndTile.first] = dimAndTile.second;
136+
}
149137
}
150-
}
151138
} else if (auto unPackOp = dyn_cast<tensor::UnPackOp>(op)) {
152139
// tileSize comes from OpResult
153140
if (std::is_same<T, tensor::ExtractSliceOp>::value) {
@@ -159,11 +146,11 @@ verifyTilableOpTileSizesOnDimAndTileMap(RewriterBase &rewriter, Operation *op,
159146
targetInnerTileSizes[dimAndTile.first] = dimAndTile.second;
160147
}
161148
} else // tileSize comes from OpOperand
162-
if (std::is_same<T, OffsetSizeAndStrideOpInterface>::value) {
163-
targetInnerTileSizes = unPackOp.getInnerTiles();
164-
targetTileSizes =
165-
llvm::to_vector(tileSizes.take_back(targetInnerTileSizes.size()));
166-
}
149+
if (std::is_same<T, OffsetSizeAndStrideOpInterface>::value) {
150+
targetInnerTileSizes = unPackOp.getInnerTiles();
151+
targetTileSizes =
152+
llvm::to_vector(tileSizes.take_back(targetInnerTileSizes.size()));
153+
}
167154
}
168155

169156
// check tileSizes is full on or multiple of `inner_tile_size`
@@ -205,11 +192,7 @@ FusionAnchorBase<T>::selectCandidateByCostModel(RewriterBase &rewriter,
205192
if (candidateSliceOpList.empty())
206193
return failure();
207194
/// TODO: use cost model
208-
if (std::is_same<T, tensor::ExtractSliceOp>::value) {
209-
return cast<T>(candidateSliceOpList.front());
210-
} else {
211-
return cast<T>(candidateSliceOpList.back());
212-
}
195+
return cast<T>(candidateSliceOpList.front());
213196
}
214197

215198
// Target at tensor.extract_slice
@@ -361,98 +344,6 @@ getProducerFusionAnchorFromOpOperand(RewriterBase &rewriter,
361344
return failure();
362345
}
363346

364-
/** Get the Result of top-level Loop which yield the target InsertSliceOp
365-
*
366-
* %1 = scf.for
367-
* %2 = scf.for
368-
* %3 = scf.for
369-
* ...
370-
* %4 = insert
371-
* yield %4
372-
* %5 = insert %3
373-
* yield %5
374-
* yield %2
375-
*
376-
* @param targetSliceOp: %4 = insert
377-
* @return Result Value: %1
378-
* Collected insertSliceOp List during walk including targetSliceOp:
379-
* %4 = insert and %5 = insert %3
380-
*/
381-
FailureOr<std::pair<Value, SmallVector<OffsetSizeAndStrideOpInterface>>>
382-
getResultOfTopLevelLoopYieldInsertSliceOp(
383-
OffsetSizeAndStrideOpInterface targetSliceOp, int curDepth = 0) {
384-
// control recursive time in avoid of stack overflow
385-
if (curDepth > MAX_DEPTH)
386-
return failure();
387-
388-
SmallVector<OffsetSizeAndStrideOpInterface> candidateSliceOpList;
389-
candidateSliceOpList.push_back(targetSliceOp);
390-
Value resultOfLoop;
391-
if (auto sliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(
392-
targetSliceOp.getOperation())) {
393-
Value destValue = sliceOp.getDest();
394-
auto iterArg = cast<BlockArgument>(destValue);
395-
auto forallOp = dyn_cast<scf::ForallOp>(iterArg.getOwner()->getParentOp());
396-
if (!forallOp)
397-
return failure();
398-
resultOfLoop = forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
399-
} else if (auto sliceOp = dyn_cast<tensor::InsertSliceOp>(
400-
targetSliceOp.getOperation())) {
401-
Value resultValue = sliceOp.getResult();
402-
for (auto &useOperand : resultValue.getUses()) {
403-
if (auto yieldOp = dyn_cast<scf::YieldOp>(useOperand.getOwner())) {
404-
if (llvm::detail::isPresent(resultOfLoop))
405-
return failure();
406-
auto forOp = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp());
407-
if (!forOp)
408-
return failure();
409-
resultOfLoop = forOp->getResult(useOperand.getOperandNumber());
410-
}
411-
}
412-
}
413-
414-
if (!llvm::detail::isPresent(resultOfLoop))
415-
return failure();
416-
417-
while (true) {
418-
bool walkThroughOuterLoop = false;
419-
for (auto &useOperand : resultOfLoop.getUses()) {
420-
if (auto sliceOp =
421-
dyn_cast<tensor::ParallelInsertSliceOp>(useOperand.getOwner())) {
422-
auto resultAndSliceOpsPair =
423-
getResultOfTopLevelLoopYieldInsertSliceOp(sliceOp, curDepth + 1);
424-
if (failed(resultAndSliceOpsPair))
425-
return failure();
426-
candidateSliceOpList.append((*resultAndSliceOpsPair).second.begin(),
427-
(*resultAndSliceOpsPair).second.end());
428-
return std::make_pair((*resultAndSliceOpsPair).first,
429-
candidateSliceOpList);
430-
} else if (auto sliceOp =
431-
dyn_cast<tensor::InsertSliceOp>(useOperand.getOwner())) {
432-
auto resultAndSliceOpsPair =
433-
getResultOfTopLevelLoopYieldInsertSliceOp(sliceOp, curDepth + 1);
434-
if (failed(resultAndSliceOpsPair))
435-
return failure();
436-
candidateSliceOpList.append((*resultAndSliceOpsPair).second.begin(),
437-
(*resultAndSliceOpsPair).second.end());
438-
return std::make_pair((*resultAndSliceOpsPair).first,
439-
candidateSliceOpList);
440-
} else if (auto yieldOp = dyn_cast<scf::YieldOp>(useOperand.getOwner())) {
441-
// walk through outer loop
442-
auto forOp = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp());
443-
if (!forOp)
444-
return failure();
445-
resultOfLoop = forOp->getResult(useOperand.getOperandNumber());
446-
walkThroughOuterLoop = true;
447-
break;
448-
}
449-
}
450-
if (!walkThroughOuterLoop)
451-
break;
452-
}
453-
return std::make_pair(resultOfLoop, candidateSliceOpList);
454-
}
455-
456347
/**
457348
* Find the untiled Consumer op based on given OpResult of Tiled Op, E.g.
458349
*
@@ -493,7 +384,7 @@ getConsumerFusionAnchorFromOpResult(RewriterBase &rewriter,
493384
return failure();
494385

495386
auto resultAndSliceOpsPair =
496-
getResultOfTopLevelLoopYieldInsertSliceOp(sliceOp);
387+
scfX::getResultOfTopLevelLoopYieldInsertSliceOp(sliceOp);
497388
if (failed(resultAndSliceOpsPair))
498389
return failure();
499390

@@ -530,7 +421,7 @@ static Operation *preOpFuseProducerOfOpOperand(
530421
if (failed(candidateSliceOp)) {
531422
return nullptr;
532423
}
533-
auto outerLoops = getOuterLoopsOfSliceOp(*candidateSliceOp);
424+
auto outerLoops = scfX::getOuterLoopsOfSliceOp(*candidateSliceOp);
534425
std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
535426
scfX::tileAndFuseProducerOfSlice(rewriter, *candidateSliceOp, outerLoops);
536427

@@ -562,7 +453,7 @@ static SmallVector<Operation *> postOpFuseConsumerOfOpResult(
562453
std::optional<scf::SCFFuseConsumerOfSliceResult> fusedResult =
563454
scfX::tileAndFuseConsumerOfSlice(rewriter, *candidateSliceOp);
564455
if (fusedResult) {
565-
tiledConsumerList.push_back(fusedResult.value().tiledAndFusedConsumer);
456+
tiledConsumerList.push_back(fusedResult.value().tiledOps[0]);
566457
rewriter.eraseOp(consAnchor.getFusableOp());
567458
}
568459
}
@@ -712,6 +603,5 @@ struct AnyTilableFusion : public impl::AnyTilableFusionBase<AnyTilableFusion> {
712603
}
713604
};
714605

715-
} // namespace
716606
} // namespace gc
717607
} // namespace mlir

0 commit comments

Comments
 (0)