Skip to content

[mlir][core|ptr] Add PtrLikeTypeInterface and casting ops to the ptr dialect #137469

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

fabianmcg
Copy link
Contributor

@fabianmcg fabianmcg commented Apr 26, 2025

This patch adds the PtrLikeTypeInterface type interface to identify pointer-like types. This interface is defined as:

A ptr-like type represents an object storing a memory address. This object
is constituted by:
- A memory address called the base pointer. The base pointer is an
  indivisible object.
- Optional metadata about the pointer. For example, the size of the  memory
  region associated with the pointer.

Furthermore, all ptr-like types have two properties:
- The memory space associated with the address held by the pointer.
- An optional element type. If the element type is not specified, the
  pointer is considered opaque.

This patch adds this interface to !ptr.ptr and the memref type.

Furthermore, this patch adds necessary ops and type to handle casting between !ptr.ptr and ptr-like types.

First, it defines the !ptr.ptr_metadata type. An opaque type to represent the metadata of a ptr-like type. The rationale behind adding this type, is that at high-level the metadata of a type like memref cannot be specified, as its structure is tied to its lowering.

The ptr.get_metadata operation was added to extract the opaque pointer metadata. The concrete structure of the metadata is only known when the op is lowered.

Finally, this patch adds the ptr.from_ptr and ptr.to_ptr operations. Allowing to cast back and forth between !ptr.ptr and ptr-like types.

func.func @func(%mr: memref<f32, #ptr.generic_space>) -> memref<f32, #ptr.generic_space> {
  %ptr = ptr.to_ptr %mr : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
  %mda = ptr.get_metadata %mr : memref<f32, #ptr.generic_space>
  %res = ptr.from_ptr %ptr metadata %mda : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
  return %res : memref<f32, #ptr.generic_space>
}

It's future work to replace and remove the bare-ptr-convention through the use of these ops.

…tr` dialect

This patch adds the `PtrLikeTypeInterface` type interface to identify pointer-like types.
This interface is defined as:

```
A ptr-like type represents an object storing a memory address. This object
is constituted by:
- A memory address called the base pointer. The base pointer is an
  indivisible object.
- Optional metadata about the pointer. For example, the size of the  memory
  region associated with the pointer.

Furthermore, all ptr-like types have two properties:
- The memory space associated with the address held by the pointer.
- An optional element type. If the element type is not specified, the
  pointer is considered opaque.
```

This patch adds this interface to `!ptr.ptr` and the `memref` type.

Furthermore, this patch adds necessary ops and type to handle casting between `!ptr.ptr`
and ptr-like types.

First, it defines the `!ptr.ptr_metadata` type. An opaque type to represent the metadata
of a ptr-like type. The rationale behind adding this type, is that at high-level the
metadata of a type like `memref` cannot be specified, as its structure is tied to its
lowering.

The `ptr.get_metadata` operation was added to extract the opaque pointer metadata. The
concrete structure of the metadata is only known when the op is lowered.

Finally, this patch adds the `ptr.from_ptr` and `ptr.to_ptr` operations. Allowing to
cast back and forth between `!ptr.ptr` and ptr-liker types.

```mlir
func.func @func(%mr: memref<f32, #ptr.generic_space>) -> memref<f32, #ptr.generic_space> {
  %ptr = ptr.to_ptr %mr : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
  %mda = ptr.get_metadata %mr : memref<f32, #ptr.generic_space>
  %res = ptr.from_ptr %ptr metadata %mda : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
  return %res : memref<f32, #ptr.generic_space>
}
```
Copy link

⚠️ We detected that you are using a GitHub private e-mail address to contribute to the repo.
Please turn off Keep my email addresses private setting in your account.
See LLVM Discourse for more information.

@fabianmcg fabianmcg marked this pull request as ready for review April 26, 2025 19:42
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir mlir:ods labels Apr 26, 2025
@llvmbot
Copy link
Member

llvmbot commented Apr 26, 2025

@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir-ods

@llvm/pr-subscribers-mlir

Author: Fabian Mora (fabianmcg)

Changes

This patch adds the PtrLikeTypeInterface type interface to identify pointer-like types. This interface is defined as:

A ptr-like type represents an object storing a memory address. This object
is constituted by:
- A memory address called the base pointer. The base pointer is an
  indivisible object.
- Optional metadata about the pointer. For example, the size of the  memory
  region associated with the pointer.

Furthermore, all ptr-like types have two properties:
- The memory space associated with the address held by the pointer.
- An optional element type. If the element type is not specified, the
  pointer is considered opaque.

This patch adds this interface to !ptr.ptr and the memref type.

Furthermore, this patch adds necessary ops and type to handle casting between !ptr.ptr and ptr-like types.

First, it defines the !ptr.ptr_metadata type. An opaque type to represent the metadata of a ptr-like type. The rationale behind adding this type, is that at high-level the metadata of a type like memref cannot be specified, as its structure is tied to its lowering.

The ptr.get_metadata operation was added to extract the opaque pointer metadata. The concrete structure of the metadata is only known when the op is lowered.

Finally, this patch adds the ptr.from_ptr and ptr.to_ptr operations. Allowing to cast back and forth between !ptr.ptr and ptr-liker types.

func.func @<!-- -->func(%mr: memref&lt;f32, #ptr.generic_space&gt;) -&gt; memref&lt;f32, #ptr.generic_space&gt; {
  %ptr = ptr.to_ptr %mr : memref&lt;f32, #ptr.generic_space&gt; -&gt; !ptr.ptr&lt;#ptr.generic_space&gt;
  %mda = ptr.get_metadata %mr : memref&lt;f32, #ptr.generic_space&gt;
  %res = ptr.from_ptr %ptr metadata %mda : !ptr.ptr&lt;#ptr.generic_space&gt; -&gt; memref&lt;f32, #ptr.generic_space&gt;
  return %res : memref&lt;f32, #ptr.generic_space&gt;
}

It's future work to replace and remove the bare-ptr-convetion through the use of this ops.


Patch is 23.28 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/137469.diff

11 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td (+49)
  • (modified) mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td (+99)
  • (modified) mlir/include/mlir/IR/BuiltinTypeInterfaces.td (+49)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.h (+17-1)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.td (+2)
  • (modified) mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp (+75)
  • (modified) mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp (+12)
  • (modified) mlir/lib/IR/BuiltinTypes.cpp (+14)
  • (modified) mlir/test/Dialect/Ptr/canonicalize.mlir (+58)
  • (added) mlir/test/Dialect/Ptr/invalid.mlir (+33)
  • (modified) mlir/test/Dialect/Ptr/ops.mlir (+10)
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
index 73b2a0857cef3..6631b338db199 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
@@ -37,6 +37,7 @@ class Ptr_Type<string name, string typeMnemonic, list<Trait> traits = []>
 
 def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [
     MemRefElementTypeInterface,
+    PtrLikeTypeInterface,
     VectorElementTypeInterface,
     DeclareTypeInterfaceMethods<DataLayoutTypeInterface, [
       "areCompatible", "getIndexBitwidth", "verifyEntries",
@@ -63,6 +64,54 @@ def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [
       return $_get(memorySpace.getContext(), memorySpace);
     }]>
   ];
+  let extraClassDeclaration = [{
+    // `PtrLikeTypeInterface` interface methods.
+    /// Returns `Type()` as this pointer type is opaque.
+    Type getElementType() const {
+      return Type();
+    }
+    /// Clones the pointer with specified memory space or returns failure
+    /// if an `elementType` was specified or if the memory space doesn't
+    /// implement `MemorySpaceAttrInterface`.
+    FailureOr<PtrLikeTypeInterface> clonePtrWith(Attribute memorySpace,
+      std::optional<Type> elementType) const {
+      if (elementType)
+        return failure();
+      if (auto ms = dyn_cast<MemorySpaceAttrInterface>(memorySpace))
+        return cast<PtrLikeTypeInterface>(get(ms));
+      return failure();
+    }
+    /// `!ptr.ptr` types are seen as ptr-like objects with no metadata.
+    bool hasPtrMetadata() const {
+      return false;
+    }
+  }];
+}
+
+def Ptr_PtrMetadata : Ptr_Type<"PtrMetadata", "ptr_metadata"> {
+  let summary = "Pointer metadata type";
+  let description = [{
+    The `ptr_metadata` type represents an opaque-view of the metadata associated
+    with a `ptr-like` object type.
+    It's an error to get a `ptr_metadata` using `ptr-like` type with no
+    metadata.
+
+    Example:
+
+    ```mlir
+    // The metadata associated with a `memref` type.
+    !ptr.ptr_metadata<memref<f32>>
+    ```
+  }];
+  let parameters = (ins "PtrLikeTypeInterface":$type);
+  let assemblyFormat = "`<` $type `>`";
+  let builders = [
+    TypeBuilderWithInferredContext<(ins
+      "PtrLikeTypeInterface":$ptrLike), [{
+      return $_get(ptrLike.getContext(), ptrLike);
+    }]>
+  ];
+  let genVerifyDecl = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
index 791b95ad3559e..8ad475c41c8d3 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
@@ -17,6 +17,75 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Interfaces/ViewLikeInterface.td"
 include "mlir/IR/OpAsmInterface.td"
 
+//===----------------------------------------------------------------------===//
+// FromPtrOp
+//===----------------------------------------------------------------------===//
+
+def Ptr_FromPtrOp : Pointer_Op<"from_ptr", [
+    Pure, OptionalTypesMatchWith<"metadata type", "result", "metadata",
+            "PtrMetadataType::get(cast<PtrLikeTypeInterface>($_self))">
+  ]> {
+  let summary = "Casts a `!ptr.ptr` value to a ptr-like value.";
+  let description = [{
+    The `from_ptr` operation casts a `ptr` value to a ptr-like object. It's
+    important to note that:
+    - The ptr-like object cannot be a `!ptr.ptr`.
+    - The memory-space of both the `ptr` and ptr-like object must match.
+    - The cast is side-effect free.
+
+    If the ptr-like object type has metadata, then the operation expects the
+    metadata as an argument or expects that the flag `trivial_metadata` is set.
+    If `trivial_metadata` is set, then it is assumed that the metadata can be
+    reconstructed statically from the pointer-like type.
+
+    Example:
+
+    ```mlir
+    %typed_ptr = ptr.from_ptr %ptr : !ptr.ptr<0> -> !my.ptr<f32, 0>
+    %memref = ptr.from_ptr %ptr metadata %md : !ptr.ptr<0> -> memref<f32, 0>
+    %memref = ptr.from_ptr %ptr trivial_metadata : !ptr.ptr<0> -> memref<f32, 0>
+    ```
+  }];
+
+  let arguments = (ins Ptr_PtrType:$ptr,
+                       Optional<Ptr_PtrMetadata>:$metadata,
+                       UnitProp:$hasTrivialMetadata);
+  let results = (outs PtrLikeTypeInterface:$result);
+  let assemblyFormat = [{
+    $ptr (`metadata` $metadata^)? (`trivial_metadata` $hasTrivialMetadata^)?
+    attr-dict `:` type($ptr) `->` type($result)
+  }];
+  let hasFolder = 1;
+  let hasVerifier = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// GetMetadataOp
+//===----------------------------------------------------------------------===//
+
+def Ptr_GetMetadataOp : Pointer_Op<"get_metadata", [
+    Pure, TypesMatchWith<"metadata type", "ptr", "result",
+            "PtrMetadataType::get(cast<PtrLikeTypeInterface>($_self))">
+  ]> {
+  let summary = "SSA value representing pointer metadata.";
+  let description = [{
+    The `get_metadata` operation produces an opaque value that encodes the
+    metadata of the ptr-like type.
+
+    Example:
+
+    ```mlir
+    %metadata = ptr.get_metadata %memref : memref<?x?xf32>
+    ```
+  }];
+
+  let arguments = (ins PtrLikeTypeInterface:$ptr);
+  let results = (outs Ptr_PtrMetadata:$result);
+  let assemblyFormat = [{
+    $ptr attr-dict `:` type($ptr)
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // PtrAddOp
 //===----------------------------------------------------------------------===//
@@ -52,6 +121,36 @@ def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// ToPtrOp
+//===----------------------------------------------------------------------===//
+
+def Ptr_ToPtrOp : Pointer_Op<"to_ptr", [Pure]> {
+  let summary = "Casts a ptr-like value to a `!ptr.ptr` value.";
+  let description = [{
+    The `to_ptr` operation casts a ptr-like object to a `!ptr.ptr`. It's
+    important to note that:
+    - The ptr-like object cannot be a `!ptr.ptr`.
+    - The memory-space of both the `ptr` and ptr-like object must match.
+    - The cast is side-effect free.
+
+    Example:
+
+    ```mlir
+    %ptr0 = ptr.to_ptr %my_ptr : !my.ptr<f32, 0> -> !ptr.ptr<0>
+    %ptr1 = ptr.to_ptr %memref : memref<f32, 0> -> !ptr.ptr<0>
+    ```
+  }];
+
+  let arguments = (ins PtrLikeTypeInterface:$ptr);
+  let results = (outs Ptr_PtrType:$result);
+  let assemblyFormat = [{
+    $ptr attr-dict `:` type($ptr) `->` type($result)
+  }];
+  let hasFolder = 1;
+  let hasVerifier = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // TypeOffsetOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index 4a4f818b46c57..d058f6c4d9651 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -110,6 +110,55 @@ def MemRefElementTypeInterface : TypeInterface<"MemRefElementTypeInterface"> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// PtrLikeTypeInterface
+//===----------------------------------------------------------------------===//
+
+def PtrLikeTypeInterface : TypeInterface<"PtrLikeTypeInterface"> {
+  let cppNamespace = "::mlir";
+  let description = [{
+    A ptr-like type represents an object storing a memory address. This object
+    is constituted by:
+    - A memory address called the base pointer. The base pointer is an 
+      indivisible object.
+    - Optional metadata about the pointer. For example, the size of the  memory
+      region associated with the pointer.
+
+    Furthermore, all ptr-like types have two properties:
+    - The memory space associated with the address held by the pointer.
+    - An optional element type. If the element type is not specified, the
+      pointer is considered opaque.
+  }];
+  let methods = [
+    InterfaceMethod<[{
+      Returns the memory space of this ptr-like type.
+    }],
+    "::mlir::Attribute", "getMemorySpace">,
+    InterfaceMethod<[{
+      Returns the element type of this ptr-like type. Note: this method can
+      return `::mlir::Type()`, in which case the pointer is considered opaque.
+    }],
+    "::mlir::Type", "getElementType">,
+    InterfaceMethod<[{
+      Returns whether this ptr-like type has non-empty metadata.
+    }],
+    "bool", "hasPtrMetadata">,
+    InterfaceMethod<[{
+      Returns a clone of this type with the given memory space and element type,
+      or `failure` if the type cannot be cloned with the specified arguments.
+      If the pointer is opaque and `elementType` is not `std::nullopt` the
+      method will return `failure`.
+
+      If no `elementType` is provided and ptr is not opaque, the `elementType`
+      of this type is used.
+    }],
+    "::llvm::FailureOr<::mlir::PtrLikeTypeInterface>", "clonePtrWith", (ins
+      "::mlir::Attribute":$memorySpace,
+      "::std::optional<::mlir::Type>":$elementType
+    )>
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // ShapedType
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index df1e02732617d..86ec5c43970b1 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -99,7 +99,9 @@ class TensorType : public Type, public ShapedType::Trait<TensorType> {
 /// Note: This class attaches the ShapedType trait to act as a mixin to
 ///       provide many useful utility functions. This inheritance has no effect
 ///       on derived memref types.
-class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
+class BaseMemRefType : public Type,
+                       public PtrLikeTypeInterface::Trait<BaseMemRefType>,
+                       public ShapedType::Trait<BaseMemRefType> {
 public:
   using Type::Type;
 
@@ -117,6 +119,12 @@ class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
   BaseMemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape,
                            Type elementType) const;
 
+  /// Clone this type with the given memory space and element type. If the
+  /// provided element type is `std::nullopt`, the current element type of the
+  /// type is used.
+  FailureOr<PtrLikeTypeInterface>
+  clonePtrWith(Attribute memorySpace, std::optional<Type> elementType) const;
+
   // Make sure that base class overloads are visible.
   using ShapedType::Trait<BaseMemRefType>::clone;
 
@@ -141,8 +149,16 @@ class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
   /// New `Attribute getMemorySpace()` method should be used instead.
   unsigned getMemorySpaceAsInt() const;
 
+  /// Returns that this ptr-like object has non-empty ptr metadata.
+  bool hasPtrMetadata() const { return true; }
+
   /// Allow implicit conversion to ShapedType.
   operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
+
+  /// Allow implicit conversion to PtrLikeTypeInterface.
+  operator PtrLikeTypeInterface() const {
+    return llvm::cast<PtrLikeTypeInterface>(*this);
+  }
 };
 
 } // namespace mlir
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 771de01fc8d5d..9ad24e45c8315 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -562,6 +562,7 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer",
 //===----------------------------------------------------------------------===//
 
 def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
+    PtrLikeTypeInterface,
     ShapedTypeInterface
   ], "BaseMemRefType"> {
   let summary = "Shaped reference to a region of memory";
@@ -1143,6 +1144,7 @@ def Builtin_Tuple : Builtin_Type<"Tuple", "tuple"> {
 //===----------------------------------------------------------------------===//
 
 def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
+    PtrLikeTypeInterface,
     ShapedTypeInterface
   ], "BaseMemRefType"> {
   let summary = "Shaped reference, with unknown rank, to a region of memory";
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
index c21783011452f..80fd7617c9354 100644
--- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
+++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
@@ -41,6 +41,54 @@ void PtrDialect::initialize() {
       >();
 }
 
+//===----------------------------------------------------------------------===//
+// FromPtrOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult FromPtrOp::fold(FoldAdaptor adaptor) {
+  // Fold the pattern:
+  // %ptr = ptr.to_ptr %v : type -> ptr
+  // (%mda = ptr.get_metadata %v : type)?
+  // %val = ptr.from_ptr %ptr (metadata %mda)? : ptr -> type
+  // To:
+  // %val -> %v
+  auto toPtr = dyn_cast_or_null<ToPtrOp>(getPtr().getDefiningOp());
+  // Cannot fold if it's not a `to_ptr` op or the initial and final types are
+  // different.
+  if (!toPtr || toPtr.getPtr().getType() != getType())
+    return nullptr;
+  Value md = getMetadata();
+  if (!md)
+    return toPtr.getPtr();
+  // Fold if the metadata can be verified to be equal.
+  if (auto mdOp = dyn_cast_or_null<GetMetadataOp>(md.getDefiningOp());
+      mdOp && mdOp.getPtr() == toPtr.getPtr())
+    return toPtr.getPtr();
+  return nullptr;
+}
+
+LogicalResult FromPtrOp::verify() {
+  if (isa<PtrType>(getType()))
+    return emitError() << "the result type cannot be `!ptr.ptr`";
+  if (getType().getMemorySpace() != getPtr().getType().getMemorySpace()) {
+    return emitError()
+           << "expected the input and output to have the same memory space";
+  }
+  bool hasMD = getMetadata() != Value();
+  bool hasTrivialMD = getHasTrivialMetadata();
+  if (hasMD && hasTrivialMD) {
+    return emitError() << "expected either a metadata argument or the "
+                          "`trivial_metadata` flag, not both";
+  }
+  if (getType().hasPtrMetadata() && !(hasMD || hasTrivialMD)) {
+    return emitError() << "expected either a metadata argument or the "
+                          "`trivial_metadata` flag to be set";
+  }
+  if (!getType().hasPtrMetadata() && (hasMD || hasTrivialMD))
+    return emitError() << "expected no metadata specification";
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // PtrAddOp
 //===----------------------------------------------------------------------===//
@@ -55,6 +103,33 @@ OpFoldResult PtrAddOp::fold(FoldAdaptor adaptor) {
   return nullptr;
 }
 
+//===----------------------------------------------------------------------===//
+// ToPtrOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult ToPtrOp::fold(FoldAdaptor adaptor) {
+  // Fold the pattern:
+  // %val = ptr.from_ptr %p (metadata ...)? : ptr -> type
+  // %ptr = ptr.to_ptr %val : type -> ptr
+  // To:
+  // %ptr -> %p
+  auto fromPtr = dyn_cast_or_null<FromPtrOp>(getPtr().getDefiningOp());
+  // Cannot fold if it's not a `from_ptr` op.
+  if (!fromPtr)
+    return nullptr;
+  return fromPtr.getPtr();
+}
+
+LogicalResult ToPtrOp::verify() {
+  if (isa<PtrType>(getPtr().getType()))
+    return emitError() << "the input value cannot be of type `!ptr.ptr`";
+  if (getType().getMemorySpace() != getPtr().getType().getMemorySpace()) {
+    return emitError()
+           << "expected the input and output to have the same memory space";
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // TypeOffsetOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp b/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp
index cab9ca11e679e..7ad2a6bc4c80b 100644
--- a/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp
+++ b/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp
@@ -151,3 +151,15 @@ LogicalResult PtrType::verifyEntries(DataLayoutEntryListRef entries,
   }
   return success();
 }
+
+//===----------------------------------------------------------------------===//
+// Pointer metadata
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+PtrMetadataType::verify(function_ref<InFlightDiagnostic()> emitError,
+                        PtrLikeTypeInterface type) {
+  if (!type.hasPtrMetadata())
+    return emitError() << "the ptr-like type has no metadata";
+  return success();
+}
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 3924d082f0628..3032b68c1fdd4 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -376,6 +376,20 @@ BaseMemRefType BaseMemRefType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
   return builder;
 }
 
+FailureOr<PtrLikeTypeInterface>
+BaseMemRefType::clonePtrWith(Attribute memorySpace,
+                             std::optional<Type> elementType) const {
+  Type eTy = elementType ? *elementType : getElementType();
+  if (llvm::dyn_cast<UnrankedMemRefType>(*this))
+    return cast<PtrLikeTypeInterface>(
+        UnrankedMemRefType::get(eTy, memorySpace));
+
+  MemRefType::Builder builder(llvm::cast<MemRefType>(*this));
+  builder.setElementType(eTy);
+  builder.setMemorySpace(memorySpace);
+  return cast<PtrLikeTypeInterface>(static_cast<MemRefType>(builder));
+}
+
 MemRefType BaseMemRefType::clone(::llvm::ArrayRef<int64_t> shape,
                                  Type elementType) const {
   return ::llvm::cast<MemRefType>(cloneWith(shape, elementType));
diff --git a/mlir/test/Dialect/Ptr/canonicalize.mlir b/mlir/test/Dialect/Ptr/canonicalize.mlir
index ad363d554f247..837f364242beb 100644
--- a/mlir/test/Dialect/Ptr/canonicalize.mlir
+++ b/mlir/test/Dialect/Ptr/canonicalize.mlir
@@ -13,3 +13,61 @@ func.func @zero_offset(%ptr: !ptr.ptr<#ptr.generic_space>) -> !ptr.ptr<#ptr.gene
   %res0 = ptr.ptr_add %ptr, %off : !ptr.ptr<#ptr.generic_space>, index
   return %res0 : !ptr.ptr<#ptr.generic_space>
 }
+
+/// Tests the the `from_ptr` folder.
+// CHECK-LABEL: @test_from_ptr_0
+// CHECK-SAME: (%[[MEM_REF:.*]]: memref<f32, #ptr.generic_space>)
+func.func @test_from_ptr_0(%mr: memref<f32, #ptr.generic_space>) -> memref<f32, #ptr.generic_space> {
+  // CHECK-NOT: ptr.to_ptr
+  // CHECK-NOT: ptr.get_metadata
+  // CHECK-NOT: ptr.from_ptr
+  // CHECK: return %[[MEM_REF]]
+  %ptr = ptr.to_ptr %mr : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
+  %mda = ptr.get_metadata %mr : memref<f32, #ptr.generic_space>
+  %res = ptr.from_ptr %ptr metadata %mda : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+  return %res : memref<f32, #ptr.generic_space>
+}
+
+// CHECK-LABEL: @test_from_ptr_1
+// CHECK-SAME: (%[[MEM_REF:.*]]: memref<f32, #ptr.generic_space>)
+func.func @test_from_ptr_1(%mr: memref<f32, #ptr.generic_space>) -> memref<f32, #ptr.generic_space> {
+  // CHECK-NOT: ptr.to_ptr
+  // CHECK-NOT: ptr.from_ptr
+  // CHECK: return %[[MEM_REF]]
+  %ptr = ptr.to_ptr %mr : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
+  %res = ptr.from_ptr %ptr trivial_metadata : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+  return %res : memref<f32, #ptr.generic_space>
+}
+
+/// Check that the ops cannot be folded because the metadata cannot be guaranteed to be the same.
+// CHECK-LABEL: @test_from_ptr_2
+func.func @test_from_ptr_2(%mr: memref<f32, #ptr.generic_space>, %md: !ptr.ptr_metadata<memref<f32, #ptr.generic_space>>) -> memref<f32, #ptr.generic_space> {
+  // CHECK: ptr.to_ptr
+  // CHECK: ptr.from_ptr
+  %ptr = ptr.to_ptr %mr : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
+  %res = ptr.from_ptr %ptr metadata %md : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
+  return %res : memref<f32, #ptr.generic_space>
+}
+
+/// Tests the the `to_ptr` folder.
+// CHECK-LABEL: @test_to_ptr_0
+// CHECK-SAME: (%[[PTR:.*]]: !ptr.ptr<#ptr.generic_space>
+func.func @test_to_ptr_0(%ptr: !ptr.ptr<#ptr.generic_space>, %md: !ptr.ptr_metadata<memref<f32, #ptr.generic_space>>) -> !ptr.ptr...
[truncated]

return nullptr;
Value md = getMetadata();
if (!md)
return toPtr.getPtr();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about this condition.
It seems to imply that all the metadata are statically encoded in the type. Couldn't there be a type where some runtime metadata are default initialized? In which case round-tripping to a ptr would be erasing them and folding would undo this erasing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The verifier of from_ptr will check whether PtrLikeType expects non-trivial metadata, and if it does, it always asks for it. The only way to not provide the metadata is if trivial_metadata is set, which tells the compiler it's safe to assume that all info is statically encoded in the type.

Therefore in this case, it is known from the verifier that the op can be folded.

If the ptr-like object type has metadata, then the operation expects the
metadata as an argument or expects that the flag `trivial_metadata` is set.
If `trivial_metadata` is set, then it is assumed that the metadata can be
reconstructed statically from the pointer-like type.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need the trivial_metadata flag instead of just assuming it when metadata isn't provided?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's suppose that we have %v0: memref<f32, 0> in which alignedPtr != allocatedPtr. Then in the following example %v1 != %v0:

%p = to.ptr %v : memref<f32, 0> -> !ptr.ptr<0>
%v1 = from_ptr %p : !ptr.ptr<0> -> memref<f32, 0>

So in that case the cast sequence is a lossy conversion, which might be fine for the user and it should be possible to do. However, I think it makes for a less buggy experience to explicitly state that the user is ignoring the metadata.

Currently, the above IR will generate a verification error saying that from_ptr requires metadata because the type says so.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

owever, I think it makes for a less buggy experience to explicitly state that the user is ignoring the metadata.

For a programming language, I guess so, but for the IR that seems like redundant information here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It also tells the compiler that the folding is valid, see my comment in the folder of FromPtrOp.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It also tells the compiler that the folding is valid

I don't quite follow: if the absence of the flag implies trivial_metadata, then you can fold just as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not an implementation impossibility of the folding logic, but rather an affirmation to the compiler that the user guarantees it's safe to do it, even when the cast is lossy like with memref.

My main point with this flag is to make it explicit that there's something potentially hazardous taking place. Otherwise an user could end up with something like:

%v1 = from_ptr %p : !ptr.ptr<0> -> memref<f32, 0>
memref.dealloc %v1 : memref<f32, 0>

Which is potentially dangerous, and no warnings were raised. Instead having:

%v1 = from_ptr %p trivial_metadata : !ptr.ptr<0> -> memref<f32, 0>
memref.dealloc %v1 : memref<f32, 0>

At least places the burden on the user of the op.

Copy link
Collaborator

@joker-eph joker-eph May 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again: this is an IR, not a programming language: people aren't writing textual IR by hand.
We don't need to have redundant information encoded on an operation.
(even if you want to have something printed out, this does not need to be reflected as an extra property).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have removed the trivial_metadata flag.

fabianmcg and others added 2 commits April 27, 2025 07:29
@fabianmcg fabianmcg requested a review from joker-eph May 13, 2025 17:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:ods mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants