Skip to content

[mlir][bufferization] Use TensorLike, BufferLike type interfaces #136736

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

andrey-golubev
Copy link
Contributor

The general idea is to replace most of the places that rely on builtin's TensorType / BaseMemRefType with the newly added type interfaces.

Thus far, do the bare minimum: refactor (almost) "blindly" the API of the dialect and options, leaving most of the logic "as is". The exceptions are the bufferization.{to_tensor, to_memref} ops that act as "glue" when bufferizing neighbouring operations and the enclosing functions.

The general idea is to replace most of the places that rely on builtin's
TensorType / BaseMemRefType with the newly added type interfaces.

Thus far, do the bare minimum: refactor (almost) "blindly" the API of
the dialect and options, leaving most of the logic "as is". The
exceptions are the bufferization.{to_tensor, to_memref} ops that act as
"glue" when bufferizing neighbouring operations and the enclosing
functions.
@llvmbot
Copy link
Member

llvmbot commented Apr 22, 2025

@llvm/pr-subscribers-mlir-scf
@llvm/pr-subscribers-mlir-bufferization
@llvm/pr-subscribers-mlir-sparse

@llvm/pr-subscribers-mlir

Author: Andrei Golubev (andrey-golubev)

Changes

The general idea is to replace most of the places that rely on builtin's TensorType / BaseMemRefType with the newly added type interfaces.

Thus far, do the bare minimum: refactor (almost) "blindly" the API of the dialect and options, leaving most of the logic "as is". The exceptions are the bufferization.{to_tensor, to_memref} ops that act as "glue" when bufferizing neighbouring operations and the enclosing functions.


Patch is 75.29 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/136736.diff

26 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h (+11-10)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td (+1-1)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td (+10-7)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h (+1)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td (+10-3)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h (+21-14)
  • (modified) mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp (+7-6)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp (+54-40)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp (+5-1)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (+16-14)
  • (added) mlir/lib/Dialect/Bufferization/IR/BufferizationTypeInterfaces.cpp (+21)
  • (modified) mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt (+1)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp (+9-8)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp (+10-8)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp (+15-15)
  • (modified) mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp (+60-49)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp (+2-2)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp (+2-2)
  • (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+5-4)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir (+20-1)
  • (modified) mlir/test/Dialect/Bufferization/invalid.mlir (+4-4)
  • (modified) mlir/test/lib/Dialect/Bufferization/TestTensorCopyInsertion.cpp (+3-1)
  • (modified) mlir/test/lib/Dialect/Test/TestOpDefs.cpp (+24)
  • (modified) mlir/test/lib/Dialect/Test/TestOps.h (+1)
  • (modified) mlir/test/lib/Dialect/Test/TestOps.td (+54-1)
  • (modified) mlir/test/lib/Dialect/Test/TestTypeDefs.td (+3)
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index ada9539e87121..70092908d961f 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -17,6 +17,7 @@
 #include <optional>
 
 #include "mlir/Dialect/Bufferization/IR/BufferizationEnums.h.inc"
+#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
 
 namespace mlir {
 class OpBuilder;
@@ -259,18 +260,18 @@ struct BufferizationOptions {
       std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
   /// Initializer function for analysis state.
   using AnalysisStateInitFn = std::function<void(AnalysisState &)>;
-  /// Tensor -> MemRef type converter.
-  /// Parameters: tensor type, memory space, func op, bufferization options
+  /// TensorLike -> BufferLike type converter.
+  /// Parameters: tensor like type, memory space, func op, bufferization options
   using FunctionArgTypeConverterFn =
-      std::function<BaseMemRefType(TensorType, Attribute memorySpace,
+      std::function<BufferLikeType(TensorLikeType, Attribute memorySpace,
                                    func::FuncOp, const BufferizationOptions &)>;
-  /// Tensor -> MemRef type converter.
+  /// TensorLike -> BufferLike type converter.
   /// Parameters: Value, memory space, bufferization options
-  using UnknownTypeConverterFn = std::function<BaseMemRefType(
+  using UnknownTypeConverterFn = std::function<BufferLikeType(
       Value, Attribute memorySpace, const BufferizationOptions &)>;
   // Produce a MemorySpace attribute from a tensor type
   using DefaultMemorySpaceFn =
-      std::function<std::optional<Attribute>(TensorType t)>;
+      std::function<std::optional<Attribute>(TensorLikeType t)>;
 
   BufferizationOptions();
 
@@ -360,7 +361,7 @@ struct BufferizationOptions {
   // Returning std::nullopt will cause bufferization to fail (useful to indicate
   // failure to determine memory space for a tensor type).
   DefaultMemorySpaceFn defaultMemorySpaceFn =
-      [](TensorType t) -> std::optional<Attribute> { return Attribute(); };
+      [](TensorLikeType t) -> std::optional<Attribute> { return Attribute(); };
 
   /// If set to `true`, the analysis is skipped. A buffer is copied before every
   /// write. This flag cannot be used together with `testAnalysisOnly = true`.
@@ -600,7 +601,7 @@ FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
 /// IR, this function can be used.
 ///
 /// This function is a wrapper around BufferizableOpInterface::getBufferType.
-FailureOr<BaseMemRefType> getBufferType(Value value,
+FailureOr<BufferLikeType> getBufferType(Value value,
                                         const BufferizationOptions &options);
 
 /// Return the buffer type for a given Value (tensor) after bufferization
@@ -613,7 +614,7 @@ FailureOr<BaseMemRefType> getBufferType(Value value,
 /// IR, this function can be used.
 ///
 /// This function is a wrapper around `BufferizableOpInterface::getBufferType`.
-FailureOr<BaseMemRefType> getBufferType(Value value,
+FailureOr<BufferLikeType> getBufferType(Value value,
                                         const BufferizationOptions &options,
                                         SmallVector<Value> &invocationStack);
 
@@ -693,7 +694,7 @@ AliasingOpOperandList defaultGetAliasingOpOperands(Value value,
 /// This is the default implementation of
 /// BufferizableOpInterface::getBufferType. Should not be called from other
 /// places.
-FailureOr<BaseMemRefType>
+FailureOr<BufferLikeType>
 defaultGetBufferType(Value value, const BufferizationOptions &options,
                      SmallVector<Value> &invocationStack);
 
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index 95022d7d665d2..1de1742fab81a 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -518,7 +518,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
           Note: This interface method should never be called directly from user
           code. Always use `bufferization::getBufferType`.
         }],
-        /*retType=*/"::mlir::FailureOr<::mlir::BaseMemRefType>",
+        /*retType=*/"::mlir::FailureOr<::mlir::bufferization::BufferLikeType>",
         /*methodName=*/"getBufferType",
         /*args=*/(ins "::mlir::Value":$value,
                       "const ::mlir::bufferization::BufferizationOptions &":$options,
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index fad78a63444b9..81ce0f3fb650b 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -13,6 +13,7 @@ include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td"
 include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td"
 include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
 include "mlir/Dialect/Bufferization/IR/BufferizationBase.td"
+include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td"
 include "mlir/Interfaces/DestinationStyleOpInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -109,7 +110,7 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
     AliasingValueList getAliasingValues(
         OpOperand &opOperand, const AnalysisState &state);
 
-    FailureOr<BaseMemRefType> getBufferType(
+    FailureOr<BufferLikeType> getBufferType(
         Value value, const BufferizationOptions &options,
         SmallVector<Value> &invocationStack);
 
@@ -438,11 +439,11 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
     away. However, such IR is no longer bufferizable with One-Shot Bufferize.
   }];
 
-  let arguments = (ins Arg<AnyRankedOrUnrankedMemRef,
+  let arguments = (ins Arg<Bufferization_BufferLikeTypeInterface,
                            "the reference to load from",
                            [MemReadAt<0, FullEffect>]>:$memref,
                        UnitAttr:$restrict, UnitAttr:$writable);
-  let results = (outs AnyTensor:$result);
+  let results = (outs Bufferization_TensorLikeTypeInterface:$result);
 
   let extraClassDeclaration = [{
     /// The result of a to_tensor is always a tensor.
@@ -465,10 +466,10 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
 
     bool isWritable(Value value, const AnalysisState &state);
 
-    FailureOr<BaseMemRefType> getBufferType(
+    FailureOr<BufferLikeType> getBufferType(
         Value value, const BufferizationOptions &options,
         SmallVector<Value> &invocationStack) {
-      return ::llvm::cast<BaseMemRefType>(getMemref().getType());
+      return ::llvm::cast<BufferLikeType>(getMemref().getType());
     }
   }];
 
@@ -493,6 +494,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
 // ToMemrefOp
 //===----------------------------------------------------------------------===//
 
+// TODO: rename to "to_buffer"
 def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
     BufferizableOpInterface,
     SameOperandsAndResultShape,
@@ -519,8 +521,9 @@ def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
     the returned buffer) will not be written to.
   }];
 
-  let arguments = (ins AnyTensor:$tensor, UnitAttr:$read_only);
-  let results = (outs AnyRankedOrUnrankedMemRef:$memref);
+  let arguments = (ins Bufferization_TensorLikeTypeInterface:$tensor,
+                       UnitAttr:$read_only);
+  let results = (outs Bufferization_BufferLikeTypeInterface:$memref);
 
   let extraClassDeclaration = [{
     //===------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
index 5faa1479ee542..290f1298f2501 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
@@ -13,6 +13,7 @@
 // Bufferization Type Interfaces
 //===----------------------------------------------------------------------===//
 
+#include "mlir/IR/Attributes.h" // mlir::Attribute
 #include "mlir/IR/Types.h"
 
 #include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h.inc"
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td
index f19224a295648..c053a6bdc1a91 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td
@@ -33,10 +33,17 @@ def Bufferization_BufferLikeTypeInterface
   let description = [{
     Indicates that this type is a buffer type (similarly to a MLIR builtin
     memref) for bufferization purposes.
-
-    The interface currently has no methods as it is used by types to opt into
-    being supported by the bufferization procedures.
   }];
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/[{
+        Returns the memory space in which data referred to by this buffer resides.
+      }],
+      /*retType=*/"::mlir::Attribute",
+      /*methodName=*/"getMemorySpace"
+    >,
+  ];
 }
 
 #endif // BUFFERIZATION_TYPE_INTERFACES
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
index 78109770efab7..89eb65c4a0942 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
@@ -32,7 +32,7 @@ template <typename ConcreteModel, typename ConcreteOp>
 struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
     : public BufferizableOpInterface::ExternalModel<ConcreteModel, ConcreteOp> {
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 SmallVector<Value> &invocationStack) const {
     // Note: The user may want to override this function for OpResults in
@@ -46,7 +46,7 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
     // operand types of all forwarded values. If these are all the same type,
     // take that type. Otherwise, take only the memory space and fall back to a
     // buffer type with a fully dynamic layout map.
-    BaseMemRefType bufferType;
+    BufferLikeType bufferType;
     auto tensorType = cast<TensorType>(value.getType());
     for (OpOperand *opOperand :
          detail::getCallerOpOperands(cast<BlockArgument>(value))) {
@@ -59,13 +59,13 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
         continue;
 
       // Compute the bufferized type of the forwarded operand.
-      BaseMemRefType callerType;
-      if (auto memrefType =
-              dyn_cast<BaseMemRefType>(opOperand->get().getType())) {
+      BufferLikeType callerType;
+      if (auto bufferLikeType =
+              dyn_cast<BufferLikeType>(opOperand->get().getType())) {
         // The operand was already bufferized. Take its type directly.
-        callerType = memrefType;
+        callerType = bufferLikeType;
       } else {
-        FailureOr<BaseMemRefType> maybeCallerType =
+        FailureOr<BufferLikeType> maybeCallerType =
             bufferization::getBufferType(opOperand->get(), options,
                                          invocationStack);
         if (failed(maybeCallerType))
@@ -86,14 +86,20 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
         // of the earlier forwarded operands, fall back to a buffer type with a
         // fully dynamic layout map.
 #ifndef NDEBUG
+      assert(mlir::isa<BaseMemRefType>(bufferType) &&
+             mlir::isa<BaseMemRefType>(callerType) && "expected memrefs");
+      auto memrefType = mlir::cast<BaseMemRefType>(bufferType);
+      auto callerMemrefType = mlir::cast<BaseMemRefType>(callerType);
+
       if (auto rankedTensorType = dyn_cast<RankedTensorType>(tensorType)) {
-        assert(bufferType.hasRank() && callerType.hasRank() &&
+        assert(memrefType.hasRank() && callerMemrefType.hasRank() &&
                "expected ranked memrefs");
-        assert(llvm::all_equal({bufferType.getShape(), callerType.getShape(),
-                                rankedTensorType.getShape()}) &&
-               "expected same shape");
+        assert(
+            llvm::all_equal({memrefType.getShape(), callerMemrefType.getShape(),
+                             rankedTensorType.getShape()}) &&
+            "expected same shape");
       } else {
-        assert(!bufferType.hasRank() && !callerType.hasRank() &&
+        assert(!memrefType.hasRank() && !callerMemrefType.hasRank() &&
                "expected unranked memrefs");
       }
 #endif // NDEBUG
@@ -102,8 +108,9 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
         return op->emitOpError("incoming operands of block argument have "
                                "inconsistent memory spaces");
 
-      bufferType = getMemRefTypeWithFullyDynamicLayout(
-          tensorType, bufferType.getMemorySpace());
+      bufferType =
+          mlir::cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
+              tensorType, bufferType.getMemorySpace()));
     }
 
     if (!bufferType)
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index 5e69a98db8f1e..433757192bfd1 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -26,7 +26,7 @@ struct ConstantOpInterface
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           const BufferizationOptions &options) const {
     auto constantOp = cast<arith::ConstantOp>(op);
-    auto type = dyn_cast<RankedTensorType>(constantOp.getType());
+    auto type = dyn_cast<TensorLikeType>(constantOp.getType());
 
     // Only ranked tensors are supported.
     if (!type)
@@ -176,7 +176,7 @@ struct SelectOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<bufferization::BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 SmallVector<Value> &invocationStack) const {
     auto selectOp = cast<arith::SelectOp>(op);
@@ -195,10 +195,11 @@ struct SelectOpInterface
     // If the buffers have different types, they differ only in their layout
     // map.
     auto memrefType = llvm::cast<MemRefType>(*trueType);
-    return getMemRefTypeWithFullyDynamicLayout(
-        RankedTensorType::get(memrefType.getShape(),
-                              memrefType.getElementType()),
-        memrefType.getMemorySpace());
+    return mlir::cast<bufferization::BufferLikeType>(
+        getMemRefTypeWithFullyDynamicLayout(
+            RankedTensorType::get(memrefType.getShape(),
+                                  memrefType.getElementType()),
+            memrefType.getMemorySpace()));
   }
 };
 
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 99ffa62c41a4d..82ff1bdfe5fd7 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -206,12 +206,13 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
   // Add 'memory_space' attribute. Not needed if 'copy' operand is specified.
   if (copy)
     return allocTensorOp.getResult();
-  FailureOr<BaseMemRefType> copyBufferType = getBufferType(tensor, options);
+  FailureOr<BufferLikeType> copyBufferType = getBufferType(tensor, options);
   if (failed(copyBufferType))
     return failure();
   std::optional<Attribute> memorySpace = copyBufferType->getMemorySpace();
   if (!memorySpace)
-    memorySpace = options.defaultMemorySpaceFn(tensorType);
+    memorySpace =
+        options.defaultMemorySpaceFn(mlir::cast<TensorLikeType>(tensorType));
   if (memorySpace.has_value())
     allocTensorOp.setMemorySpaceAttr(memorySpace.value());
   return allocTensorOp.getResult();
@@ -229,6 +230,7 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
   // Find all out-of-place OpOperands.
   for (OpOperand &opOperand : op->getOpOperands()) {
     Type operandType = opOperand.get().getType();
+    // Note: can only copy TensorType (any other TensorLikeType is rejected)
     if (!llvm::isa<TensorType>(operandType))
       continue;
     if (state.isInPlace(opOperand))
@@ -328,18 +330,21 @@ bool OpFilter::isOpAllowed(Operation *op) const {
 namespace {
 
 /// Default function arg type converter: Use a fully dynamic layout map.
-BaseMemRefType
-defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace,
-                                func::FuncOp funcOp,
+bufferization::BufferLikeType
+defaultFunctionArgTypeConverter(bufferization::TensorLikeType type,
+                                Attribute memorySpace, func::FuncOp funcOp,
                                 const BufferizationOptions &options) {
-  return getMemRefTypeWithFullyDynamicLayout(type, memorySpace);
+  return mlir::cast<bufferization::BufferLikeType>(
+      getMemRefTypeWithFullyDynamicLayout(mlir::cast<TensorType>(type),
+                                          memorySpace));
 }
 /// Default unknown type converter: Use a fully dynamic layout map.
-BaseMemRefType
+BufferLikeType
 defaultUnknownTypeConverter(Value value, Attribute memorySpace,
                             const BufferizationOptions &options) {
-  return getMemRefTypeWithFullyDynamicLayout(
-      llvm::cast<TensorType>(value.getType()), memorySpace);
+  return mlir::cast<bufferization::BufferLikeType>(
+      getMemRefTypeWithFullyDynamicLayout(
+          llvm::cast<TensorType>(value.getType()), memorySpace));
 }
 
 } // namespace
@@ -376,14 +381,16 @@ BufferizationOptions::dynCastBufferizableOp(Value value) const {
 
 void BufferizationOptions::setFunctionBoundaryTypeConversion(
     LayoutMapOption layoutMapOption) {
-  functionArgTypeConverterFn = [=](TensorType tensorType, Attribute memorySpace,
-                                   func::FuncOp funcOp,
+  functionArgTypeConverterFn = [=](TensorLikeType tensorType,
+                                   Attribute memorySpace, func::FuncOp funcOp,
                                    const BufferizationOptions &options) {
     if (layoutMapOption == LayoutMapOption::IdentityLayoutMap)
-      return bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType,
-                                                                  memorySpace);
-    return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
-                                                              memorySpace);
+      return mlir::cast<bufferization::BufferLikeType>(
+          bufferization::getMemRefTypeWithStaticIdentityLayout(
+              mlir::cast<TensorType>(tensorType), memorySpace));
+    return mlir::cast<bufferization::BufferLikeType>(
+        bufferization::getMemRefTypeWithFullyDynamicLayout(
+            mlir::cast<TensorType>(tensorType), memorySpace));
   };
   inferFunctionResultLayout =
       layoutMapOption == LayoutMapOption::InferLayoutMap;
@@ -473,7 +480,8 @@ bool AnalysisState::bufferizesToMemoryWrite(Value value) const {
 /// read. Also takes into account ops that create an alias but do not read by
 /// themselves (e.g., ExtractSliceOp).
 bool AnalysisState::isValueRead(Value value) const {
-  assert(llvm::isa<TensorType>(value.getType()) && "expected TensorType");
+  assert(llvm::isa<bufferization::TensorLikeType>(value.getType()) &&
+         "expected TensorLikeType");
   SmallVector<OpOperand *> workingSet;
   DenseSet<OpOperand *> visited;
   for (OpOperand &use : value.getUses())
@@ -66...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Apr 22, 2025

@llvm/pr-subscribers-mlir-arith

Author: Andrei Golubev (andrey-golubev)

Changes

The general idea is to replace most of the places that rely on builtin's TensorType / BaseMemRefType with the newly added type interfaces.

Thus far, do the bare minimum: refactor (almost) "blindly" the API of the dialect and options, leaving most of the logic "as is". The exceptions are the bufferization.{to_tensor, to_memref} ops that act as "glue" when bufferizing neighbouring operations and the enclosing functions.


Patch is 75.29 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/136736.diff

26 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h (+11-10)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td (+1-1)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td (+10-7)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h (+1)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td (+10-3)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h (+21-14)
  • (modified) mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp (+7-6)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp (+54-40)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp (+5-1)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (+16-14)
  • (added) mlir/lib/Dialect/Bufferization/IR/BufferizationTypeInterfaces.cpp (+21)
  • (modified) mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt (+1)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp (+9-8)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp (+10-8)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp (+15-15)
  • (modified) mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp (+60-49)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp (+2-2)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp (+2-2)
  • (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+5-4)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir (+20-1)
  • (modified) mlir/test/Dialect/Bufferization/invalid.mlir (+4-4)
  • (modified) mlir/test/lib/Dialect/Bufferization/TestTensorCopyInsertion.cpp (+3-1)
  • (modified) mlir/test/lib/Dialect/Test/TestOpDefs.cpp (+24)
  • (modified) mlir/test/lib/Dialect/Test/TestOps.h (+1)
  • (modified) mlir/test/lib/Dialect/Test/TestOps.td (+54-1)
  • (modified) mlir/test/lib/Dialect/Test/TestTypeDefs.td (+3)
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index ada9539e87121..70092908d961f 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -17,6 +17,7 @@
 #include <optional>
 
 #include "mlir/Dialect/Bufferization/IR/BufferizationEnums.h.inc"
+#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
 
 namespace mlir {
 class OpBuilder;
@@ -259,18 +260,18 @@ struct BufferizationOptions {
       std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
   /// Initializer function for analysis state.
   using AnalysisStateInitFn = std::function<void(AnalysisState &)>;
-  /// Tensor -> MemRef type converter.
-  /// Parameters: tensor type, memory space, func op, bufferization options
+  /// TensorLike -> BufferLike type converter.
+  /// Parameters: tensor like type, memory space, func op, bufferization options
   using FunctionArgTypeConverterFn =
-      std::function<BaseMemRefType(TensorType, Attribute memorySpace,
+      std::function<BufferLikeType(TensorLikeType, Attribute memorySpace,
                                    func::FuncOp, const BufferizationOptions &)>;
-  /// Tensor -> MemRef type converter.
+  /// TensorLike -> BufferLike type converter.
   /// Parameters: Value, memory space, bufferization options
-  using UnknownTypeConverterFn = std::function<BaseMemRefType(
+  using UnknownTypeConverterFn = std::function<BufferLikeType(
       Value, Attribute memorySpace, const BufferizationOptions &)>;
   // Produce a MemorySpace attribute from a tensor type
   using DefaultMemorySpaceFn =
-      std::function<std::optional<Attribute>(TensorType t)>;
+      std::function<std::optional<Attribute>(TensorLikeType t)>;
 
   BufferizationOptions();
 
@@ -360,7 +361,7 @@ struct BufferizationOptions {
   // Returning std::nullopt will cause bufferization to fail (useful to indicate
   // failure to determine memory space for a tensor type).
   DefaultMemorySpaceFn defaultMemorySpaceFn =
-      [](TensorType t) -> std::optional<Attribute> { return Attribute(); };
+      [](TensorLikeType t) -> std::optional<Attribute> { return Attribute(); };
 
   /// If set to `true`, the analysis is skipped. A buffer is copied before every
   /// write. This flag cannot be used together with `testAnalysisOnly = true`.
@@ -600,7 +601,7 @@ FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
 /// IR, this function can be used.
 ///
 /// This function is a wrapper around BufferizableOpInterface::getBufferType.
-FailureOr<BaseMemRefType> getBufferType(Value value,
+FailureOr<BufferLikeType> getBufferType(Value value,
                                         const BufferizationOptions &options);
 
 /// Return the buffer type for a given Value (tensor) after bufferization
@@ -613,7 +614,7 @@ FailureOr<BaseMemRefType> getBufferType(Value value,
 /// IR, this function can be used.
 ///
 /// This function is a wrapper around `BufferizableOpInterface::getBufferType`.
-FailureOr<BaseMemRefType> getBufferType(Value value,
+FailureOr<BufferLikeType> getBufferType(Value value,
                                         const BufferizationOptions &options,
                                         SmallVector<Value> &invocationStack);
 
@@ -693,7 +694,7 @@ AliasingOpOperandList defaultGetAliasingOpOperands(Value value,
 /// This is the default implementation of
 /// BufferizableOpInterface::getBufferType. Should not be called from other
 /// places.
-FailureOr<BaseMemRefType>
+FailureOr<BufferLikeType>
 defaultGetBufferType(Value value, const BufferizationOptions &options,
                      SmallVector<Value> &invocationStack);
 
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index 95022d7d665d2..1de1742fab81a 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -518,7 +518,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
           Note: This interface method should never be called directly from user
           code. Always use `bufferization::getBufferType`.
         }],
-        /*retType=*/"::mlir::FailureOr<::mlir::BaseMemRefType>",
+        /*retType=*/"::mlir::FailureOr<::mlir::bufferization::BufferLikeType>",
         /*methodName=*/"getBufferType",
         /*args=*/(ins "::mlir::Value":$value,
                       "const ::mlir::bufferization::BufferizationOptions &":$options,
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index fad78a63444b9..81ce0f3fb650b 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -13,6 +13,7 @@ include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td"
 include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td"
 include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
 include "mlir/Dialect/Bufferization/IR/BufferizationBase.td"
+include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td"
 include "mlir/Interfaces/DestinationStyleOpInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -109,7 +110,7 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
     AliasingValueList getAliasingValues(
         OpOperand &opOperand, const AnalysisState &state);
 
-    FailureOr<BaseMemRefType> getBufferType(
+    FailureOr<BufferLikeType> getBufferType(
         Value value, const BufferizationOptions &options,
         SmallVector<Value> &invocationStack);
 
@@ -438,11 +439,11 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
     away. However, such IR is no longer bufferizable with One-Shot Bufferize.
   }];
 
-  let arguments = (ins Arg<AnyRankedOrUnrankedMemRef,
+  let arguments = (ins Arg<Bufferization_BufferLikeTypeInterface,
                            "the reference to load from",
                            [MemReadAt<0, FullEffect>]>:$memref,
                        UnitAttr:$restrict, UnitAttr:$writable);
-  let results = (outs AnyTensor:$result);
+  let results = (outs Bufferization_TensorLikeTypeInterface:$result);
 
   let extraClassDeclaration = [{
     /// The result of a to_tensor is always a tensor.
@@ -465,10 +466,10 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
 
     bool isWritable(Value value, const AnalysisState &state);
 
-    FailureOr<BaseMemRefType> getBufferType(
+    FailureOr<BufferLikeType> getBufferType(
         Value value, const BufferizationOptions &options,
         SmallVector<Value> &invocationStack) {
-      return ::llvm::cast<BaseMemRefType>(getMemref().getType());
+      return ::llvm::cast<BufferLikeType>(getMemref().getType());
     }
   }];
 
@@ -493,6 +494,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
 // ToMemrefOp
 //===----------------------------------------------------------------------===//
 
+// TODO: rename to "to_buffer"
 def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
     BufferizableOpInterface,
     SameOperandsAndResultShape,
@@ -519,8 +521,9 @@ def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
     the returned buffer) will not be written to.
   }];
 
-  let arguments = (ins AnyTensor:$tensor, UnitAttr:$read_only);
-  let results = (outs AnyRankedOrUnrankedMemRef:$memref);
+  let arguments = (ins Bufferization_TensorLikeTypeInterface:$tensor,
+                       UnitAttr:$read_only);
+  let results = (outs Bufferization_BufferLikeTypeInterface:$memref);
 
   let extraClassDeclaration = [{
     //===------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
index 5faa1479ee542..290f1298f2501 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
@@ -13,6 +13,7 @@
 // Bufferization Type Interfaces
 //===----------------------------------------------------------------------===//
 
+#include "mlir/IR/Attributes.h" // mlir::Attribute
 #include "mlir/IR/Types.h"
 
 #include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h.inc"
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td
index f19224a295648..c053a6bdc1a91 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td
@@ -33,10 +33,17 @@ def Bufferization_BufferLikeTypeInterface
   let description = [{
     Indicates that this type is a buffer type (similarly to a MLIR builtin
     memref) for bufferization purposes.
-
-    The interface currently has no methods as it is used by types to opt into
-    being supported by the bufferization procedures.
   }];
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/[{
+        Returns the memory space in which data referred to by this buffer resides.
+      }],
+      /*retType=*/"::mlir::Attribute",
+      /*methodName=*/"getMemorySpace"
+    >,
+  ];
 }
 
 #endif // BUFFERIZATION_TYPE_INTERFACES
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
index 78109770efab7..89eb65c4a0942 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
@@ -32,7 +32,7 @@ template <typename ConcreteModel, typename ConcreteOp>
 struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
     : public BufferizableOpInterface::ExternalModel<ConcreteModel, ConcreteOp> {
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 SmallVector<Value> &invocationStack) const {
     // Note: The user may want to override this function for OpResults in
@@ -46,7 +46,7 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
     // operand types of all forwarded values. If these are all the same type,
     // take that type. Otherwise, take only the memory space and fall back to a
     // buffer type with a fully dynamic layout map.
-    BaseMemRefType bufferType;
+    BufferLikeType bufferType;
     auto tensorType = cast<TensorType>(value.getType());
     for (OpOperand *opOperand :
          detail::getCallerOpOperands(cast<BlockArgument>(value))) {
@@ -59,13 +59,13 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
         continue;
 
       // Compute the bufferized type of the forwarded operand.
-      BaseMemRefType callerType;
-      if (auto memrefType =
-              dyn_cast<BaseMemRefType>(opOperand->get().getType())) {
+      BufferLikeType callerType;
+      if (auto bufferLikeType =
+              dyn_cast<BufferLikeType>(opOperand->get().getType())) {
         // The operand was already bufferized. Take its type directly.
-        callerType = memrefType;
+        callerType = bufferLikeType;
       } else {
-        FailureOr<BaseMemRefType> maybeCallerType =
+        FailureOr<BufferLikeType> maybeCallerType =
             bufferization::getBufferType(opOperand->get(), options,
                                          invocationStack);
         if (failed(maybeCallerType))
@@ -86,14 +86,20 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
         // of the earlier forwarded operands, fall back to a buffer type with a
         // fully dynamic layout map.
 #ifndef NDEBUG
+      assert(mlir::isa<BaseMemRefType>(bufferType) &&
+             mlir::isa<BaseMemRefType>(callerType) && "expected memrefs");
+      auto memrefType = mlir::cast<BaseMemRefType>(bufferType);
+      auto callerMemrefType = mlir::cast<BaseMemRefType>(callerType);
+
       if (auto rankedTensorType = dyn_cast<RankedTensorType>(tensorType)) {
-        assert(bufferType.hasRank() && callerType.hasRank() &&
+        assert(memrefType.hasRank() && callerMemrefType.hasRank() &&
                "expected ranked memrefs");
-        assert(llvm::all_equal({bufferType.getShape(), callerType.getShape(),
-                                rankedTensorType.getShape()}) &&
-               "expected same shape");
+        assert(
+            llvm::all_equal({memrefType.getShape(), callerMemrefType.getShape(),
+                             rankedTensorType.getShape()}) &&
+            "expected same shape");
       } else {
-        assert(!bufferType.hasRank() && !callerType.hasRank() &&
+        assert(!memrefType.hasRank() && !callerMemrefType.hasRank() &&
                "expected unranked memrefs");
       }
 #endif // NDEBUG
@@ -102,8 +108,9 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
         return op->emitOpError("incoming operands of block argument have "
                                "inconsistent memory spaces");
 
-      bufferType = getMemRefTypeWithFullyDynamicLayout(
-          tensorType, bufferType.getMemorySpace());
+      bufferType =
+          mlir::cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
+              tensorType, bufferType.getMemorySpace()));
     }
 
     if (!bufferType)
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index 5e69a98db8f1e..433757192bfd1 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -26,7 +26,7 @@ struct ConstantOpInterface
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           const BufferizationOptions &options) const {
     auto constantOp = cast<arith::ConstantOp>(op);
-    auto type = dyn_cast<RankedTensorType>(constantOp.getType());
+    auto type = dyn_cast<TensorLikeType>(constantOp.getType());
 
     // Only ranked tensors are supported.
     if (!type)
@@ -176,7 +176,7 @@ struct SelectOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<bufferization::BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 SmallVector<Value> &invocationStack) const {
     auto selectOp = cast<arith::SelectOp>(op);
@@ -195,10 +195,11 @@ struct SelectOpInterface
     // If the buffers have different types, they differ only in their layout
     // map.
     auto memrefType = llvm::cast<MemRefType>(*trueType);
-    return getMemRefTypeWithFullyDynamicLayout(
-        RankedTensorType::get(memrefType.getShape(),
-                              memrefType.getElementType()),
-        memrefType.getMemorySpace());
+    return mlir::cast<bufferization::BufferLikeType>(
+        getMemRefTypeWithFullyDynamicLayout(
+            RankedTensorType::get(memrefType.getShape(),
+                                  memrefType.getElementType()),
+            memrefType.getMemorySpace()));
   }
 };
 
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 99ffa62c41a4d..82ff1bdfe5fd7 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -206,12 +206,13 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
   // Add 'memory_space' attribute. Not needed if 'copy' operand is specified.
   if (copy)
     return allocTensorOp.getResult();
-  FailureOr<BaseMemRefType> copyBufferType = getBufferType(tensor, options);
+  FailureOr<BufferLikeType> copyBufferType = getBufferType(tensor, options);
   if (failed(copyBufferType))
     return failure();
   std::optional<Attribute> memorySpace = copyBufferType->getMemorySpace();
   if (!memorySpace)
-    memorySpace = options.defaultMemorySpaceFn(tensorType);
+    memorySpace =
+        options.defaultMemorySpaceFn(mlir::cast<TensorLikeType>(tensorType));
   if (memorySpace.has_value())
     allocTensorOp.setMemorySpaceAttr(memorySpace.value());
   return allocTensorOp.getResult();
@@ -229,6 +230,7 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
   // Find all out-of-place OpOperands.
   for (OpOperand &opOperand : op->getOpOperands()) {
     Type operandType = opOperand.get().getType();
+    // Note: can only copy TensorType (any other TensorLikeType is rejected)
     if (!llvm::isa<TensorType>(operandType))
       continue;
     if (state.isInPlace(opOperand))
@@ -328,18 +330,21 @@ bool OpFilter::isOpAllowed(Operation *op) const {
 namespace {
 
 /// Default function arg type converter: Use a fully dynamic layout map.
-BaseMemRefType
-defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace,
-                                func::FuncOp funcOp,
+bufferization::BufferLikeType
+defaultFunctionArgTypeConverter(bufferization::TensorLikeType type,
+                                Attribute memorySpace, func::FuncOp funcOp,
                                 const BufferizationOptions &options) {
-  return getMemRefTypeWithFullyDynamicLayout(type, memorySpace);
+  return mlir::cast<bufferization::BufferLikeType>(
+      getMemRefTypeWithFullyDynamicLayout(mlir::cast<TensorType>(type),
+                                          memorySpace));
 }
 /// Default unknown type converter: Use a fully dynamic layout map.
-BaseMemRefType
+BufferLikeType
 defaultUnknownTypeConverter(Value value, Attribute memorySpace,
                             const BufferizationOptions &options) {
-  return getMemRefTypeWithFullyDynamicLayout(
-      llvm::cast<TensorType>(value.getType()), memorySpace);
+  return mlir::cast<bufferization::BufferLikeType>(
+      getMemRefTypeWithFullyDynamicLayout(
+          llvm::cast<TensorType>(value.getType()), memorySpace));
 }
 
 } // namespace
@@ -376,14 +381,16 @@ BufferizationOptions::dynCastBufferizableOp(Value value) const {
 
 void BufferizationOptions::setFunctionBoundaryTypeConversion(
     LayoutMapOption layoutMapOption) {
-  functionArgTypeConverterFn = [=](TensorType tensorType, Attribute memorySpace,
-                                   func::FuncOp funcOp,
+  functionArgTypeConverterFn = [=](TensorLikeType tensorType,
+                                   Attribute memorySpace, func::FuncOp funcOp,
                                    const BufferizationOptions &options) {
     if (layoutMapOption == LayoutMapOption::IdentityLayoutMap)
-      return bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType,
-                                                                  memorySpace);
-    return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
-                                                              memorySpace);
+      return mlir::cast<bufferization::BufferLikeType>(
+          bufferization::getMemRefTypeWithStaticIdentityLayout(
+              mlir::cast<TensorType>(tensorType), memorySpace));
+    return mlir::cast<bufferization::BufferLikeType>(
+        bufferization::getMemRefTypeWithFullyDynamicLayout(
+            mlir::cast<TensorType>(tensorType), memorySpace));
   };
   inferFunctionResultLayout =
       layoutMapOption == LayoutMapOption::InferLayoutMap;
@@ -473,7 +480,8 @@ bool AnalysisState::bufferizesToMemoryWrite(Value value) const {
 /// read. Also takes into account ops that create an alias but do not read by
 /// themselves (e.g., ExtractSliceOp).
 bool AnalysisState::isValueRead(Value value) const {
-  assert(llvm::isa<TensorType>(value.getType()) && "expected TensorType");
+  assert(llvm::isa<bufferization::TensorLikeType>(value.getType()) &&
+         "expected TensorLikeType");
   SmallVector<OpOperand *> workingSet;
   DenseSet<OpOperand *> visited;
   for (OpOperand &use : value.getUses())
@@ -66...
[truncated]

@andrey-golubev
Copy link
Contributor Author

andrey-golubev commented Apr 22, 2025

This is the initial patch to adopt TensorLike / BufferLike. So far not done:

  • function boundary bufferization with custom types
    • this requires custom pass that resets options.functionArgTypeConverterFn
  • SCF bufferization
  • tensor / memref copying and allocation (completely unclear what to do here)

I'd like to collect some initial feedback, whether the general direction makes sense (still), and whether there's something else I am missing.

Overall, I feel like in order to support custom types the users would have to implement a custom pass as there's no other way to specify BufferizationOptions callbacks, etc.? Thus, the bare minimum for one-shot-bufferize is to just "not break"? Anyhow, perhaps I am missing something, so let's discuss.

@@ -493,6 +494,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
// ToMemrefOp
//===----------------------------------------------------------------------===//

// TODO: rename to "to_buffer"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

side-note: I would prefer to do it in a separate patch (either before or after this one) since renaming would be (almost?) NFC

@andrey-golubev andrey-golubev marked this pull request as draft April 25, 2025 07:13
// replace op with memref analogy, preserve correct types at the boundaries
auto toMemref = rewriter.create<::mlir::bufferization::ToMemrefOp>(
getLoc(), bufferizedInType, getInput());
auto dummyMemrefOp = rewriter.create<test::TestDummyMemrefOp>(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self-review: to memref / to tensor calls must come from the infrastructure itself. here, one should just create the TestDummyMemrefOp?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants