Skip to content

[mlir][core|ptr] Add PtrLikeTypeInterface and casting ops to the ptr dialect #137469

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class Ptr_Type<string name, string typeMnemonic, list<Trait> traits = []>

def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [
MemRefElementTypeInterface,
PtrLikeTypeInterface,
VectorElementTypeInterface,
DeclareTypeInterfaceMethods<DataLayoutTypeInterface, [
"areCompatible", "getIndexBitwidth", "verifyEntries",
Expand All @@ -63,6 +64,54 @@ def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [
return $_get(memorySpace.getContext(), memorySpace);
}]>
];
let extraClassDeclaration = [{
// `PtrLikeTypeInterface` interface methods.
/// Returns `Type()` as this pointer type is opaque.
Type getElementType() const {
return Type();
}
/// Clones the pointer with specified memory space or returns failure
/// if an `elementType` was specified or if the memory space doesn't
/// implement `MemorySpaceAttrInterface`.
FailureOr<PtrLikeTypeInterface> clonePtrWith(Attribute memorySpace,
std::optional<Type> elementType) const {
if (elementType)
return failure();
if (auto ms = dyn_cast<MemorySpaceAttrInterface>(memorySpace))
return cast<PtrLikeTypeInterface>(get(ms));
return failure();
}
/// `!ptr.ptr` types are seen as ptr-like objects with no metadata.
bool hasPtrMetadata() const {
return false;
}
}];
}

def Ptr_PtrMetadata : Ptr_Type<"PtrMetadata", "ptr_metadata"> {
let summary = "Pointer metadata type";
let description = [{
The `ptr_metadata` type represents an opaque-view of the metadata associated
with a `ptr-like` object type.
It's an error to get a `ptr_metadata` using `ptr-like` type with no
metadata.

Example:

```mlir
// The metadata associated with a `memref` type.
!ptr.ptr_metadata<memref<f32>>
```
}];
let parameters = (ins "PtrLikeTypeInterface":$type);
let assemblyFormat = "`<` $type `>`";
let builders = [
TypeBuilderWithInferredContext<(ins
"PtrLikeTypeInterface":$ptrLike), [{
return $_get(ptrLike.getContext(), ptrLike);
}]>
];
let genVerifyDecl = 1;
}

//===----------------------------------------------------------------------===//
Expand Down
96 changes: 96 additions & 0 deletions mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,72 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"
include "mlir/IR/OpAsmInterface.td"

//===----------------------------------------------------------------------===//
// FromPtrOp
//===----------------------------------------------------------------------===//

def Ptr_FromPtrOp : Pointer_Op<"from_ptr", [
Pure, OptionalTypesMatchWith<"metadata type", "result", "metadata",
"PtrMetadataType::get(cast<PtrLikeTypeInterface>($_self))">
]> {
let summary = "Casts a `!ptr.ptr` value to a ptr-like value.";
let description = [{
The `from_ptr` operation casts a `ptr` value to a ptr-like object. It's
important to note that:
- The ptr-like object cannot be a `!ptr.ptr`.
- The memory-space of both the `ptr` and ptr-like object must match.
- The cast is Pure (no UB and side-effect free).

The optional `metadata` operand exists to provide any ptr-like metadata
that might be required to perform the cast.

Example:

```mlir
%typed_ptr = ptr.from_ptr %ptr : !ptr.ptr<0> -> !my.ptr<f32, 0>
%memref = ptr.from_ptr %ptr metadata %md : !ptr.ptr<0> -> memref<f32, 0>

// Cast the `%ptr` to a memref without utilizing metadata.
%memref = ptr.from_ptr %ptr : !ptr.ptr<0> -> memref<f32, 0>
```
}];

let arguments = (ins Ptr_PtrType:$ptr, Optional<Ptr_PtrMetadata>:$metadata);
let results = (outs PtrLikeTypeInterface:$result);
let assemblyFormat = [{
$ptr (`metadata` $metadata^)? attr-dict `:` type($ptr) `->` type($result)
}];
let hasFolder = 1;
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// GetMetadataOp
//===----------------------------------------------------------------------===//

def Ptr_GetMetadataOp : Pointer_Op<"get_metadata", [
Pure, TypesMatchWith<"metadata type", "ptr", "result",
"PtrMetadataType::get(cast<PtrLikeTypeInterface>($_self))">
]> {
let summary = "SSA value representing pointer metadata.";
let description = [{
The `get_metadata` operation produces an opaque value that encodes the
metadata of the ptr-like type.

Example:

```mlir
%metadata = ptr.get_metadata %memref : memref<?x?xf32>
```
}];

let arguments = (ins PtrLikeTypeInterface:$ptr);
let results = (outs Ptr_PtrMetadata:$result);
let assemblyFormat = [{
$ptr attr-dict `:` type($ptr)
}];
}

//===----------------------------------------------------------------------===//
// PtrAddOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -52,6 +118,36 @@ def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [
}];
}

