@@ -832,12 +832,12 @@ getUntiledProducerFromSliceSource(OpOperand *source,
832
832
return {dyn_cast<OpResult>(source->get ()), destinationIterArg};
833
833
}
834
834
835
- // / Implementation of fusing producer of a single slice by computing the
835
+ // / Basic implementation of fusing producer of a single slice by computing the
836
836
// / slice of the producer in-place.
837
- std::optional<scf::SCFFuseProducerOfSliceResult>
838
- mlir::scf::tileAndFuseProducerOfSlice (
839
- RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
840
- MutableArrayRef<LoopLikeOpInterface> loops) {
837
+ static std::optional<scf::SCFFuseProducerOfSliceResult>
838
+ tileAndFuseProducerOfSliceImpl (RewriterBase &rewriter,
839
+ tensor::ExtractSliceOp candidateSliceOp,
840
+ MutableArrayRef<LoopLikeOpInterface> loops) {
841
841
// 1. Get the producer of the source (potentially walking through
842
842
// `iter_args` of nested `scf.for`)
843
843
auto [fusableProducer, destinationInitArg] =
@@ -949,6 +949,139 @@ mlir::scf::tileAndFuseProducerOfSlice(
949
949
tileAndFuseResult->tiledOps };
950
950
}
951
951
952
+ // / Get the Root source of target ExtractSliceOp
953
+ // / %0 =
954
+ // / %1 = scf.for(%arg1 = %0)
955
+ // / %2 = extract %arg1
956
+ // / %3 = scf.for(%arg2 = %2)
957
+ // / %4 = extract %args2
958
+ // / ...
959
+ // / @param targetSliceOp: %4 = extract %args2
960
+ // / @param extractSliceOpChain: chain of all related extract sliceOp
961
+ // / @return Value of Root Source : %0
962
+ static FailureOr<Value> getRootSourceOfExtractSliceOp (
963
+ Operation *targetSliceOp,
964
+ SmallVectorImpl<tensor::ExtractSliceOp> &extractSliceOpChain,
965
+ int curDepth = 0 , int maxDepth = 5 ) {
966
+ assert (isa<tensor::ExtractSliceOp>(targetSliceOp));
967
+ // control recursive time in avoid of stack overflow
968
+ if (curDepth > maxDepth)
969
+ return failure ();
970
+
971
+ auto extractOp = cast<tensor::ExtractSliceOp>(targetSliceOp);
972
+ extractSliceOpChain.push_back (extractOp);
973
+ Value rootSource = extractOp.getSourceMutable ().get ();
974
+
975
+ while (true ) {
976
+ if (auto iterArg = dyn_cast<BlockArgument>(rootSource)) {
977
+ if (auto outerLoop = dyn_cast<LoopLikeOpInterface>(
978
+ iterArg.getOwner ()->getParentOp ())) {
979
+ rootSource = outerLoop.getTiedLoopInit (iterArg)->get ();
980
+ continue ;
981
+ }
982
+ return failure ();
983
+ } else if (auto sliceOp =
984
+ rootSource.getDefiningOp <tensor::ExtractSliceOp>()) {
985
+ // walk up loop to find larger candidate extractSliceOp
986
+ return getRootSourceOfExtractSliceOp (sliceOp, extractSliceOpChain,
987
+ curDepth + 1 );
988
+ }
989
+ break ;
990
+ }
991
+ return rootSource;
992
+ }
993
+
994
+ // / Recursively find the outer nest loops of given loop(included) while the
995
+ // / predict function succeed, sorted from outer to inner.
996
+ // /
997
+ // / @param loop: target loop, note that this loop will be also included. I.e.
998
+ // / if no other nest loops were found, just return itself.
999
+ // / @param pred: predict function, the termination condition of recursive
1000
+ // / process.
1001
+ // / @return Outer Nest Loops: nest loops outside given target loop(included).
1002
+ // /
1003
+ // / E.g.
1004
+ // /
1005
+ // / ```
1006
+ // / %0 = scf.for()
1007
+ // / %1 = scf.for()
1008
+ // / %2 = scf.for()
1009
+ // / ```
1010
+ // /
1011
+ // / If `%2 = scf.for` is given without specific prediction function, this
1012
+ // / function will return three nest loops: %0 + %1 + %2.
1013
+ static SmallVector<LoopLikeOpInterface> getOuterNestLoopsWhile (
1014
+ LoopLikeOpInterface loop,
1015
+ const std::function<LogicalResult(LoopLikeOpInterface)> &pred) {
1016
+ SmallVector<LoopLikeOpInterface> nestLoops = {loop};
1017
+ auto outerLoop = dyn_cast<LoopLikeOpInterface>(loop->getParentOp ());
1018
+ while (outerLoop && succeeded (pred (outerLoop))) {
1019
+ nestLoops.push_back (outerLoop);
1020
+ outerLoop = dyn_cast<LoopLikeOpInterface>(outerLoop->getParentOp ());
1021
+ }
1022
+ // sorted from outer to inner
1023
+ return {nestLoops.rbegin (), nestLoops.rend ()};
1024
+ }
1025
+
1026
+ // / Enhanced version for basic implementation of fusing producer, which can deal
1027
+ // / with multi-level candidates. E.g.
1028
+ // /
1029
+ // / ```
1030
+ // / %0 = untiled_producer
1031
+ // / %1 = scf.for(%arg1 = %0)
1032
+ // / %2 = tensor.extract_slice %arg1
1033
+ // / %3 = scf.for(%arg2 = %2)
1034
+ // / %4 = tensor.extract_slice %args2
1035
+ // / %5 = tiled_consumer ins(%4)
1036
+ // / ```
1037
+ // /
1038
+ // / This utility can fuse untiled producer at `%4 = tensor.extract_slice` within
1039
+ // / inner loop `%3 = scf.for`.
1040
+ std::optional<scf::SCFFuseProducerOfSliceResult>
1041
+ mlir::scf::tileAndFuseProducerOfSlice (RewriterBase &rewriter,
1042
+ Operation *candidateSliceOp) {
1043
+ SmallVector<tensor::ExtractSliceOp> sliceOpChain;
1044
+ if (failed (getRootSourceOfExtractSliceOp (candidateSliceOp, sliceOpChain))) {
1045
+ return std::nullopt;
1046
+ }
1047
+
1048
+ std::optional<scf::SCFFuseProducerOfSliceResult> fuseProducerResult;
1049
+ // reverse from outer to inner
1050
+ std::reverse (sliceOpChain.begin (), sliceOpChain.end ());
1051
+ // multiple application of `tileAndFuseProducerOfSliceImpl`
1052
+ for (auto &&[index, sliceOp] : llvm::enumerate (sliceOpChain)) {
1053
+ // get nest loops between next candidate sliceOp and tiled producer.
1054
+ auto whileProducerOutOfLoopBlock =
1055
+ [&fuseProducerResult](LoopLikeOpInterface loop) -> LogicalResult {
1056
+ if (fuseProducerResult) {
1057
+ Block &body = loop->getRegion (0 ).front ();
1058
+ if (fuseProducerResult->tiledAndFusedProducer .getDefiningOp ()
1059
+ ->getBlock () == &body)
1060
+ return failure ();
1061
+ }
1062
+ return success ();
1063
+ };
1064
+ SmallVector<LoopLikeOpInterface> outerLoops =
1065
+ getOuterNestLoopsWhile (sliceOp->getParentOfType <LoopLikeOpInterface>(),
1066
+ whileProducerOutOfLoopBlock);
1067
+ fuseProducerResult =
1068
+ tileAndFuseProducerOfSliceImpl (rewriter, sliceOp, outerLoops);
1069
+ if (!fuseProducerResult) {
1070
+ return std::nullopt;
1071
+ }
1072
+ }
1073
+ return fuseProducerResult;
1074
+ }
1075
+
1076
+ // / Implementation of fusing producer of a single slice by computing the
1077
+ // / slice of the producer in-place.
1078
+ std::optional<scf::SCFFuseProducerOfSliceResult>
1079
+ mlir::scf::tileAndFuseProducerOfSlice (
1080
+ RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
1081
+ MutableArrayRef<LoopLikeOpInterface> loops) {
1082
+ return tileAndFuseProducerOfSliceImpl (rewriter, candidateSliceOp, loops);
1083
+ }
1084
+
952
1085
// / Reconstruct the fused producer from within the tiled-and-fused code.
953
1086
LogicalResult mlir::scf::yieldReplacementForFusedProducer (
954
1087
RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
0 commit comments