Skip to content

Revert "[mlir][linalg] Enable fuse consumer" #89722

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 23, 2024
Merged

Revert "[mlir][linalg] Enable fuse consumer" #89722

merged 1 commit into from
Apr 23, 2024

Conversation

ftynse
Copy link
Member

@ftynse ftynse commented Apr 23, 2024

Reverts #85528. This was committed without tests, despite reviewers requesting tests to be added. The post-commit discussion leans towards revert, which would be consistent with the policy.

@llvmbot
Copy link
Member

llvmbot commented Apr 23, 2024

@llvm/pr-subscribers-mlir-tensor
@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Oleksandr "Alex" Zinenko (ftynse)

Changes

Reverts llvm/llvm-project#85528. This was committed without tests, despite reviewers requesting tests to be added. The post-commit discussion leans towards revert, which would be consistent with the policy.


Full diff: https://github.com/llvm/llvm-project/pull/89722.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Interfaces/TilingInterface.td (+6-61)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+2-2)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp (+26-80)
  • (modified) mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp (+6-6)
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index 84f7dec2f4003d..66382f29c24249 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -63,7 +63,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           The method returns the operation that is the tiled
           implementation.
         }],
-        /*retType=*/"FailureOr<::mlir::TilingResult>",
+        /*retType=*/"FailureOr<TilingResult>",
         /*methodName=*/"getTiledImplementation",
         /*args=*/(ins
             "OpBuilder &":$b,
@@ -82,34 +82,15 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           by the tiled implementation. Expects the same `offsets` and `sizes` as
           used to obtain the tiled implementation of the operation.
         }],
-        /*retType=*/"::mlir::LogicalResult",
+        /*retType=*/"LogicalResult",
         /*methodName=*/"getResultTilePosition",
         /*args=*/(ins
           "OpBuilder &":$b,
           "unsigned":$resultNumber,
           "ArrayRef<OpFoldResult> ":$offsets,
           "ArrayRef<OpFoldResult> ":$sizes,
-          "SmallVectorImpl<OpFoldResult> &":$resultOffsets,
-          "SmallVectorImpl<OpFoldResult> &":$resultSizes),
-        /*methodBody=*/"",
-        /*defaultImplementation=*/[{
-          return failure();
-        }]
-      >,
-      InterfaceMethod<
-        /*desc=*/[{
-          Method to return the position of iteration domain tile computed by the
-          tiled operation.
-        }],
-        /*retType=*/"::mlir::LogicalResult",
-        /*methodName=*/"getIterationDomainTileFromOperandTile",
-        /*args=*/(ins
-          "OpBuilder &":$b,
-          "unsigned":$operandNumber,
-          "ArrayRef<OpFoldResult> ":$offsets,
-          "ArrayRef<OpFoldResult> ":$sizes,
-          "SmallVectorImpl<OpFoldResult> &":$iterDomainOffsets,
-          "SmallVectorImpl<OpFoldResult> &":$iterDomainSizes),
+          "SmallVector<OpFoldResult> &":$resultOffsets,
+          "SmallVector<OpFoldResult> &":$resultSizes),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
           return failure();
@@ -138,7 +119,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
             iteration space).
           - `sizes` provides the size of the tile.
         }],
-        /*retType=*/"FailureOr<::mlir::TilingResult>",
+        /*retType=*/"FailureOr<TilingResult>",
         /*methodName=*/"generateResultTileValue",
         /*args=*/(ins
           "OpBuilder &":$b,
@@ -150,42 +131,6 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           return failure();
         }]
       >,
-      InterfaceMethod<
-        /*desc=*/[{
-          Method to generate the tiled implementation of an operation from
-          operand tile position.
-
-          Generates the IR that computes the tiled implementation of an
-          operation from operand tile.  The `offsets` and `sizes`
-          describe the tile of the operand required. This is different from
-          `getTiledImplementation` which generates the tiled
-          implementation of the operation given a tile of the
-          iteration space. This method generates a tiled
-          implementation of the operation based on the tile of the
-          operand required. This method enables consumer fusion by using
-          tile and fuse. The method returns failure if the operation
-          can't be tiled to generate the operand tile. In practical terms
-          this implies it cannot be tiled and fused with its producers.
-
-          - `offsets` provides the offset of the tile in the coordinate system
-            of the original iteration space, i.e., if an iteration space
-            dimension had non-zero offset, it must be included in the offset
-            provided here (as opposed to zero-based offset "relative" to the
-            iteration space).
-          - `sizes` provides the size of the tile.
-        }],
-        /*retType=*/"FailureOr<::mlir::TilingResult>",
-        /*methodName=*/"getTiledImplementationFromOperandTile",
-        /*args=*/(ins
-          "OpBuilder &":$b,
-          "unsigned":$operandNumber,
-          "ArrayRef<OpFoldResult>":$offsets,
-          "ArrayRef<OpFoldResult>":$sizes),
-        /*methodBody=*/"",
-        /*defaultImplementation=*/[{
-          return failure();
-        }]
-      >,
       InterfaceMethod<
         /*desc=*/[{
           Generates the scalar implementation of the operation.
@@ -197,7 +142,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           transformations are done, this method can be used to lower to scalar
           code that can then be lowered to LLVM or SPIR-V dialects.
         }],