//===----------------------------------------------------------------------===//
// ToPtrOp
//===----------------------------------------------------------------------===//

def Ptr_ToPtrOp : Pointer_Op<"to_ptr", [Pure]> {
let summary = "Casts a ptr-like value to a `!ptr.ptr` value.";
let description = [{
The `to_ptr` operation casts a ptr-like object to a `!ptr.ptr`. It's
important to note that:
- The ptr-like object cannot be a `!ptr.ptr`.
- The memory-space of both the `ptr` and ptr-like object must match.
- The cast is side-effect free.

Example:

```mlir
%ptr0 = ptr.to_ptr %my_ptr : !my.ptr<f32, 0> -> !ptr.ptr<0>
%ptr1 = ptr.to_ptr %memref : memref<f32, 0> -> !ptr.ptr<0>
```
}];

let arguments = (ins PtrLikeTypeInterface:$ptr);
let results = (outs Ptr_PtrType:$result);
let assemblyFormat = [{
$ptr attr-dict `:` type($ptr) `->` type($result)
}];
let hasFolder = 1;
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// TypeOffsetOp
//===----------------------------------------------------------------------===//
Expand Down
49 changes: 49 additions & 0 deletions mlir/include/mlir/IR/BuiltinTypeInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,55 @@ def MemRefElementTypeInterface : TypeInterface<"MemRefElementTypeInterface"> {
}];
}

//===----------------------------------------------------------------------===//
// PtrLikeTypeInterface
//===----------------------------------------------------------------------===//

def PtrLikeTypeInterface : TypeInterface<"PtrLikeTypeInterface"> {
let cppNamespace = "::mlir";
let description = [{
A ptr-like type represents an object storing a memory address. This object
is constituted by:
- A memory address called the base pointer. The base pointer is an
indivisible object.
- Optional metadata about the pointer. For example, the size of the memory
region associated with the pointer.

Furthermore, all ptr-like types have two properties:
- The memory space associated with the address held by the pointer.
- An optional element type. If the element type is not specified, the
pointer is considered opaque.
}];
let methods = [
InterfaceMethod<[{
Returns the memory space of this ptr-like type.
}],
"::mlir::Attribute", "getMemorySpace">,
InterfaceMethod<[{
Returns the element type of this ptr-like type. Note: this method can
return `::mlir::Type()`, in which case the pointer is considered opaque.
}],
"::mlir::Type", "getElementType">,
InterfaceMethod<[{
Returns whether this ptr-like type has non-empty metadata.
}],
"bool", "hasPtrMetadata">,
InterfaceMethod<[{
Returns a clone of this type with the given memory space and element type,
or `failure` if the type cannot be cloned with the specified arguments.
If the pointer is opaque and `elementType` is not `std::nullopt` the
method will return `failure`.

If no `elementType` is provided and ptr is not opaque, the `elementType`
of this type is used.
}],
"::llvm::FailureOr<::mlir::PtrLikeTypeInterface>", "clonePtrWith", (ins
"::mlir::Attribute":$memorySpace,
"::std::optional<::mlir::Type>":$elementType
)>
];
}

//===----------------------------------------------------------------------===//
// ShapedType
//===----------------------------------------------------------------------===//
Expand Down
18 changes: 17 additions & 1 deletion mlir/include/mlir/IR/BuiltinTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ class TensorType : public Type, public ShapedType::Trait<TensorType> {
/// Note: This class attaches the ShapedType trait to act as a mixin to
/// provide many useful utility functions. This inheritance has no effect
/// on derived memref types.
class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
class BaseMemRefType : public Type,
public PtrLikeTypeInterface::Trait<BaseMemRefType>,
public ShapedType::Trait<BaseMemRefType> {
public:
using Type::Type;

Expand All @@ -117,6 +119,12 @@ class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
BaseMemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape,
Type elementType) const;

/// Clone this type with the given memory space and element type. If the
/// provided element type is `std::nullopt`, the current element type of the
/// type is used.
FailureOr<PtrLikeTypeInterface>
clonePtrWith(Attribute memorySpace, std::optional<Type> elementType) const;

// Make sure that base class overloads are visible.
using ShapedType::Trait<BaseMemRefType>::clone;

Expand All @@ -141,8 +149,16 @@ class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
/// New `Attribute getMemorySpace()` method should be used instead.
unsigned getMemorySpaceAsInt() const;

/// Returns that this ptr-like object has non-empty ptr metadata.
bool hasPtrMetadata() const { return true; }

/// Allow implicit conversion to ShapedType.
operator ShapedType() const { return llvm::cast<ShapedType>(*this); }

/// Allow implicit conversion to PtrLikeTypeInterface.
operator PtrLikeTypeInterface() const {
return llvm::cast<PtrLikeTypeInterface>(*this);
}
};

} // namespace mlir
Expand Down
2 changes: 2 additions & 0 deletions mlir/include/mlir/IR/BuiltinTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,7 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer",
//===----------------------------------------------------------------------===//

