-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][core|ptr] Add PtrLikeTypeInterface
and casting ops to the ptr
dialect
#137469
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…tr` dialect This patch adds the `PtrLikeTypeInterface` type interface to identify pointer-like types. This interface is defined as: ``` A ptr-like type represents an object storing a memory address. This object is constituted by: - A memory address called the base pointer. The base pointer is an indivisible object. - Optional metadata about the pointer. For example, the size of the memory region associated with the pointer. Furthermore, all ptr-like types have two properties: - The memory space associated with the address held by the pointer. - An optional element type. If the element type is not specified, the pointer is considered opaque. ``` This patch adds this interface to `!ptr.ptr` and the `memref` type. Furthermore, this patch adds necessary ops and type to handle casting between `!ptr.ptr` and ptr-like types. First, it defines the `!ptr.ptr_metadata` type. An opaque type to represent the metadata of a ptr-like type. The rationale behind adding this type, is that at high-level the metadata of a type like `memref` cannot be specified, as its structure is tied to its lowering. The `ptr.get_metadata` operation was added to extract the opaque pointer metadata. The concrete structure of the metadata is only known when the op is lowered. Finally, this patch adds the `ptr.from_ptr` and `ptr.to_ptr` operations. Allowing to cast back and forth between `!ptr.ptr` and ptr-liker types. ```mlir func.func @func(%mr: memref<f32, #ptr.generic_space>) -> memref<f32, #ptr.generic_space> { %ptr = ptr.to_ptr %mr : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space> %mda = ptr.get_metadata %mr : memref<f32, #ptr.generic_space> %res = ptr.from_ptr %ptr metadata %mda : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space> return %res : memref<f32, #ptr.generic_space> } ```
|
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir Author: Fabian Mora (fabianmcg) ChangesThis patch adds the
This patch adds this interface to Furthermore, this patch adds necessary ops and type to handle casting between First, it defines the The Finally, this patch adds the func.func @<!-- -->func(%mr: memref<f32, #ptr.generic_space>) -> memref<f32, #ptr.generic_space> {
%ptr = ptr.to_ptr %mr : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
%mda = ptr.get_metadata %mr : memref<f32, #ptr.generic_space>
%res = ptr.from_ptr %ptr metadata %mda : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
return %res : memref<f32, #ptr.generic_space>
} It's future work to replace and remove the Patch is 23.28 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/137469.diff 11 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
index 73b2a0857cef3..6631b338db199 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
@@ -37,6 +37,7 @@ class Ptr_Type<string name, string typeMnemonic, list<Trait> traits = []>
def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [
MemRefElementTypeInterface,
+ PtrLikeTypeInterface,
VectorElementTypeInterface,
DeclareTypeInterfaceMethods<DataLayoutTypeInterface, [
"areCompatible", "getIndexBitwidth", "verifyEntries",
@@ -63,6 +64,54 @@ def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [
return $_get(memorySpace.getContext(), memorySpace);
}]>
];
+ let extraClassDeclaration = [{
+ // `PtrLikeTypeInterface` interface methods.
+ /// Returns `Type()` as this pointer type is opaque.
+ Type getElementType() const {
+ return Type();
+ }
+ /// Clones the pointer with specified memory space or returns failure
+ /// if an `elementType` was specified or if the memory space doesn't
+ /// implement `MemorySpaceAttrInterface`.
+ FailureOr<PtrLikeTypeInterface> clonePtrWith(Attribute memorySpace,
+ std::optional<Type> elementType) const {
+ if (elementType)
+ return failure();
+ if (auto ms = dyn_cast<MemorySpaceAttrInterface>(memorySpace))
+ return cast<PtrLikeTypeInterface>(get(ms));
+ return failure();
+ }
+ /// `!ptr.ptr` types are seen as ptr-like objects with no metadata.
+ bool hasPtrMetadata() const {
+ return false;
+ }
+ }];
+}
+
+def Ptr_PtrMetadata : Ptr_Type<"PtrMetadata", "ptr_metadata"> {
+ let summary = "Pointer metadata type";
+ let description = [{
+ The `ptr_metadata` type represents an opaque-view of the metadata associated
+ with a `ptr-like` object type.
+ It's an error to get a `ptr_metadata` using `ptr-like` type with no
+ metadata.
+
+ Example:
+
+ ```mlir
+ // The metadata associated with a `memref` type.
+ !ptr.ptr_metadata<memref<f32>>
+ ```
+ }];
+ let parameters = (ins "PtrLikeTypeInterface":$type);
+ let assemblyFormat = "`<` $type `>`";
+ let builders = [
+ TypeBuilderWithInferredContext<(ins
+ "PtrLikeTypeInterface":$ptrLike), [{
+ return $_get(ptrLike.getContext(), ptrLike);
+ }]>
+ ];
+ let genVerifyDecl = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
index 791b95ad3559e..8ad475c41c8d3 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
@@ -17,6 +17,75 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"
include "mlir/IR/OpAsmInterface.td"
+//===----------------------------------------------------------------------===//
+// FromPtrOp
+//===----------------------------------------------------------------------===//
+
+def Ptr_FromPtrOp : Pointer_Op<"from_ptr", [
+ Pure, OptionalTypesMatchWith<"metadata type", "result", "metadata",
+ "PtrMetadataType::get(cast<PtrLikeTypeInterface>($_self))">
+ ]> {
+ let summary = "Casts a `!ptr.ptr` value to a ptr-like value.";
+ let description = [{
+ The `from_ptr` operation casts a `ptr` value to a ptr-like object. It's
+ important to note that:
+ - The ptr-like object cannot be a `!ptr.ptr`.
+ - The memory-space of both the `ptr` and ptr-like object must match.
+ - The cast is side-effect free.
+
+ If the ptr-like object type has metadata, then the operation expects the
+ metadata as an argument or expects that the flag `trivial_metadata` is set.
+ If `trivial_metadata` is set, then it is assumed that the metadata can be
+ reconstructed statically from the pointer-like type.
+
+ Example:
+
+ ```mlir
+ %typed_ptr = ptr.from_ptr %ptr : !ptr.ptr<0> -> !my.ptr<f32, 0>
+ %memref = ptr.from_ptr %ptr metadata %md : !ptr.ptr<0> -> memref<f32, 0>
+ %memref = ptr.from_ptr %ptr trivial_metadata : !ptr.ptr<0> -> memref<f32, 0>
+ ```
+ }];
+
+ let arguments = (ins Ptr_PtrType:$ptr,
+ Optional<Ptr_PtrMetadata>:$metadata,
+ UnitProp:$hasTrivialMetadata);
+ let results = (outs PtrLikeTypeInterface:$result);
+ let assemblyFormat = [{
+ $ptr (`metadata` $metadata^)? (`trivial_metadata` $hasTrivialMetadata^)?
+ attr-dict `:` type($ptr) `->` type($result)
+ }];
+ let hasFolder = 1;
+ let hasVerifier = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// GetMetadataOp
+//===----------------------------------------------------------------------===//
+
+def Ptr_GetMetadataOp : Pointer_Op<"get_metadata", [
+ Pure, TypesMatchWith<"metadata type", "ptr", "result",
+ "PtrMetadataType::get(cast<PtrLikeTypeInterface>($_self))">
+ ]> {
+ let summary = "SSA value representing pointer metadata.";
+ let description = [{
+ The `get_metadata` operation produces an opaque value that encodes the
+ metadata of the ptr-like type.
+
+ Example:
+
+ ```mlir
+ %metadata = ptr.get_metadata %memref : memref<?x?xf32>
+ ```
+ }];
+
+ let arguments = (ins PtrLikeTypeInterface:$ptr);
+ let results = (outs Ptr_PtrMetadata:$result);
+ let assemblyFormat = [{
+ $ptr attr-dict `:` type($ptr)
+ }];
+}
+
//===----------------------------------------------------------------------===//
// PtrAddOp
//===----------------------------------------------------------------------===//
@@ -52,6 +121,36 @@ def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [
}];
}
+//===----------------------------------------------------------------------===//
+// ToPtrOp
+//===----------------------------------------------------------------------===//
+
+def Ptr_ToPtrOp : Pointer_Op<"to_ptr", [Pure]> {
+ let summary = "Casts a ptr-like value to a `!ptr.ptr` value.";
+ let description = [{
+ The `to_ptr` operation casts a ptr-like object to a `!ptr.ptr`. It's
+ important to note that:
+ - The ptr-like object cannot be a `!ptr.ptr`.
+ - The memory-space of both the `ptr` and ptr-like object must match.
+ - The cast is side-effect free.
+
+ Example:
+
+ ```mlir
+ %ptr0 = ptr.to_ptr %my_ptr : !my.ptr<f32, 0> -> !ptr.ptr<0>
+ %ptr1 = ptr.to_ptr %memref : memref<f32, 0> -> !ptr.ptr<0>
+ ```
+ }];
+
+ let arguments = (ins PtrLikeTypeInterface:$ptr);
+ let results = (outs Ptr_PtrType:$result);
+ let assemblyFormat = [{
+ $ptr attr-dict `:` type($ptr) `->` type($result)
+ }];
+ let hasFolder = 1;
+ let hasVerifier = 1;
+}
+
//===----------------------------------------------------------------------===//
// TypeOffsetOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index 4a4f818b46c57..d058f6c4d9651 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -110,6 +110,55 @@ def MemRefElementTypeInterface : TypeInterface<"MemRefElementTypeInterface"> {
}];
}
+//===----------------------------------------------------------------------===//
+// PtrLikeTypeInterface
+//===----------------------------------------------------------------------===//
+
+def PtrLikeTypeInterface : TypeInterface<"PtrLikeTypeInterface"> {
+ let cppNamespace = "::mlir";
+ let description = [{
+ A ptr-like type represents an object storing a memory address. This object
+ is constituted by:
+ - A memory address called the base pointer. The base pointer is an
+ indivisible object.
+ - Optional metadata about the pointer. For example, the size of the memory
+ region associated with the pointer.
+
+ Furthermore, all ptr-like types have two properties:
+ - The memory space associated with the address held by the pointer.
+ - An optional element type. If the element type is not specified, the
+ pointer is considered opaque.
+ }];
+ let methods = [
+ InterfaceMethod<[{
+ Returns the memory space of this ptr-like type.
+ }],
+ "::mlir::Attribute", "getMemorySpace">,
+ InterfaceMethod<[{
+ Returns the element type of this ptr-like type. Note: this method can
+ return `::mlir::Type()`, in which case the pointer is considered opaque.
+ }],
+ "::mlir::Type", "getElementType">,
+ InterfaceMethod<[{
+ Returns whether this ptr-like type has non-empty metadata.
+ }],
+ "bool", "hasPtrMetadata">,
+ InterfaceMethod<[{
+ Returns a clone of this type with the given memory space and element type,
+ or `failure` if the type cannot be cloned with the specified arguments.
+ If the pointer is opaque and `elementType` is not `std::nullopt` the
+ method will return `failure`.
+
+ If no `elementType` is provided and ptr is not opaque, the `elementType`
+ of this type is used.
+ }],
+ "::llvm::FailureOr<::mlir::PtrLikeTypeInterface>", "clonePtrWith", (ins
+ "::mlir::Attribute":$memorySpace,
+ "::std::optional<::mlir::Type>":$elementType
+ )>
+ ];
+}
+
//===----------------------------------------------------------------------===//
// ShapedType
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index df1e02732617d..86ec5c43970b1 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -99,7 +99,9 @@ class TensorType : public Type, public ShapedType::Trait<TensorType> {
/// Note: This class attaches the ShapedType trait to act as a mixin to
/// provide many useful utility functions. This inheritance has no effect
/// on derived memref types.
-class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
+class BaseMemRefType : public Type,
+ public PtrLikeTypeInterface::Trait<BaseMemRefType>,
+ public ShapedType::Trait<BaseMemRefType> {
public:
using Type::Type;
@@ -117,6 +119,12 @@ class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
BaseMemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape,
Type elementType) const;
+ /// Clone this type with the given memory space and element type. If the
+ /// provided element type is `std::nullopt`, the current element type of the
+ /// type is used.
+ FailureOr<PtrLikeTypeInterface>
+ clonePtrWith(Attribute memorySpace, std::optional<Type> elementType) const;
+
// Make sure that base class overloads are visible.
using ShapedType::Trait<BaseMemRefType>::clone;
@@ -141,8 +149,16 @@ class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
/// New `Attribute getMemorySpace()` method should be used instead.
unsigned getMemorySpaceAsInt() const;
+ /// Returns that this ptr-like object has non-empty ptr metadata.
+ bool hasPtrMetadata() const { return true; }
+
/// Allow implicit conversion to ShapedType.
operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
+
+ /// Allow implicit conversion to PtrLikeTypeInterface.
+ operator PtrLikeTypeInterface() const {
+ return llvm::cast<PtrLikeTypeInterface>(*this);
+ }
};
} // namespace mlir
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 771de01fc8d5d..9ad24e45c8315 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -562,6 +562,7 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer",
//===----------------------------------------------------------------------===//
def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
+ PtrLikeTypeInterface,
ShapedTypeInterface
], "BaseMemRefType"> {
let summary = "Shaped reference to a region of memory";
@@ -1143,6 +1144,7 @@ def Builtin_Tuple : Builtin_Type<"Tuple", "tuple"> {
//===----------------------------------------------------------------------===//
def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
+ PtrLikeTypeInterface,
ShapedTypeInterface
], "BaseMemRefType"> {
let summary = "Shaped reference, with unknown rank, to a region of memory";
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
index c21783011452f..80fd7617c9354 100644
--- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
+++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
@@ -41,6 +41,54 @@ void PtrDialect::initialize() {
>();
}
+//===----------------------------------------------------------------------===//
+// FromPtrOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult FromPtrOp::fold(FoldAdaptor adaptor) {
+ // Fold the pattern:
+ // %ptr = ptr.to_ptr %v : type -> ptr
+ // (%mda = ptr.get_metadata %v : type)?
+ // %val = ptr.from_ptr %ptr (metadata %mda)? : ptr -> type
+ // To:
+ // %val -> %v
+ auto toPtr = dyn_cast_or_null<ToPtrOp>(getPtr().getDefiningOp());
+ // Cannot fold if it's not a `to_ptr` op or the initial and final types are
+ // different.
+ if (!toPtr || toPtr.getPtr().getType() != getType())
+ return nullptr;
+ Value md = getMetadata();
+ if (!md)
+ return toPtr.getPtr();
+ // Fold if the metadata can be verified to be equal.
+ if (auto mdOp = dyn_cast_or_null<GetMetadataOp>(md.getDefiningOp());
+ mdOp && mdOp.getPtr() == toPtr.getPtr())
+ return toPtr.getPtr();
+ return nullptr;
+}
+
+LogicalResult FromPtrOp::verify() {
+ if (isa<PtrType>(getType()))
+ return emitError() << "the result type cannot be `!ptr.ptr`";
+ if (getType().getMemorySpace() != getPtr().getType().getMemorySpace()) {
+ return emitError()
+ << "expected the input and output to have the same memory space";
+ }
+ bool hasMD = getMetadata() != Value();
+ bool hasTrivialMD = getHasTrivialMetadata();
+ if (hasMD && hasTrivialMD) {
+ return emitError() << "expected either a metadata argument or the "
+ "`trivial_metadata` flag, not both";
+ }
+ if (getType().hasPtrMetadata() && !(hasMD || hasTrivialMD)) {
+ return emitError() << "expected either a metadata argument or the "
+ "`trivial_metadata` flag to be set";
+ }
+ if (!getType().hasPtrMetadata() && (hasMD || hasTrivialMD))
+ return emitError() << "expected no metadata specification";
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// PtrAddOp
//===----------------------------------------------------------------------===//
@@ -55,6 +103,33 @@ OpFoldResult PtrAddOp::fold(FoldAdaptor adaptor) {
return nullptr;
}
+//===----------------------------------------------------------------------===//
+// ToPtrOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult ToPtrOp::fold(FoldAdaptor adaptor) {
+ // Fold the pattern:
+ // %val = ptr.from_ptr %p (metadata ...)? : ptr -> type
+ // %ptr = ptr.to_ptr %val : type -> ptr
+ // To:
+ // %ptr -> %p
+ auto fromPtr = dyn_cast_or_null<FromPtrOp>(getPtr().getDefiningOp());
+ // Cannot fold if it's not a `from_ptr` op.
+ if (!fromPtr)
+ return nullptr;
+ return fromPtr.getPtr();
+}
+
+LogicalResult ToPtrOp::verify() {
+ if (isa<PtrType>(getPtr().getType()))
+ return emitError() << "the input value cannot be of type `!ptr.ptr`";
+ if (getType().getMemorySpace() != getPtr().getType().getMemorySpace()) {
+ return emitError()
+ << "expected the input and output to have the same memory space";
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// TypeOffsetOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp b/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp
index cab9ca11e679e..7ad2a6bc4c80b 100644
--- a/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp
+++ b/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp
@@ -151,3 +151,15 @@ LogicalResult PtrType::verifyEntries(DataLayoutEntryListRef entries,
}
return success();
}
+
+//===----------------------------------------------------------------------===//
+// Pointer metadata
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+PtrMetadataType::verify(function_ref<InFlightDiagnostic()> emitError,
+ PtrLikeTypeInterface type) {
+ if (!type.hasPtrMetadata())
+ return emitError() << "the ptr-like type has no metadata";
+ return success();
+}
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 3924d082f0628..3032b68c1fdd4 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -376,6 +376,20 @@ BaseMemRefType BaseMemRefType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
return builder;
}
+FailureOr<PtrLikeTypeInterface>
+BaseMemRefType::clonePtrWith(Attribute memorySpace,
+ std::optional<Type> elementType) const {
+ Type eTy = elementType ? *elementType : getElementType();
+ if (llvm::dyn_cast<UnrankedMemRefType>(*this))
+ return cast<PtrLikeTypeInterface>(
+ UnrankedMemRefType::get(eTy, memorySpace));
+
+ MemRefType::Builder builder(llvm::cast<MemRefType>(*this));
+ builder.setElementType(eTy);
+ builder.setMemorySpace(memorySpace);
+ return cast<PtrLikeTypeInterface>(static_cast<MemRefType>(builder));
+}
+
MemRefType BaseMemRefType::clone(::llvm::ArrayRef<int64_t> shape,
Type elementType) const {
return ::llvm::cast<MemRefType>(cloneWith(shape, elementType));
diff --git a/mlir/test/Dialect/Ptr/canonicalize.mlir b/mlir/test/Dialect/Ptr/canonicalize.mlir
index ad363d554f247..837f364242beb 100644
--- a/mlir/test/Dialect/Ptr/canonicalize.mlir
+++ b/mlir/test/Dialect/Ptr/canonicalize.mlir
@@ -13,3 +13,61 @@ func.func @zero_offset(%ptr: !ptr.ptr<#ptr.generic_space>) -> !ptr.ptr<#ptr.gene
%res0 = ptr.ptr_add %ptr, %off : !ptr.ptr<#ptr.generic_space>, index
return %res0 : !ptr.ptr<#ptr.generic_space>
}
+
+/// Tests the the `from_ptr` folder.
+// CHECK-LABEL: @test_from_ptr_0
+// CHECK-SAME: (%[[MEM_REF:.*]]: memref<f32, #ptr.generic_space>)
+func.func @test_from_ptr_0(%mr: memref<f32, #ptr.generic_space>) -> memref<f32, #ptr.generic_space> {
+ // CHECK-NOT: ptr.to_ptr
+ // CHECK-NOT: ptr.get_metadata
+ // CHECK-NOT: ptr.from_ptr
+ // CHECK: return %[[MEM_REF]]
+ %ptr = ptr.to_ptr %mr : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
+ %mda = ptr.get_metadata %mr : memref<f32, #ptr.generic_space>
+ %res = ptr.from_ptr %ptr metadata %mda : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+ return %res : memref<f32, #ptr.generic_space>
+}
+
+// CHECK-LABEL: @test_from_ptr_1
+// CHECK-SAME: (%[[MEM_REF:.*]]: memref<f32, #ptr.generic_space>)
+func.func @test_from_ptr_1(%mr: memref<f32, #ptr.generic_space>) -> memref<f32, #ptr.generic_space> {
+ // CHECK-NOT: ptr.to_ptr
+ // CHECK-NOT: ptr.from_ptr
+ // CHECK: return %[[MEM_REF]]
+ %ptr = ptr.to_ptr %mr : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
+ %res = ptr.from_ptr %ptr trivial_metadata : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+ return %res : memref<f32, #ptr.generic_space>
+}
+
+/// Check that the ops cannot be folded because the metadata cannot be guaranteed to be the same.
+// CHECK-LABEL: @test_from_ptr_2
+func.func @test_from_ptr_2(%mr: memref<f32, #ptr.generic_space>, %md: !ptr.ptr_metadata<memref<f32, #ptr.generic_space>>) -> memref<f32, #ptr.generic_space> {
+ // CHECK: ptr.to_ptr
+ // CHECK: ptr.from_ptr
+ %ptr = ptr.to_ptr %mr : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
+ %res = ptr.from_ptr %ptr metadata %md : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+ return %res : memref<f32, #ptr.generic_space>
+}
+
+/// Tests the the `to_ptr` folder.
+// CHECK-LABEL: @test_to_ptr_0
+// CHECK-SAME: (%[[PTR:.*]]: !ptr.ptr<#ptr.generic_space>
+func.func @test_to_ptr_0(%ptr: !ptr.ptr<#ptr.generic_space>, %md: !ptr.ptr_metadata<memref<f32, #ptr.generic_space>>) -> !ptr.ptr...
[truncated]
|
return nullptr; | ||
Value md = getMetadata(); | ||
if (!md) | ||
return toPtr.getPtr(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure about this condition.
It seems to imply that all the metadata are statically encoded in the type. Couldn't there be a type where some runtime metadata are default initialized? In which case round-tripping to a ptr
would be erasing them and folding would undo this erasing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The verifier of from_ptr
will check whether PtrLikeType
expects non-trivial metadata, and if it does, it always asks for it. The only way to not provide the metadata is if trivial_metadata
is set, which tells the compiler it's safe to assume that all info is statically encoded in the type.
Therefore in this case, it is known from the verifier that the op can be folded.
If the ptr-like object type has metadata, then the operation expects the | ||
metadata as an argument or expects that the flag `trivial_metadata` is set. | ||
If `trivial_metadata` is set, then it is assumed that the metadata can be | ||
reconstructed statically from the pointer-like type. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need the trivial_metadata
flag instead of just assuming it when metadata isn't provided?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's suppose that we have %v0: memref<f32, 0>
in which alignedPtr != allocatedPtr
. Then in the following example %v1 != %v0
:
%p = to.ptr %v : memref<f32, 0> -> !ptr.ptr<0>
%v1 = from_ptr %p : !ptr.ptr<0> -> memref<f32, 0>
So in that case the cast sequence is a lossy conversion, which might be fine for the user and it should be possible to do. However, I think it makes for a less buggy experience to explicitly state that the user is ignoring the metadata.
Currently, the above IR will generate a verification error saying that from_ptr
requires metadata because the type says so.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
owever, I think it makes for a less buggy experience to explicitly state that the user is ignoring the metadata.
For a programming language, I guess so, but for the IR that seems like redundant information here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It also tells the compiler that the folding is valid, see my comment in the folder of FromPtrOp
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It also tells the compiler that the folding is valid
I don't quite follow: if the absence of the flag implies trivial_metadata, then you can fold just as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not an implementation impossibility of the folding logic, but rather an affirmation to the compiler that the user guarantees it's safe to do it, even when the cast is lossy like with memref
.
My main point with this flag is to make it explicit that there's something potentially hazardous taking place. Otherwise an user could end up with something like:
%v1 = from_ptr %p : !ptr.ptr<0> -> memref<f32, 0>
memref.dealloc %v1 : memref<f32, 0>
Which is potentially dangerous, and no warnings were raised. Instead having:
%v1 = from_ptr %p trivial_metadata : !ptr.ptr<0> -> memref<f32, 0>
memref.dealloc %v1 : memref<f32, 0>
At least places the burden on the user of the op.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again: this is an IR, not a programming language: people aren't writing textual IR by hand.
We don't need to have redundant information encoded on an operation.
(even if you want to have something printed out, this does not need to be reflected as an extra property).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have removed the trivial_metadata
flag.
Co-authored-by: Mehdi Amini <joker.eph@gmail.com>
This patch adds the
PtrLikeTypeInterface
type interface to identify pointer-like types. This interface is defined as:This patch adds this interface to
!ptr.ptr
and thememref
type.Furthermore, this patch adds necessary ops and type to handle casting between
!ptr.ptr
and ptr-like types.First, it defines the
!ptr.ptr_metadata
type. An opaque type to represent the metadata of a ptr-like type. The rationale behind adding this type, is that at high-level the metadata of a type likememref
cannot be specified, as its structure is tied to its lowering.The
ptr.get_metadata
operation was added to extract the opaque pointer metadata. The concrete structure of the metadata is only known when the op is lowered.Finally, this patch adds the
ptr.from_ptr
andptr.to_ptr
operations. Allowing to cast back and forth between!ptr.ptr
and ptr-like types.It's future work to replace and remove the
bare-ptr-convention
through the use of these ops.