-        /*retType=*/"::mlir::LogicalResult",
+        /*retType=*/"LogicalResult",
         /*methodName=*/"generateScalarImplementation",
         /*args=*/(ins
             "OpBuilder &":$b,
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index e9999c34d0face..9c5c58fa1fabfb 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2425,8 +2425,8 @@ SoftmaxOp::getTiledImplementation(OpBuilder &builder,
 
 LogicalResult SoftmaxOp::getResultTilePosition(
     OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
-    ArrayRef<OpFoldResult> sizes, SmallVectorImpl<OpFoldResult> &resultOffsets,
-    SmallVectorImpl<OpFoldResult> &resultSizes) {
+    ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
+    SmallVector<OpFoldResult> &resultSizes) {
   if (resultNumber == 0) {
     resultOffsets.assign(offsets.begin(), offsets.end());
     resultSizes.assign(sizes.begin(), sizes.end());
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 71e9c3771dcded..bd870d4f982e5d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -110,7 +110,7 @@ struct LinalgOpTilingInterface
         }));
   }
 
-  /// Instantiate the tiled implementation of the operation.
+  // Instantiate the tiled implementation of the operation.
   FailureOr<TilingResult>
   getTiledImplementation(Operation *op, OpBuilder &b,
                          ArrayRef<OpFoldResult> offsets,
@@ -132,66 +132,14 @@ struct LinalgOpTilingInterface
     return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
   }
 
-  void
-  getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b, AffineMap indexingMap,
-                         ArrayRef<OpFoldResult> offsets,
-                         ArrayRef<OpFoldResult> sizes,
-                         SmallVectorImpl<OpFoldResult> &mappedOffsets,
-                         SmallVectorImpl<OpFoldResult> &mappedSizes) const {
-    unsigned numLoops = linalgOp.getNumLoops();
-    auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation());
-    mappedOffsets.resize(numLoops);
-    mappedSizes.resize(numLoops);
-    if (!indexingMap.isPermutation()) {
-      SmallVector<Range> iterationDomain =
-          tilingInterfaceOp.getIterationDomain(b);
-      for (const auto &&[index, value] : llvm::enumerate(iterationDomain)) {
-        mappedOffsets[index] = value.offset;
-        mappedSizes[index] = value.size;
-      }
-    }
-    for (const auto &&[index, value] :
-         llvm::enumerate(indexingMap.getResults())) {
-      unsigned dimPosition = cast<AffineDimExpr>(value).getPosition();
-      mappedOffsets[dimPosition] = offsets[index];
-      mappedSizes[dimPosition] = sizes[index];
-    }
-  }
-
-  /// Return the details of the output tile generated by the tiled
-  /// implementation.
-  LogicalResult getIterationDomainTileFromOperandTile(
-      Operation *op, OpBuilder &b, unsigned operandNumber,
-      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
-      SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
-      SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
-    auto linalgOp = cast<LinalgOp>(op);
-
-    // Check that the indexing map used for the operand is a projected
-    // permutation. This could be relaxed with a more general approach that can
-    // map the offsets and sizes from the operand to iteration space tiles
-    // (filling in full extent for dimensions not used to access the result).
-    AffineMap indexingMap =
-        linalgOp.getMatchingIndexingMap(&op->getOpOperand(operandNumber));
-    if (!indexingMap.isProjectedPermutation()) {
-      return emitError(op->getLoc(),
-                       "unhandled get iter domain position when operand is not "
-                       "accessed using a permuted projection");
-    }
-
-    getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
-                           iterDomainOffsets, iterDomainSizes);
-    return success();
-  }
-
-  /// Return the details of the output tile generated by the tiled
-  /// implementation.
+  // Return the details of the output tile generated by the tiled
+  // implementation.
   LogicalResult
   getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
                         ArrayRef<OpFoldResult> offsets,
                         ArrayRef<OpFoldResult> sizes,
-                        SmallVectorImpl<OpFoldResult> &resultOffsets,
-                        SmallVectorImpl<OpFoldResult> &resultSizes) const {
+                        SmallVector<OpFoldResult> &resultOffsets,
+                        SmallVector<OpFoldResult> &resultSizes) const {
     Location loc = op->getLoc();
     LinalgOp linalgOp = cast<LinalgOp>(op);
 
@@ -212,21 +160,6 @@ struct LinalgOpTilingInterface
     return success();
   }
 
-  FailureOr<TilingResult> getTiledImplementationFromOperandTile(
-      Operation *op, OpBuilder &b, unsigned operandNumber,
-      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
-    SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
-    auto tilingInterfaceOp = cast<TilingInterface>(op);
-    if (failed(tilingInterfaceOp.getIterationDomainTileFromOperandTile(
-            b, operandNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
-      return emitError(
-          op->getLoc(),
-          "unable to obtain the iter domain position of the operation.");
-    }
-    return tilingInterfaceOp.getTiledImplementation(b, mappedOffsets,
-                                                    mappedSizes);
-  }
-
   FailureOr<TilingResult>
   generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
                           ArrayRef<OpFoldResult> offsets,
@@ -244,16 +177,29 @@ struct LinalgOpTilingInterface
           "unhandled tiled implementation generation when result is not "
           "accessed using a permuted projection");
     }
-    SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
-    getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
-                           mappedOffsets, mappedSizes);
-    auto tilingInterfaceOp = cast<TilingInterface>(op);
-    FailureOr<TilingResult> tilingResult =
-        tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes);
 
-    if (failed(tilingResult))
-      return failure();
+    auto numLoops = linalgOp.getNumLoops();
+    auto tilingInterfaceOp = cast<TilingInterface>(op);
+    SmallVector<OpFoldResult> iterationTileOffsets(numLoops),
+        iterationTileSizes(numLoops);
+    if (!indexingMap.isPermutation()) {
+      SmallVector<Range> iterationDomain =
+          tilingInterfaceOp.getIterationDomain(b);
+      for (const auto &range : llvm::enumerate(iterationDomain)) {
+        iterationTileOffsets[range.index()] = range.value().offset;
+        iterationTileSizes[range.index()] = range.value().size;
+      }
+    }
+    for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) {
+      unsigned dimPosition =
+          cast<AffineDimExpr>(resultExpr.value()).getPosition();
+      iterationTileOffsets[dimPosition] = offsets[resultExpr.index()];
+      iterationTileSizes[dimPosition] = sizes[resultExpr.index()];
+    }
 
+    FailureOr<TilingResult> tilingResult =
+        tilingInterfaceOp.getTiledImplementation(b, iterationTileOffsets,
+                                                 iterationTileSizes);
     if (tilingResult->tiledOps.size() != 1)
       return op->emitOpError("failed to generate tiled implementation");
 
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index 296c5fc7a5c2bd..d25efcf50ec566 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -61,8 +61,8 @@ struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> {
   getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
                         ArrayRef<OpFoldResult> offsets,
                         ArrayRef<OpFoldResult> sizes,
-                        SmallVectorImpl<OpFoldResult> &resultOffsets,
-                        SmallVectorImpl<OpFoldResult> &resultSizes) const {
+                        SmallVector<OpFoldResult> &resultOffsets,
+                        SmallVector<OpFoldResult> &resultSizes) const {
     resultOffsets.assign(offsets.begin(), offsets.end());
     resultSizes.assign(sizes.begin(), sizes.end());
     return success();
@@ -199,8 +199,8 @@ struct PackOpTiling
   getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
                         ArrayRef<OpFoldResult> offsets,
                         ArrayRef<OpFoldResult> sizes,
-                        SmallVectorImpl<OpFoldResult> &resultOffsets,
-                        SmallVectorImpl<OpFoldResult> &resultSizes) const {
+                        SmallVector<OpFoldResult> &resultOffsets,
+                        SmallVector<OpFoldResult> &resultSizes) const {
     // The iteration domain is over outer dimensions of packed layout. In this
     // context, the outer dimensions of `resultOffsets` are `offsets`. The
     // inner dimensions of `resultOffsets` are zeros because tiling is not
@@ -452,8 +452,8 @@ struct UnPackOpTiling
   getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
                         ArrayRef<OpFoldResult> offsets,
                         ArrayRef<OpFoldResult> sizes,
-                        SmallVectorImpl<OpFoldResult> &resultOffsets,
-                        SmallVectorImpl<OpFoldResult> &resultSizes) const {
+                        SmallVector<OpFoldResult> &resultOffsets,
+                        SmallVector<OpFoldResult> &resultSizes) const {
     resultOffsets = llvm::to_vector(offsets);
     resultSizes = llvm::to_vector(sizes);
     return success();

@ftynse ftynse merged commit f220c35 into main Apr 23, 2024
@ftynse ftynse deleted the revert-85528-main branch April 23, 2024 09:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants