-
Notifications
You must be signed in to change notification settings - Fork 13.7k
[mlir][linalg] Expose transform.fuse_into_containing_op helpers: NFC #72473
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This allows use of the fusion implementations in `transform.fuse_into_containing_op` in other contexts.
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: Quinn Dawkins (qedawkins) ChangesThis allows use of the various fusion helpers called by Full diff: https://github.com/llvm/llvm-project/pull/72473.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
index 12923663b3fb6ce..f9f96ae0b7b6494 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
@@ -59,6 +59,33 @@ tileToForallOpImpl(RewriterBase &rewriter, transform::TransformState &state,
std::optional<ArrayAttr> mapping,
linalg::ForallTilingResult &tilingResult);
+/// Find the first "extract" user of `producerOp` and tile it right before its
+/// use. The tiled op is fused under the `containingOp`.
+/// Return this fused op on success or nullptr if anything fails.
+/// If tiled op has uses that are dominated by `containingOp`, return
+/// a new `containingOp` with results of the fused op appended to
+/// results of the `containingOp` or nullptr if there are no dominated uses.
+std::tuple<SmallVector<Operation *>, Operation *>
+tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag,
+ Operation *producerOp, Operation *containingOp);
+
+/// First, find the first "scf::ForallOp" user of `producerOp` and ensure
+/// it is exactly the `containingOp`, otherwise bail.
+/// Then, find the first "extract" user of the tied block argument and tile it
+/// right before its "extract" use. The tiled op is fused under the
+/// `containingOp`.
+/// Return this fused op on success or nullptr if anything fails.
+SmallVector<Operation *>
+tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
+ RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
+ Operation *containingOp);
+
+/// Find the first use of `producerOp` inside `containingOp` and fuse into
+/// the containing op by cloning the producer. Return nullptr if no such
+/// fusion opportunity exists.
+Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag,
+ Operation *producerOp, Operation *containingOp);
+
} // namespace transform
} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index de4965f937162ea..fe98dfbf287a8ad 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -628,15 +628,11 @@ static Operation *replaceForAllWithNewSignature(
return newforallOp;
}
-/// Find the first "extract" user of `producerOp` and tile it right before its
-/// use. The tiled op is fused under the `containingOp`.
-/// Return this fused op on success or nullptr if anything fails.
-/// If tiled op has uses that are dominated by `containingOp`, return
-/// a new `containingOp` with results of the fused op appended to
-/// results of the `containingOp` or nullptr if there are no dominated uses.
-static std::tuple<SmallVector<Operation *>, Operation *>
-tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag,
- Operation *producerOp, Operation *containingOp) {
+std::tuple<SmallVector<Operation *>, Operation *>
+mlir::transform::tileAndFuseFirstExtractUse(RewriterBase &rewriter,
+ Diagnostic &diag,
+ Operation *producerOp,
+ Operation *containingOp) {
LLVM_DEBUG(DBGS() << "Try to fuse a direct extract use\n");
auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
if (!tileableProducer) {
@@ -710,14 +706,8 @@ tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag,
return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp);
}
-/// First, find the first "scf::ForallOp" user of `producerOp` and ensure
-/// it is exactly the `containingOp`, otherwise bail.
-/// Then, find the first "extract" user of the tied block argument and tile it
-/// right before its "extract" use. The tiled op is fused under the
-/// `containingOp`.
-/// Return this fused op on success or nullptr if anything fails.
-static SmallVector<Operation *>
-tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
+SmallVector<Operation *>
+mlir::transform::tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
Operation *containingOp) {
LLVM_DEBUG(DBGS() << "Try to fuse an extract use through block argument\n");
@@ -819,9 +809,10 @@ tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
return tileAndFuseResult->tiledOps;
}
-static Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag,
- Operation *producerOp,
- Operation *containingOp) {
+Operation *mlir::transform::cloneAndFuseFirstUse(RewriterBase &rewriter,
+ Diagnostic &diag,
+ Operation *producerOp,
+ Operation *containingOp) {
LLVM_DEBUG(DBGS() << "Try to fuse an use by cloning\n");
// Gather all uses inside the containing op.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain the intended use of these internal implementation details ?
I'd prefer to see this extended to also work properly with scf::ForOp or more generally than make the current implementation load-bearing to downstream clients by exposing the API.
The only helper I really want here is Makes sense that wildly exposing internal implementation details would either subject downstream to frequent API breakages, or hinder upstream's ability to make changes. If I had a way to invoke
|
These things need to be decoupled. I am trying to unify everything in a way that can be used directly and the transform dialect operations just call into these methods. There are a lot of threads to untangle here, but I am starting with these here #72178 . I still have to untangle things, and this is definitely part of the entanglement, but I dont yet know how. Maybe we can just use one of the methods here https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h |
@qedawkins ok, I'll play along, but please make sure this is properly tested standalone in-tree in the way you want to use it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO, these should be moved under lib/Linalg/Transforms
. If their caller is not using transform ops, it shouldn't need to include TransformOps.h
. They don't seem to be dependent on any transform dialect functionality, not even DiagnosedSilenceableFailure
.
Discussed some more offline with @MaheshRavishankar and there are some longer term plans to unify the various tiling/fusion implementations in a more common location, and I've managed to work around the direct need for this change until later. Sorry for noise :( closing for now. |
This allows use of the various fusion helpers called by
transform.fuse_into_containing_op
in other contexts.