-
Notifications
You must be signed in to change notification settings - Fork 13.3k
[mlir][bufferization] Use TensorLike, BufferLike type interfaces #136736
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[mlir][bufferization] Use TensorLike, BufferLike type interfaces #136736
Conversation
The general idea is to replace most of the places that rely on builtin's TensorType / BaseMemRefType with the newly added type interfaces. Thus far, do the bare minimum: refactor (almost) "blindly" the API of the dialect and options, leaving most of the logic "as is". The exceptions are the bufferization.{to_tensor, to_memref} ops that act as "glue" when bufferizing neighbouring operations and the enclosing functions.
@llvm/pr-subscribers-mlir-scf @llvm/pr-subscribers-mlir Author: Andrei Golubev (andrey-golubev) ChangesThe general idea is to replace most of the places that rely on builtin's TensorType / BaseMemRefType with the newly added type interfaces. Thus far, do the bare minimum: refactor (almost) "blindly" the API of the dialect and options, leaving most of the logic "as is". The exceptions are the bufferization.{to_tensor, to_memref} ops that act as "glue" when bufferizing neighbouring operations and the enclosing functions. Patch is 75.29 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/136736.diff 26 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index ada9539e87121..70092908d961f 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -17,6 +17,7 @@
#include <optional>
#include "mlir/Dialect/Bufferization/IR/BufferizationEnums.h.inc"
+#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
namespace mlir {
class OpBuilder;
@@ -259,18 +260,18 @@ struct BufferizationOptions {
std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
/// Initializer function for analysis state.
using AnalysisStateInitFn = std::function<void(AnalysisState &)>;
- /// Tensor -> MemRef type converter.
- /// Parameters: tensor type, memory space, func op, bufferization options
+ /// TensorLike -> BufferLike type converter.
+ /// Parameters: tensor like type, memory space, func op, bufferization options
using FunctionArgTypeConverterFn =
- std::function<BaseMemRefType(TensorType, Attribute memorySpace,
+ std::function<BufferLikeType(TensorLikeType, Attribute memorySpace,
func::FuncOp, const BufferizationOptions &)>;
- /// Tensor -> MemRef type converter.
+ /// TensorLike -> BufferLike type converter.
/// Parameters: Value, memory space, bufferization options
- using UnknownTypeConverterFn = std::function<BaseMemRefType(
+ using UnknownTypeConverterFn = std::function<BufferLikeType(
Value, Attribute memorySpace, const BufferizationOptions &)>;
// Produce a MemorySpace attribute from a tensor type
using DefaultMemorySpaceFn =
- std::function<std::optional<Attribute>(TensorType t)>;
+ std::function<std::optional<Attribute>(TensorLikeType t)>;
BufferizationOptions();
@@ -360,7 +361,7 @@ struct BufferizationOptions {
// Returning std::nullopt will cause bufferization to fail (useful to indicate
// failure to determine memory space for a tensor type).
DefaultMemorySpaceFn defaultMemorySpaceFn =
- [](TensorType t) -> std::optional<Attribute> { return Attribute(); };
+ [](TensorLikeType t) -> std::optional<Attribute> { return Attribute(); };
/// If set to `true`, the analysis is skipped. A buffer is copied before every
/// write. This flag cannot be used together with `testAnalysisOnly = true`.
@@ -600,7 +601,7 @@ FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
/// IR, this function can be used.
///
/// This function is a wrapper around BufferizableOpInterface::getBufferType.
-FailureOr<BaseMemRefType> getBufferType(Value value,
+FailureOr<BufferLikeType> getBufferType(Value value,
const BufferizationOptions &options);
/// Return the buffer type for a given Value (tensor) after bufferization
@@ -613,7 +614,7 @@ FailureOr<BaseMemRefType> getBufferType(Value value,
/// IR, this function can be used.
///
/// This function is a wrapper around `BufferizableOpInterface::getBufferType`.
-FailureOr<BaseMemRefType> getBufferType(Value value,
+FailureOr<BufferLikeType> getBufferType(Value value,
const BufferizationOptions &options,
SmallVector<Value> &invocationStack);
@@ -693,7 +694,7 @@ AliasingOpOperandList defaultGetAliasingOpOperands(Value value,
/// This is the default implementation of
/// BufferizableOpInterface::getBufferType. Should not be called from other
/// places.
-FailureOr<BaseMemRefType>
+FailureOr<BufferLikeType>
defaultGetBufferType(Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack);
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index 95022d7d665d2..1de1742fab81a 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -518,7 +518,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
Note: This interface method should never be called directly from user
code. Always use `bufferization::getBufferType`.
}],
- /*retType=*/"::mlir::FailureOr<::mlir::BaseMemRefType>",
+ /*retType=*/"::mlir::FailureOr<::mlir::bufferization::BufferLikeType>",
/*methodName=*/"getBufferType",
/*args=*/(ins "::mlir::Value":$value,
"const ::mlir::bufferization::BufferizationOptions &":$options,
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index fad78a63444b9..81ce0f3fb650b 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -13,6 +13,7 @@ include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td"
include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td"
include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
include "mlir/Dialect/Bufferization/IR/BufferizationBase.td"
+include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -109,7 +110,7 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
AliasingValueList getAliasingValues(
OpOperand &opOperand, const AnalysisState &state);
- FailureOr<BaseMemRefType> getBufferType(
+ FailureOr<BufferLikeType> getBufferType(
Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack);
@@ -438,11 +439,11 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
away. However, such IR is no longer bufferizable with One-Shot Bufferize.
}];
- let arguments = (ins Arg<AnyRankedOrUnrankedMemRef,
+ let arguments = (ins Arg<Bufferization_BufferLikeTypeInterface,
"the reference to load from",
[MemReadAt<0, FullEffect>]>:$memref,
UnitAttr:$restrict, UnitAttr:$writable);
- let results = (outs AnyTensor:$result);
+ let results = (outs Bufferization_TensorLikeTypeInterface:$result);
let extraClassDeclaration = [{
/// The result of a to_tensor is always a tensor.
@@ -465,10 +466,10 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
bool isWritable(Value value, const AnalysisState &state);
- FailureOr<BaseMemRefType> getBufferType(
+ FailureOr<BufferLikeType> getBufferType(
Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack) {
- return ::llvm::cast<BaseMemRefType>(getMemref().getType());
+ return ::llvm::cast<BufferLikeType>(getMemref().getType());
}
}];
@@ -493,6 +494,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
// ToMemrefOp
//===----------------------------------------------------------------------===//
+// TODO: rename to "to_buffer"
def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
BufferizableOpInterface,
SameOperandsAndResultShape,
@@ -519,8 +521,9 @@ def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
the returned buffer) will not be written to.
}];
- let arguments = (ins AnyTensor:$tensor, UnitAttr:$read_only);
- let results = (outs AnyRankedOrUnrankedMemRef:$memref);
+ let arguments = (ins Bufferization_TensorLikeTypeInterface:$tensor,
+ UnitAttr:$read_only);
+ let results = (outs Bufferization_BufferLikeTypeInterface:$memref);
let extraClassDeclaration = [{
//===------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
index 5faa1479ee542..290f1298f2501 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
@@ -13,6 +13,7 @@
// Bufferization Type Interfaces
//===----------------------------------------------------------------------===//
+#include "mlir/IR/Attributes.h" // mlir::Attribute
#include "mlir/IR/Types.h"
#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h.inc"
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td
index f19224a295648..c053a6bdc1a91 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td
@@ -33,10 +33,17 @@ def Bufferization_BufferLikeTypeInterface
let description = [{
Indicates that this type is a buffer type (similarly to a MLIR builtin
memref) for bufferization purposes.
-
- The interface currently has no methods as it is used by types to opt into
- being supported by the bufferization procedures.
}];
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Returns the memory space in which data referred to by this buffer resides.
+ }],
+ /*retType=*/"::mlir::Attribute",
+ /*methodName=*/"getMemorySpace"
+ >,
+ ];
}
#endif // BUFFERIZATION_TYPE_INTERFACES
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
index 78109770efab7..89eb65c4a0942 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
@@ -32,7 +32,7 @@ template <typename ConcreteModel, typename ConcreteOp>
struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
: public BufferizableOpInterface::ExternalModel<ConcreteModel, ConcreteOp> {
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack) const {
// Note: The user may want to override this function for OpResults in
@@ -46,7 +46,7 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
// operand types of all forwarded values. If these are all the same type,
// take that type. Otherwise, take only the memory space and fall back to a
// buffer type with a fully dynamic layout map.
- BaseMemRefType bufferType;
+ BufferLikeType bufferType;
auto tensorType = cast<TensorType>(value.getType());
for (OpOperand *opOperand :
detail::getCallerOpOperands(cast<BlockArgument>(value))) {
@@ -59,13 +59,13 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
continue;
// Compute the bufferized type of the forwarded operand.
- BaseMemRefType callerType;
- if (auto memrefType =
- dyn_cast<BaseMemRefType>(opOperand->get().getType())) {
+ BufferLikeType callerType;
+ if (auto bufferLikeType =
+ dyn_cast<BufferLikeType>(opOperand->get().getType())) {
// The operand was already bufferized. Take its type directly.
- callerType = memrefType;
+ callerType = bufferLikeType;
} else {
- FailureOr<BaseMemRefType> maybeCallerType =
+ FailureOr<BufferLikeType> maybeCallerType =
bufferization::getBufferType(opOperand->get(), options,
invocationStack);
if (failed(maybeCallerType))
@@ -86,14 +86,20 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
// of the earlier forwarded operands, fall back to a buffer type with a
// fully dynamic layout map.
#ifndef NDEBUG
+ assert(mlir::isa<BaseMemRefType>(bufferType) &&
+ mlir::isa<BaseMemRefType>(callerType) && "expected memrefs");
+ auto memrefType = mlir::cast<BaseMemRefType>(bufferType);
+ auto callerMemrefType = mlir::cast<BaseMemRefType>(callerType);
+
if (auto rankedTensorType = dyn_cast<RankedTensorType>(tensorType)) {
- assert(bufferType.hasRank() && callerType.hasRank() &&
+ assert(memrefType.hasRank() && callerMemrefType.hasRank() &&
"expected ranked memrefs");
- assert(llvm::all_equal({bufferType.getShape(), callerType.getShape(),
- rankedTensorType.getShape()}) &&
- "expected same shape");
+ assert(
+ llvm::all_equal({memrefType.getShape(), callerMemrefType.getShape(),
+ rankedTensorType.getShape()}) &&
+ "expected same shape");
} else {
- assert(!bufferType.hasRank() && !callerType.hasRank() &&
+ assert(!memrefType.hasRank() && !callerMemrefType.hasRank() &&
"expected unranked memrefs");
}
#endif // NDEBUG
@@ -102,8 +108,9 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
return op->emitOpError("incoming operands of block argument have "
"inconsistent memory spaces");
- bufferType = getMemRefTypeWithFullyDynamicLayout(
- tensorType, bufferType.getMemorySpace());
+ bufferType =
+ mlir::cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
+ tensorType, bufferType.getMemorySpace()));
}
if (!bufferType)
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index 5e69a98db8f1e..433757192bfd1 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -26,7 +26,7 @@ struct ConstantOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto constantOp = cast<arith::ConstantOp>(op);
- auto type = dyn_cast<RankedTensorType>(constantOp.getType());
+ auto type = dyn_cast<TensorLikeType>(constantOp.getType());
// Only ranked tensors are supported.
if (!type)
@@ -176,7 +176,7 @@ struct SelectOpInterface
return success();
}
- FailureOr<BaseMemRefType>
+ FailureOr<bufferization::BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack) const {
auto selectOp = cast<arith::SelectOp>(op);
@@ -195,10 +195,11 @@ struct SelectOpInterface
// If the buffers have different types, they differ only in their layout
// map.
auto memrefType = llvm::cast<MemRefType>(*trueType);
- return getMemRefTypeWithFullyDynamicLayout(
- RankedTensorType::get(memrefType.getShape(),
- memrefType.getElementType()),
- memrefType.getMemorySpace());
+ return mlir::cast<bufferization::BufferLikeType>(
+ getMemRefTypeWithFullyDynamicLayout(
+ RankedTensorType::get(memrefType.getShape(),
+ memrefType.getElementType()),
+ memrefType.getMemorySpace()));
}
};
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 99ffa62c41a4d..82ff1bdfe5fd7 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -206,12 +206,13 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
// Add 'memory_space' attribute. Not needed if 'copy' operand is specified.
if (copy)
return allocTensorOp.getResult();
- FailureOr<BaseMemRefType> copyBufferType = getBufferType(tensor, options);
+ FailureOr<BufferLikeType> copyBufferType = getBufferType(tensor, options);
if (failed(copyBufferType))
return failure();
std::optional<Attribute> memorySpace = copyBufferType->getMemorySpace();
if (!memorySpace)
- memorySpace = options.defaultMemorySpaceFn(tensorType);
+ memorySpace =
+ options.defaultMemorySpaceFn(mlir::cast<TensorLikeType>(tensorType));
if (memorySpace.has_value())
allocTensorOp.setMemorySpaceAttr(memorySpace.value());
return allocTensorOp.getResult();
@@ -229,6 +230,7 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
// Find all out-of-place OpOperands.
for (OpOperand &opOperand : op->getOpOperands()) {
Type operandType = opOperand.get().getType();
+ // Note: can only copy TensorType (any other TensorLikeType is rejected)
if (!llvm::isa<TensorType>(operandType))
continue;
if (state.isInPlace(opOperand))
@@ -328,18 +330,21 @@ bool OpFilter::isOpAllowed(Operation *op) const {
namespace {
/// Default function arg type converter: Use a fully dynamic layout map.
-BaseMemRefType
-defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace,
- func::FuncOp funcOp,
+bufferization::BufferLikeType
+defaultFunctionArgTypeConverter(bufferization::TensorLikeType type,
+ Attribute memorySpace, func::FuncOp funcOp,
const BufferizationOptions &options) {
- return getMemRefTypeWithFullyDynamicLayout(type, memorySpace);
+ return mlir::cast<bufferization::BufferLikeType>(
+ getMemRefTypeWithFullyDynamicLayout(mlir::cast<TensorType>(type),
+ memorySpace));
}
/// Default unknown type converter: Use a fully dynamic layout map.
-BaseMemRefType
+BufferLikeType
defaultUnknownTypeConverter(Value value, Attribute memorySpace,
const BufferizationOptions &options) {
- return getMemRefTypeWithFullyDynamicLayout(
- llvm::cast<TensorType>(value.getType()), memorySpace);
+ return mlir::cast<bufferization::BufferLikeType>(
+ getMemRefTypeWithFullyDynamicLayout(
+ llvm::cast<TensorType>(value.getType()), memorySpace));
}
} // namespace
@@ -376,14 +381,16 @@ BufferizationOptions::dynCastBufferizableOp(Value value) const {
void BufferizationOptions::setFunctionBoundaryTypeConversion(
LayoutMapOption layoutMapOption) {
- functionArgTypeConverterFn = [=](TensorType tensorType, Attribute memorySpace,
- func::FuncOp funcOp,
+ functionArgTypeConverterFn = [=](TensorLikeType tensorType,
+ Attribute memorySpace, func::FuncOp funcOp,
const BufferizationOptions &options) {
if (layoutMapOption == LayoutMapOption::IdentityLayoutMap)
- return bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType,
- memorySpace);
- return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
- memorySpace);
+ return mlir::cast<bufferization::BufferLikeType>(
+ bufferization::getMemRefTypeWithStaticIdentityLayout(
+ mlir::cast<TensorType>(tensorType), memorySpace));
+ return mlir::cast<bufferization::BufferLikeType>(
+ bufferization::getMemRefTypeWithFullyDynamicLayout(
+ mlir::cast<TensorType>(tensorType), memorySpace));
};
inferFunctionResultLayout =
layoutMapOption == LayoutMapOption::InferLayoutMap;
@@ -473,7 +480,8 @@ bool AnalysisState::bufferizesToMemoryWrite(Value value) const {
/// read. Also takes into account ops that create an alias but do not read by
/// themselves (e.g., ExtractSliceOp).
bool AnalysisState::isValueRead(Value value) const {
- assert(llvm::isa<TensorType>(value.getType()) && "expected TensorType");
+ assert(llvm::isa<bufferization::TensorLikeType>(value.getType()) &&
+ "expected TensorLikeType");
SmallVector<OpOperand *> workingSet;
DenseSet<OpOperand *> visited;
for (OpOperand &use : value.getUses())
@@ -66...
[truncated]
|
@llvm/pr-subscribers-mlir-arith Author: Andrei Golubev (andrey-golubev) ChangesThe general idea is to replace most of the places that rely on builtin's TensorType / BaseMemRefType with the newly added type interfaces. Thus far, do the bare minimum: refactor (almost) "blindly" the API of the dialect and options, leaving most of the logic "as is". The exceptions are the bufferization.{to_tensor, to_memref} ops that act as "glue" when bufferizing neighbouring operations and the enclosing functions. Patch is 75.29 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/136736.diff 26 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index ada9539e87121..70092908d961f 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -17,6 +17,7 @@
#include <optional>
#include "mlir/Dialect/Bufferization/IR/BufferizationEnums.h.inc"
+#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
namespace mlir {
class OpBuilder;
@@ -259,18 +260,18 @@ struct BufferizationOptions {
std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
/// Initializer function for analysis state.
using AnalysisStateInitFn = std::function<void(AnalysisState &)>;
- /// Tensor -> MemRef type converter.
- /// Parameters: tensor type, memory space, func op, bufferization options
+ /// TensorLike -> BufferLike type converter.
+ /// Parameters: tensor like type, memory space, func op, bufferization options
using FunctionArgTypeConverterFn =
- std::function<BaseMemRefType(TensorType, Attribute memorySpace,
+ std::function<BufferLikeType(TensorLikeType, Attribute memorySpace,
func::FuncOp, const BufferizationOptions &)>;
- /// Tensor -> MemRef type converter.
+ /// TensorLike -> BufferLike type converter.
/// Parameters: Value, memory space, bufferization options
- using UnknownTypeConverterFn = std::function<BaseMemRefType(
+ using UnknownTypeConverterFn = std::function<BufferLikeType(
Value, Attribute memorySpace, const BufferizationOptions &)>;
// Produce a MemorySpace attribute from a tensor type
using DefaultMemorySpaceFn =
- std::function<std::optional<Attribute>(TensorType t)>;
+ std::function<std::optional<Attribute>(TensorLikeType t)>;
BufferizationOptions();
@@ -360,7 +361,7 @@ struct BufferizationOptions {
// Returning std::nullopt will cause bufferization to fail (useful to indicate
// failure to determine memory space for a tensor type).
DefaultMemorySpaceFn defaultMemorySpaceFn =
- [](TensorType t) -> std::optional<Attribute> { return Attribute(); };
+ [](TensorLikeType t) -> std::optional<Attribute> { return Attribute(); };
/// If set to `true`, the analysis is skipped. A buffer is copied before every
/// write. This flag cannot be used together with `testAnalysisOnly = true`.
@@ -600,7 +601,7 @@ FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
/// IR, this function can be used.
///
/// This function is a wrapper around BufferizableOpInterface::getBufferType.
-FailureOr<BaseMemRefType> getBufferType(Value value,
+FailureOr<BufferLikeType> getBufferType(Value value,
const BufferizationOptions &options);
/// Return the buffer type for a given Value (tensor) after bufferization
@@ -613,7 +614,7 @@ FailureOr<BaseMemRefType> getBufferType(Value value,
/// IR, this function can be used.
///
/// This function is a wrapper around `BufferizableOpInterface::getBufferType`.
-FailureOr<BaseMemRefType> getBufferType(Value value,
+FailureOr<BufferLikeType> getBufferType(Value value,
const BufferizationOptions &options,
SmallVector<Value> &invocationStack);
@@ -693,7 +694,7 @@ AliasingOpOperandList defaultGetAliasingOpOperands(Value value,
/// This is the default implementation of
/// BufferizableOpInterface::getBufferType. Should not be called from other
/// places.
-FailureOr<BaseMemRefType>
+FailureOr<BufferLikeType>
defaultGetBufferType(Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack);
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index 95022d7d665d2..1de1742fab81a 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -518,7 +518,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
Note: This interface method should never be called directly from user
code. Always use `bufferization::getBufferType`.
}],
- /*retType=*/"::mlir::FailureOr<::mlir::BaseMemRefType>",
+ /*retType=*/"::mlir::FailureOr<::mlir::bufferization::BufferLikeType>",
/*methodName=*/"getBufferType",
/*args=*/(ins "::mlir::Value":$value,
"const ::mlir::bufferization::BufferizationOptions &":$options,
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index fad78a63444b9..81ce0f3fb650b 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -13,6 +13,7 @@ include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td"
include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td"
include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
include "mlir/Dialect/Bufferization/IR/BufferizationBase.td"
+include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -109,7 +110,7 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
AliasingValueList getAliasingValues(
OpOperand &opOperand, const AnalysisState &state);
- FailureOr<BaseMemRefType> getBufferType(
+ FailureOr<BufferLikeType> getBufferType(
Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack);
@@ -438,11 +439,11 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
away. However, such IR is no longer bufferizable with One-Shot Bufferize.
}];
- let arguments = (ins Arg<AnyRankedOrUnrankedMemRef,
+ let arguments = (ins Arg<Bufferization_BufferLikeTypeInterface,
"the reference to load from",
[MemReadAt<0, FullEffect>]>:$memref,
UnitAttr:$restrict, UnitAttr:$writable);
- let results = (outs AnyTensor:$result);
+ let results = (outs Bufferization_TensorLikeTypeInterface:$result);
let extraClassDeclaration = [{
/// The result of a to_tensor is always a tensor.
@@ -465,10 +466,10 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
bool isWritable(Value value, const AnalysisState &state);
- FailureOr<BaseMemRefType> getBufferType(
+ FailureOr<BufferLikeType> getBufferType(
Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack) {
- return ::llvm::cast<BaseMemRefType>(getMemref().getType());
+ return ::llvm::cast<BufferLikeType>(getMemref().getType());
}
}];
@@ -493,6 +494,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
// ToMemrefOp
//===----------------------------------------------------------------------===//
+// TODO: rename to "to_buffer"
def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
BufferizableOpInterface,
SameOperandsAndResultShape,
@@ -519,8 +521,9 @@ def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
the returned buffer) will not be written to.
}];
- let arguments = (ins AnyTensor:$tensor, UnitAttr:$read_only);
- let results = (outs AnyRankedOrUnrankedMemRef:$memref);
+ let arguments = (ins Bufferization_TensorLikeTypeInterface:$tensor,
+ UnitAttr:$read_only);
+ let results = (outs Bufferization_BufferLikeTypeInterface:$memref);
let extraClassDeclaration = [{
//===------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
index 5faa1479ee542..290f1298f2501 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
@@ -13,6 +13,7 @@
// Bufferization Type Interfaces
//===----------------------------------------------------------------------===//
+#include "mlir/IR/Attributes.h" // mlir::Attribute
#include "mlir/IR/Types.h"
#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h.inc"
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td
index f19224a295648..c053a6bdc1a91 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td
@@ -33,10 +33,17 @@ def Bufferization_BufferLikeTypeInterface
let description = [{
Indicates that this type is a buffer type (similarly to a MLIR builtin
memref) for bufferization purposes.
-
- The interface currently has no methods as it is used by types to opt into
- being supported by the bufferization procedures.
}];
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Returns the memory space in which data referred to by this buffer resides.
+ }],
+ /*retType=*/"::mlir::Attribute",
+ /*methodName=*/"getMemorySpace"
+ >,
+ ];
}
#endif // BUFFERIZATION_TYPE_INTERFACES
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
index 78109770efab7..89eb65c4a0942 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
@@ -32,7 +32,7 @@ template <typename ConcreteModel, typename ConcreteOp>
struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
: public BufferizableOpInterface::ExternalModel<ConcreteModel, ConcreteOp> {
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack) const {
// Note: The user may want to override this function for OpResults in
@@ -46,7 +46,7 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
// operand types of all forwarded values. If these are all the same type,
// take that type. Otherwise, take only the memory space and fall back to a
// buffer type with a fully dynamic layout map.
- BaseMemRefType bufferType;
+ BufferLikeType bufferType;
auto tensorType = cast<TensorType>(value.getType());
for (OpOperand *opOperand :
detail::getCallerOpOperands(cast<BlockArgument>(value))) {
@@ -59,13 +59,13 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
continue;
// Compute the bufferized type of the forwarded operand.
- BaseMemRefType callerType;
- if (auto memrefType =
- dyn_cast<BaseMemRefType>(opOperand->get().getType())) {
+ BufferLikeType callerType;
+ if (auto bufferLikeType =
+ dyn_cast<BufferLikeType>(opOperand->get().getType())) {
// The operand was already bufferized. Take its type directly.
- callerType = memrefType;
+ callerType = bufferLikeType;
} else {
- FailureOr<BaseMemRefType> maybeCallerType =
+ FailureOr<BufferLikeType> maybeCallerType =
bufferization::getBufferType(opOperand->get(), options,
invocationStack);
if (failed(maybeCallerType))
@@ -86,14 +86,20 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
// of the earlier forwarded operands, fall back to a buffer type with a
// fully dynamic layout map.
#ifndef NDEBUG
+ assert(mlir::isa<BaseMemRefType>(bufferType) &&
+ mlir::isa<BaseMemRefType>(callerType) && "expected memrefs");
+ auto memrefType = mlir::cast<BaseMemRefType>(bufferType);
+ auto callerMemrefType = mlir::cast<BaseMemRefType>(callerType);
+
if (auto rankedTensorType = dyn_cast<RankedTensorType>(tensorType)) {
- assert(bufferType.hasRank() && callerType.hasRank() &&
+ assert(memrefType.hasRank() && callerMemrefType.hasRank() &&
"expected ranked memrefs");
- assert(llvm::all_equal({bufferType.getShape(), callerType.getShape(),
- rankedTensorType.getShape()}) &&
- "expected same shape");
+ assert(
+ llvm::all_equal({memrefType.getShape(), callerMemrefType.getShape(),
+ rankedTensorType.getShape()}) &&
+ "expected same shape");
} else {
- assert(!bufferType.hasRank() && !callerType.hasRank() &&
+ assert(!memrefType.hasRank() && !callerMemrefType.hasRank() &&
"expected unranked memrefs");
}
#endif // NDEBUG
@@ -102,8 +108,9 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
return op->emitOpError("incoming operands of block argument have "
"inconsistent memory spaces");
- bufferType = getMemRefTypeWithFullyDynamicLayout(
- tensorType, bufferType.getMemorySpace());
+ bufferType =
+ mlir::cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
+ tensorType, bufferType.getMemorySpace()));
}
if (!bufferType)
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index 5e69a98db8f1e..433757192bfd1 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -26,7 +26,7 @@ struct ConstantOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto constantOp = cast<arith::ConstantOp>(op);
- auto type = dyn_cast<RankedTensorType>(constantOp.getType());
+ auto type = dyn_cast<TensorLikeType>(constantOp.getType());
// Only ranked tensors are supported.
if (!type)
@@ -176,7 +176,7 @@ struct SelectOpInterface
return success();
}
- FailureOr<BaseMemRefType>
+ FailureOr<bufferization::BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack) const {
auto selectOp = cast<arith::SelectOp>(op);
@@ -195,10 +195,11 @@ struct SelectOpInterface
// If the buffers have different types, they differ only in their layout
// map.
auto memrefType = llvm::cast<MemRefType>(*trueType);
- return getMemRefTypeWithFullyDynamicLayout(
- RankedTensorType::get(memrefType.getShape(),
- memrefType.getElementType()),
- memrefType.getMemorySpace());
+ return mlir::cast<bufferization::BufferLikeType>(
+ getMemRefTypeWithFullyDynamicLayout(
+ RankedTensorType::get(memrefType.getShape(),
+ memrefType.getElementType()),
+ memrefType.getMemorySpace()));
}
};
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 99ffa62c41a4d..82ff1bdfe5fd7 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -206,12 +206,13 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
// Add 'memory_space' attribute. Not needed if 'copy' operand is specified.
if (copy)
return allocTensorOp.getResult();
- FailureOr<BaseMemRefType> copyBufferType = getBufferType(tensor, options);
+ FailureOr<BufferLikeType> copyBufferType = getBufferType(tensor, options);
if (failed(copyBufferType))
return failure();
std::optional<Attribute> memorySpace = copyBufferType->getMemorySpace();
if (!memorySpace)
- memorySpace = options.defaultMemorySpaceFn(tensorType);
+ memorySpace =
+ options.defaultMemorySpaceFn(mlir::cast<TensorLikeType>(tensorType));
if (memorySpace.has_value())
allocTensorOp.setMemorySpaceAttr(memorySpace.value());
return allocTensorOp.getResult();
@@ -229,6 +230,7 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
// Find all out-of-place OpOperands.
for (OpOperand &opOperand : op->getOpOperands()) {
Type operandType = opOperand.get().getType();
+ // Note: can only copy TensorType (any other TensorLikeType is rejected)
if (!llvm::isa<TensorType>(operandType))
continue;
if (state.isInPlace(opOperand))
@@ -328,18 +330,21 @@ bool OpFilter::isOpAllowed(Operation *op) const {
namespace {
/// Default function arg type converter: Use a fully dynamic layout map.
-BaseMemRefType
-defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace,
- func::FuncOp funcOp,
+bufferization::BufferLikeType
+defaultFunctionArgTypeConverter(bufferization::TensorLikeType type,
+ Attribute memorySpace, func::FuncOp funcOp,
const BufferizationOptions &options) {
- return getMemRefTypeWithFullyDynamicLayout(type, memorySpace);
+ return mlir::cast<bufferization::BufferLikeType>(
+ getMemRefTypeWithFullyDynamicLayout(mlir::cast<TensorType>(type),
+ memorySpace));
}
/// Default unknown type converter: Use a fully dynamic layout map.
-BaseMemRefType
+BufferLikeType
defaultUnknownTypeConverter(Value value, Attribute memorySpace,
const BufferizationOptions &options) {
- return getMemRefTypeWithFullyDynamicLayout(
- llvm::cast<TensorType>(value.getType()), memorySpace);
+ return mlir::cast<bufferization::BufferLikeType>(
+ getMemRefTypeWithFullyDynamicLayout(
+ llvm::cast<TensorType>(value.getType()), memorySpace));
}
} // namespace
@@ -376,14 +381,16 @@ BufferizationOptions::dynCastBufferizableOp(Value value) const {
void BufferizationOptions::setFunctionBoundaryTypeConversion(
LayoutMapOption layoutMapOption) {
- functionArgTypeConverterFn = [=](TensorType tensorType, Attribute memorySpace,
- func::FuncOp funcOp,
+ functionArgTypeConverterFn = [=](TensorLikeType tensorType,
+ Attribute memorySpace, func::FuncOp funcOp,
const BufferizationOptions &options) {
if (layoutMapOption == LayoutMapOption::IdentityLayoutMap)
- return bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType,
- memorySpace);
- return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
- memorySpace);
+ return mlir::cast<bufferization::BufferLikeType>(
+ bufferization::getMemRefTypeWithStaticIdentityLayout(
+ mlir::cast<TensorType>(tensorType), memorySpace));
+ return mlir::cast<bufferization::BufferLikeType>(
+ bufferization::getMemRefTypeWithFullyDynamicLayout(
+ mlir::cast<TensorType>(tensorType), memorySpace));
};
inferFunctionResultLayout =
layoutMapOption == LayoutMapOption::InferLayoutMap;
@@ -473,7 +480,8 @@ bool AnalysisState::bufferizesToMemoryWrite(Value value) const {
/// read. Also takes into account ops that create an alias but do not read by
/// themselves (e.g., ExtractSliceOp).
bool AnalysisState::isValueRead(Value value) const {
- assert(llvm::isa<TensorType>(value.getType()) && "expected TensorType");
+ assert(llvm::isa<bufferization::TensorLikeType>(value.getType()) &&
+ "expected TensorLikeType");
SmallVector<OpOperand *> workingSet;
DenseSet<OpOperand *> visited;
for (OpOperand &use : value.getUses())
@@ -66...
[truncated]
|
This is the initial patch to adopt TensorLike / BufferLike. So far not done:
I'd like to collect some initial feedback, whether the general direction makes sense (still), and whether there's something else I am missing. Overall, I feel like in order to support custom types the users would have to implement a custom pass as there's no other way to specify |
@@ -493,6 +494,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [ | |||
// ToMemrefOp | |||
//===----------------------------------------------------------------------===// | |||
|
|||
// TODO: rename to "to_buffer" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
side-note: I would prefer to do it in a separate patch (either before or after this one) since renaming would be (almost?) NFC
// replace op with memref analogy, preserve correct types at the boundaries | ||
auto toMemref = rewriter.create<::mlir::bufferization::ToMemrefOp>( | ||
getLoc(), bufferizedInType, getInput()); | ||
auto dummyMemrefOp = rewriter.create<test::TestDummyMemrefOp>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self-review: to memref / to tensor calls must come from the infrastructure itself. here, one should just create the TestDummyMemrefOp
?
The general idea is to replace most of the places that rely on builtin's TensorType / BaseMemRefType with the newly added type interfaces.
Thus far, do the bare minimum: refactor (almost) "blindly" the API of the dialect and options, leaving most of the logic "as is". The exceptions are the bufferization.{to_tensor, to_memref} ops that act as "glue" when bufferizing neighbouring operations and the enclosing functions.