Skip to content

[DO NOT REVIEW][mlir][scf] Extend consumer fuse to nested loop structure #94460

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

Yun-Fly
Copy link
Contributor

@Yun-Fly Yun-Fly commented Jun 5, 2024

The CI failure of #94190 could not be reproduced after several attempts on local, just create this mirror PR particularly for debugging purpose.

Please DO NOT REVIEW and ignore all changes for this.

Once root cause is figured out, I will close this PR.

Sorry for disturb again. Thanks.

Copy link

github-actions bot commented Jun 5, 2024

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be
notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write
permissions for the repository. In which case you can instead tag reviewers by
name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review
by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate
is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot
Copy link
Member

llvmbot commented Jun 5, 2024

@llvm/pr-subscribers-mlir-scf

@llvm/pr-subscribers-mlir

Author: None (Yun-Fly)

Changes

The CI failure of #94190 could not be reproduced after several attempts on local, just create this mirror PR particularly for debugging purpose.

Please DO NOT REVIEW and ignore all changes for this.

Once root cause is figured out, I will close this PR.

Sorry for disturb again. Thanks.


Patch is 47.42 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/94460.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (+517-221)
  • (modified) mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir (+96)
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index f3d6b7a530117..488a1c6585790 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/Utils.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -1103,98 +1104,6 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
 // tileAndFuseConsumerUsingSCF implementation.
 //===----------------------------------------------------------------------===//
 
-/// A utility function that checks whether the only use of the result of a
-/// tensor.insert_slice op is in a scf.yield op.
-static LogicalResult
-checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
-  Value result = candidateSliceOp.getResult();
-  Value::use_range uses = result.getUses();
-  if (!llvm::hasSingleElement(uses)) {
-    LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
-    return failure();
-  }
-  OpOperand &operandUse = (*uses.begin());
-  Operation *userOp = operandUse.getOwner();
-  if (!isa<scf::YieldOp>(userOp)) {
-    LLVM_DEBUG(llvm::dbgs()
-               << "Expected scf.yield to be the only user, but got -> "
-               << (*userOp));
-    return failure();
-  }
-  if (result.getDefiningOp()->getBlock() != userOp->getBlock()) {
-    LLVM_DEBUG(llvm::dbgs() << "Expected tensor.insert_slice and scf.yield to "
-                               "be in the same block\n");
-    return failure();
-  }
-  return success();
-}
-
-/// Fetches the OpOperand of the only user (and use) of the value `val` which
-/// implements `TilingInterface` and `DestinationStyleOpInterface`. Returns
-/// failure otherwise.
-static FailureOr<OpOperand *> getConsumerFromUses(Value val,
-                                                  Block *containingOpBlock) {
-  // Step 1. Check that the value has exactly one use.
-  if (!llvm::hasSingleElement(val.getUses()))
-    return failure();
-  // Step 2. Get uses.
-  OpOperand &operand = (*val.getUses().begin());
-  Operation *consumerOp = operand.getOwner();
-  // TODO: We have to init result of consumer before scf.for, use
-  //       DestinationStyleOpInterface to get result shape from init for now.
-  //       Add support for other op such as op has InferTypeOpInterface.
-  if (!isa<TilingInterface>(consumerOp) ||
-      !isa<DestinationStyleOpInterface>(consumerOp))
-    return failure();
-  if (containingOpBlock != consumerOp->getBlock())
-    return failure();
-  return &operand;
-}
-
-/// Fetch the untiled consumer of a scf.for's result which is yielded by a
-/// tensor.insert_slice. This function makes the following assumptions :
-/// 1.  tensor.insert_slice has scf.yield as its only user.
-/// 2.  scf.for's corresponding result has only one use.
-static FailureOr<OpOperand *>
-getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
-  if (failed(checkAssumptionForFusingConsumer(candidateSliceOp)))
-    return failure();
-  Value sliceResult = candidateSliceOp.getResult();
-  // Step 1. Fetch the corresponding output.
-  OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
-  unsigned resultNumber = yieldOpOperand.getOperandNumber();
-  // Step 2. Check containing op is scf.for.
-  Operation *containingOp = candidateSliceOp->getParentOp();
-  auto forOp = dyn_cast<scf::ForOp>(containingOp);
-  if (!forOp)
-    return failure();
-  Value resultingValue = forOp->getResult(resultNumber);
-
-  return getConsumerFromUses(resultingValue, containingOp->getBlock());
-}
-
-/// Fetch the first untiled consumer of a scf.forall's result which is yielded
-/// by a tensor.parallel_insert_slice.
-static FailureOr<OpOperand *>
-getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
-  // Step 1. Fetch the corresponding output
-  Value sliceDest = candidateSliceOp.getDest();
-  auto iterArg = dyn_cast<BlockArgument>(sliceDest);
-  if (!iterArg)
-    return failure();
-  Operation *containingOp = iterArg.getOwner()->getParentOp();
-  if (containingOp != candidateSliceOp->getParentOp()->getParentOp())
-    return failure();
-  // Step 2. Check that the containing op is scf.forall.
-  auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
-  if (!forallOp)
-    return failure();
-  Value resultingValue =
-      forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
-
-  return getConsumerFromUses(resultingValue, containingOp->getBlock());
-}
-
 /// This utility currently checks whether the loop either :-
 /// 1. Yields exactly one result.
 /// 2. Has consumer op as its first user and other users to be in the same
@@ -1220,31 +1129,117 @@ static LogicalResult checkAssumptionForLoop(Operation *loopOp,
   return success();
 }
 
-/// A utility to fetch an untiled consumer of
-/// tensor.insert_slice/tensor.parallel_insert_slice.
-static FailureOr<OpOperand *> getUntiledConsumerFromSlice(Operation *sliceOp) {
-  if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
-    return getUntiledConsumerFromSlice(insertSlice);
-  } else if (auto parallelInsertSlice =
-                 dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
-    return getUntiledConsumerFromSlice(parallelInsertSlice);
-  } else {
+/// Traverse and collect all outer loops of given sliceOp, sorted by
+/// outer-to-inner. If `untilLoop` found, stop walk through in advance.
+static SmallVector<LoopLikeOpInterface> getOuterLoopsOfSliceOp(
+    Operation *sliceOp,
+    std::optional<LoopLikeOpInterface> untilLoop = std::nullopt) {
+  assert(isa<OffsetSizeAndStrideOpInterface>(sliceOp));
+  SmallVector<LoopLikeOpInterface> outerLoops;
+  auto forOp = sliceOp->getParentOfType<LoopLikeOpInterface>();
+  while (forOp) {
+    outerLoops.push_back(forOp);
+    if (untilLoop.has_value() && *untilLoop == forOp)
+      break;
+    forOp = forOp->getParentOfType<LoopLikeOpInterface>();
+  }
+  return {outerLoops.rbegin(), outerLoops.rend()};
+}
+
+/// Get the result of top-level loop which yields the target InsertSliceOp. E.g
+/// ```
+/// %1 = scf.for
+///  %2 = scf.for
+///   %3 = scf.for
+///      ...
+///      %4 = insert
+///      yield %4
+///   %5 = insert %3
+///   yield %5
+///  yield %2
+/// ```
+/// @param targetSliceOp: %4 = insert
+/// @return first: resultValue: %1
+///         second: Collected insertSliceOp List during walk including
+///                   targetSliceOp: %4 = insert and %5 = insert %3
+static FailureOr<std::pair<Value, SmallVector<OffsetSizeAndStrideOpInterface>>>
+getResultOfTopLevelLoopYieldInsertSliceOp(Operation *targetSliceOp,
+                                          int curDepth = 0, int maxDepth = 5) {
+  assert(isa<OffsetSizeAndStrideOpInterface>(targetSliceOp));
+  // Control recursive time in avoid of stack overflow
+  if (curDepth > maxDepth)
+    return failure();
+
+  SmallVector<OffsetSizeAndStrideOpInterface> candidateSliceOpList;
+  candidateSliceOpList.push_back(
+      cast<OffsetSizeAndStrideOpInterface>(targetSliceOp));
+  Value resultOfLoop;
+  if (auto sliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(targetSliceOp)) {
+    Value destValue = sliceOp.getDest();
+    auto iterArg = cast<BlockArgument>(destValue);
+    auto forallOp = dyn_cast<scf::ForallOp>(iterArg.getOwner()->getParentOp());
+    if (!forallOp)
+      return failure();
+    resultOfLoop = forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
+  } else if (auto sliceOp = dyn_cast<tensor::InsertSliceOp>(targetSliceOp)) {
+    Value resultValue = sliceOp.getResult();
+    for (auto &useOperand : resultValue.getUses()) {
+      if (auto yieldOp = dyn_cast<scf::YieldOp>(useOperand.getOwner())) {
+        if (llvm::detail::isPresent(resultOfLoop))
+          return failure();
+        auto forOp = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp());
+        if (!forOp)
+          return failure();
+        resultOfLoop = forOp->getResult(useOperand.getOperandNumber());
+      }
+    }
+  }
+
+  if (!llvm::detail::isPresent(resultOfLoop))
     return failure();
+
+  while (true) {
+    bool walkThroughOuterLoop = false;
+    for (OpOperand &useOperand : resultOfLoop.getUses()) {
+      if (auto sliceOp =
+              dyn_cast<OffsetSizeAndStrideOpInterface>(useOperand.getOwner())) {
+        FailureOr<std::pair<Value, SmallVector<OffsetSizeAndStrideOpInterface>>>
+            resultAndSliceOpsPair = getResultOfTopLevelLoopYieldInsertSliceOp(
+                sliceOp, curDepth + 1);
+        if (failed(resultAndSliceOpsPair))
+          return failure();
+        candidateSliceOpList.append((*resultAndSliceOpsPair).second.begin(),
+                                    (*resultAndSliceOpsPair).second.end());
+        return std::make_pair((*resultAndSliceOpsPair).first,
+                              candidateSliceOpList);
+      } else if (auto yieldOp = dyn_cast<scf::YieldOp>(useOperand.getOwner())) {
+        // walk through outer loop
+        auto forOp = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp());
+        if (!forOp)
+          return failure();
+        resultOfLoop = forOp->getResult(useOperand.getOperandNumber());
+        walkThroughOuterLoop = true;
+        break;
+      }
+    }
+    if (!walkThroughOuterLoop)
+      break;
   }