def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
PtrLikeTypeInterface,
ShapedTypeInterface
], "BaseMemRefType"> {
let summary = "Shaped reference to a region of memory";
Expand Down Expand Up @@ -1143,6 +1144,7 @@ def Builtin_Tuple : Builtin_Type<"Tuple", "tuple"> {
//===----------------------------------------------------------------------===//

def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
PtrLikeTypeInterface,
ShapedTypeInterface
], "BaseMemRefType"> {
let summary = "Shaped reference, with unknown rank, to a region of memory";
Expand Down
78 changes: 78 additions & 0 deletions mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,50 @@ void PtrDialect::initialize() {
>();
}

//===----------------------------------------------------------------------===//
// FromPtrOp
//===----------------------------------------------------------------------===//

OpFoldResult FromPtrOp::fold(FoldAdaptor adaptor) {
// Fold the pattern:
// %ptr = ptr.to_ptr %v : type -> ptr
// (%mda = ptr.get_metadata %v : type)?
// %val = ptr.from_ptr %ptr (metadata %mda)? : ptr -> type
// To:
// %val -> %v
Value ptrLike;
FromPtrOp fromPtr = *this;
while (fromPtr != nullptr) {
auto toPtr = dyn_cast_or_null<ToPtrOp>(fromPtr.getPtr().getDefiningOp());
// Cannot fold if it's not a `to_ptr` op or the initial and final types are
// different.
if (!toPtr || toPtr.getPtr().getType() != fromPtr.getType())
return ptrLike;
Value md = fromPtr.getMetadata();
// If there's no metadata in the op fold the op.
if (!md)
ptrLike = toPtr.getPtr();
// Fold if the metadata can be verified to be equal.
else if (auto mdOp = dyn_cast_or_null<GetMetadataOp>(md.getDefiningOp());
mdOp && mdOp.getPtr() == toPtr.getPtr())
ptrLike = toPtr.getPtr();
// Check for a sequence of casts.
fromPtr = dyn_cast_or_null<FromPtrOp>(ptrLike ? ptrLike.getDefiningOp()
: nullptr);
}
return ptrLike;
}

LogicalResult FromPtrOp::verify() {
if (isa<PtrType>(getType()))
return emitError() << "the result type cannot be `!ptr.ptr`";
if (getType().getMemorySpace() != getPtr().getType().getMemorySpace()) {
return emitError()
<< "expected the input and output to have the same memory space";
}
return success();
}

//===----------------------------------------------------------------------===//
// PtrAddOp
//===----------------------------------------------------------------------===//
Expand All @@ -55,6 +99,40 @@ OpFoldResult PtrAddOp::fold(FoldAdaptor adaptor) {
return nullptr;
}

//===----------------------------------------------------------------------===//
// ToPtrOp
//===----------------------------------------------------------------------===//

OpFoldResult ToPtrOp::fold(FoldAdaptor adaptor) {
// Fold the pattern:
// %val = ptr.from_ptr %p (metadata ...)? : ptr -> type
// %ptr = ptr.to_ptr %val : type -> ptr
// To:
// %ptr -> %p
Value ptr;
ToPtrOp toPtr = *this;
while (toPtr != nullptr) {
auto fromPtr = dyn_cast_or_null<FromPtrOp>(toPtr.getPtr().getDefiningOp());
// Cannot fold if it's not a `from_ptr` op.
if (!fromPtr)
return ptr;
ptr = fromPtr.getPtr();
// Check for chains of casts.
toPtr = dyn_cast_or_null<ToPtrOp>(ptr.getDefiningOp());
}
return ptr;
}

LogicalResult ToPtrOp::verify() {
if (isa<PtrType>(getPtr().getType()))
return emitError() << "the input value cannot be of type `!ptr.ptr`";
if (getType().getMemorySpace() != getPtr().getType().getMemorySpace()) {
return emitError()
<< "expected the input and output to have the same memory space";
}
return success();
}

//===----------------------------------------------------------------------===//
// TypeOffsetOp
//===----------------------------------------------------------------------===//
Expand Down
12 changes: 12 additions & 0 deletions mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,15 @@ LogicalResult PtrType::verifyEntries(DataLayoutEntryListRef entries,
}
return success();
}

//===----------------------------------------------------------------------===//
// Pointer metadata
//===----------------------------------------------------------------------===//

LogicalResult
PtrMetadataType::verify(function_ref<InFlightDiagnostic()> emitError,
PtrLikeTypeInterface type) {
if (!type.hasPtrMetadata())
return emitError() << "the ptr-like type has no metadata";
return success();
}
Loading