+  return std::make_pair(resultOfLoop, candidateSliceOpList);
 }
 
 /// After fusing consumer into scf.for we want to modify the scf.yield operation
 /// to reflect the same by returning the values yielded by the tiled consumer.
 static void
 fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp,
-                      TilingResult &tilingResult,
-                      ArrayRef<SmallVector<OpFoldResult>> &resultOffsets,
-                      ArrayRef<SmallVector<OpFoldResult>> &resultSizes,
+                      ResultRange tilingResult,
+                      ArrayRef<SmallVector<OpFoldResult>> resultOffsets,
+                      ArrayRef<SmallVector<OpFoldResult>> resultSizes,
                       ArrayRef<BlockArgument> bbArgs) {
   scf::YieldOp oldTerminatorOp =
       cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
   unsigned totalOldResults = oldTerminatorOp->getNumResults();
-  unsigned totalTiledResults = tilingResult.tiledOps[0]->getNumResults();
+  unsigned totalTiledResults = tilingResult.size();
   SmallVector<Value> newYieldOperands;
   newYieldOperands.reserve(totalOldResults + totalTiledResults);
   for (auto oldResult : oldTerminatorOp.getResults()) {
@@ -1253,8 +1248,7 @@ fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp,
   rewriter.setInsertionPointAfter(oldTerminatorOp);
   Location loc = newForOp.getLoc();
   for (auto [tiledResult, bbArg, resultOffset, resultSize] :
-       llvm::zip_equal(tilingResult.tiledOps[0]->getResults(), bbArgs,
-                       resultOffsets, resultSizes)) {
+       llvm::zip_equal(tilingResult, bbArgs, resultOffsets, resultSizes)) {
     SmallVector<OpFoldResult> strides(resultOffset.size(),
                                       rewriter.getIndexAttr(1));
     Value newInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
@@ -1269,16 +1263,16 @@ fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp,
 /// values by the tiled consumer within scf.forall.in_parallel region.
 static void
 fixTerminatorSCFInParallel(RewriterBase &rewriter, scf::ForallOp newForallOp,
-                           SmallVector<Value> tiledResults,
-                           ArrayRef<SmallVector<OpFoldResult>> &resultOffsets,
-                           ArrayRef<SmallVector<OpFoldResult>> &resultSizes,
+                           ResultRange tilingResult,
+                           ArrayRef<SmallVector<OpFoldResult>> resultOffsets,
+                           ArrayRef<SmallVector<OpFoldResult>> resultSizes,
                            ArrayRef<BlockArgument> bbArgs) {
   scf::InParallelOp newTerminatorOp = newForallOp.getTerminator();
   rewriter.setInsertionPointToStart(newTerminatorOp.getBody());
   Location firstYieldOpLoc =
       (*(newTerminatorOp.getYieldingOps().begin())).getLoc();
   for (auto [tiledResult, bbArg, resultOffset, resultSize] :
-       llvm::zip_equal(tiledResults, bbArgs, resultOffsets, resultSizes)) {
+       llvm::zip_equal(tilingResult, bbArgs, resultOffsets, resultSizes)) {
     SmallVector<OpFoldResult> strides(resultOffset.size(),
                                       rewriter.getIndexAttr(1));
     rewriter.create<tensor::ParallelInsertSliceOp>(
@@ -1286,6 +1280,176 @@ fixTerminatorSCFInParallel(RewriterBase &rewriter, scf::ForallOp newForallOp,
   }
 }
 
+/// If the top level loop of nested loop structure is scf.forall, need to create
+/// additional tensor.extract_slice for its new appended `shared_outs` in order
+/// to pass correct local memory for inner loops. E.g.
+///
+/// scf.forall shared_outs(%o1=..., %o2=...) {
+///     %local_o1 = extract_slice %o1
+///     // fix new appended `shared_out` %o2
+///     %local_o2 = extract_slice %o2
+///     scf.for init_args(%init1=%local_o1, %init2=%local_o2) {
+///        ...
+///     }
+///     ...
+/// }
+static SmallVector<tensor::ExtractSliceOp> fixLoopInitFromSharedOutSCFForall(
+    RewriterBase &rewriter, Operation *loop, ValueRange newSharedOuts,
+    ArrayRef<SmallVector<OpFoldResult>> resultOffsets,
+    ArrayRef<SmallVector<OpFoldResult>> resultSizes) {
+  rewriter.setInsertionPoint(loop);
+  Location loc = loop->getLoc();
+  // create new ExtractOps for newInits from scf.forall
+  SmallVector<tensor::ExtractSliceOp> newExtractOps;
+  newExtractOps.reserve(resultOffsets.size());
+  for (auto [bbArg, offset, sizes] :
+       llvm::zip_equal(newSharedOuts, resultOffsets, resultSizes)) {
+    SmallVector<OpFoldResult> strides(offset.size(), rewriter.getIndexAttr(1));
+    auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
+        loc, bbArg, offset, sizes, strides);
+    newExtractOps.push_back(newExtractOp);
+  }
+  return newExtractOps;
+}
+
+/// If outerMost loop of nested loop structure is `scf.forall`, need to deal
+/// with DpsInit of tiled consumer
+static void
+fixDpsInitsOfTiledConsumer(RewriterBase &rewriter, Operation *tiledConsumer,
+                           ArrayRef<BlockArgument> bbArgs,
+                           ArrayRef<SmallVector<OpFoldResult>> resultOffsets,
+                           ArrayRef<SmallVector<OpFoldResult>> resultSizes) {
+  rewriter.setInsertionPoint(tiledConsumer);
+  Location loc = tiledConsumer->getLoc();
+  for (auto &&[bbArg, offset, sizes, dpsInit] :
+       llvm::zip_equal(bbArgs, resultOffsets, resultSizes,
+                       cast<DestinationStyleOpInterface>(tiledConsumer)
+                           .getDpsInitsMutable())) {
+    SmallVector<OpFoldResult> strides(offset.size(), rewriter.getIndexAttr(1));
+    auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
+        loc, bbArg, offset, sizes, strides);
+    dpsInit.set(newExtractOp.getResult());
+  }
+}
+
+/// Compute all results tile by given SliceOp along operand
+static LogicalResult computeAllResultTileForOpGivenOperandSliceOp(
+    RewriterBase &rewriter, TilingInterface tilableOp, unsigned operandNumber,
+    OffsetSizeAndStrideOpInterface ossSliceOp,
+    SmallVector<SmallVector<OpFoldResult>> &allResultOffsets,
+    SmallVector<SmallVector<OpFoldResult>> &allResultSizes) {
+  // 1. Check all stride all 1
+  if (llvm::any_of(ossSliceOp.getMixedStrides(), [](OpFoldResult stride) {
+        return !isConstantIntValue(stride, 1);
+      })) {
+    return rewriter.notifyMatchFailure(ossSliceOp, "ossSliceOp has stride");
+  }
+  // 2. Compute iteration domain tile from the input position
+  SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+  if (failed(tilableOp.getIterationDomainTileFromOperandTile(
+          rewriter, operandNumber, ossSliceOp.getMixedOffsets(),
+          ossSliceOp.getMixedSizes(), iterDomainOffsets, iterDomainSizes))) {
+    return rewriter.notifyMatchFailure(
+        tilableOp, "can't get iter domain position from input position");
+  }
+  unsigned totalNumResultsOfConsumer = tilableOp->getNumResults();
+  SmallVector<SmallVector<OpFoldResult>> resultOffsets(
+      totalNumResultsOfConsumer);
+  SmallVector<SmallVector<OpFoldResult>> resultSizes(totalNumResultsOfConsumer);
+  // 3. Compute result Tile by resultNumber
+  for (auto [idx, v] : llvm::enumerate(tilableOp->getResults())) {
+    if (failed(tilableOp.getResultTilePosition(
+            rewriter, idx, iterDomainOffsets, iterDomainSizes,
+            resultOffsets[idx], resultSizes[idx]))) {
+      return rewriter.notifyMatchFailure(
+          tilableOp,
+          "can't get result domain position from iter domain position");
+    }
+  }
+  allResultOffsets = resultOffsets;
+  allResultSizes = resultSizes;
+  return success();
+}
+
+/// Considering multi-level tensor.*SliceOp maybe based on different
+/// coordination, this utility computes the real OFFSET coordinated on ROOT
+/// SliceOp. E.g
+///             %0 = insert_slice %1 into %2[OFFSET1] [SIZE1]
+///         %3 = insert_slice %4 into %5[OFFSET2] [SIZE2]
+///
+/// where the coordination can be illustrated as follow:
+///
+///  %3 ----------------------------------
+///  |         |         |
+///  | OFFSET2 | OFFSET1 |
+///  | ------ %0         |
+///  |                   |
+///  |                   |
+///  |------------------ %1 ------ |
+///  |                   |  SIZE1  |
+///  |                   |         |
+///  |                   |         |
+///  |                   | ------- |
+///  |
+///
+/// The real OFFSET of %1 coordinated on %3 is actually `OFFSET1` + `OFFSET2`
+static FailureOr<SmallVector<OpFoldResult>>
+computeRealOffsetsCoordinatedRootSliceOp(
+    RewriterBase &rewriter, Location loc,
+    OffsetSizeAndStrideOpInterface candidateSliceOp,
+    MutableArrayRef<OffsetSizeAndStrideOpInterface> candidateSliceOpList) {
+  if (llvm::any_of(candidateSliceOp.getMixedStrides(), [](OpFoldResult stride) {
+        return !isConstantIntValue(stride, 1);
+      })) {
+    return rewriter.notifyMatchFailure(candidateSliceOp,
+                                       "candidateSliceOp has stride");
+  }
+  SmallVector<OpFoldResult> realOffsets = candidateSliceOp.getMixedOffsets();
+  // Real offsets equals to accumulative offsets of outer candidates
+  for (auto iter = candidateSliceOpList.rbegin(); *iter != candidateSliceOp;
+       iter++) {
+    // assert each outer candidate slice has no stride
+    if (llvm::any_of(iter->getMixedStrides(), [](OpFoldResult stride) {
+          return !isConstantIntValue(stride, 1);
+        })) {
+      return failure();
+    }
+    for (auto &&[ofr1, ofr2] :
+         llvm::zip_equal(realOffsets, iter->getMixedOffsets())) {
+      using AVE = affine::AffineValueExpr;
+      affine::AffineBuilder ab(rewriter, loc);
+      AffineExpr dim0, dim1, sym;
+      bindDims(rewriter.getContext(), dim0, dim1);
+      bindSymbols(rewriter.getContext(), sym);
+      auto aveOffset1 = AVE(dim0).bind(ofr1), aveOffset2 = AVE(dim1).bind(ofr2);
+      ofr1 = ab.add(aveOffset1, aveOffset2);
+    }
+  }
+  return realOffsets;
+}
+
+/// Get the first tilable user of given Value and check its domination at the
+/// same time
+static FailureOr<OpOperand *>
+getTilableConsumerOperandFirstUseVal(Value val, Operation *loopOp) {
+  for (auto &useOfval : val.getUses()) {
+    Operation *consumerOp = useOfval.getOwner();
+    // 1. Check whether consumerOp is tilable
+    if (!isa<TilingInterface>(consumerOp) ||
+        !isa<DestinationStyleOpInterface>(consumerOp))
+      continue;
+    // 2. Check stay in same block with loopOp
+    if (loopOp->getBlock() != consumerOp->getBlock())
+      continue;
+    // 3. Check no other user before it
+    if (failed(checkAssumptionForLoop(loopOp, consumerOp))) {
+      continue;
+    }
+    return &useOfval;
+  }
+  return failure();
+}
+
 /// Implementation of fusing consumer of a single slice by computing the
 /// slice of the consumer in-place for scf loop.
 FailureOr<scf::SCFFuseConsumerOfSliceResult>
@@ -1297,10 +1461,28 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
 
   bool isInsertSliceOp = isa<tensor...
[truncated]

@Yun-Fly Yun-Fly force-pushed the yunfei/fuse_consumer_nested_loop_ci branch from be4eab1 to afe3a82 Compare June 5, 2024 11:53
@Yun-Fly Yun-Fly force-pushed the yunfei/fuse_consumer_nested_loop_ci branch from afe3a82 to d3daa9d Compare June 5, 2024 12:59
@Yun-Fly
Copy link
Contributor Author

Yun-Fly commented Jun 5, 2024

The reason behind failure has been found, so that close this PR.

@Yun-Fly Yun-Fly closed this Jun 5, 2024
@Yun-Fly Yun-Fly deleted the yunfei/fuse_consumer_nested_loop_ci branch June 5, 2024 13:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants