-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[MLIR][NVVM] Add prefetch Ops #141737
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][NVVM] Add prefetch Ops #141737
Conversation
@llvm/pr-subscribers-mlir-llvm Author: Srinivasa Ravi (Wolfram70) ChangesThis change adds PTX Spec Reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-prefetch-prefetchu Full diff: https://github.com/llvm/llvm-project/pull/141737.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 408537be0a5e4..311847a27a5f0 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -25,6 +25,7 @@ include "mlir/Dialect/LLVMIR/LLVMTypes.td"
def LLVM_PointerGeneric : LLVM_PointerInAddressSpace<0>;
def LLVM_PointerGlobal : LLVM_PointerInAddressSpace<1>;
def LLVM_PointerShared : LLVM_PointerInAddressSpace<3>;
+def LLVM_PointerLocal : LLVM_PointerInAddressSpace<5>;
def LLVM_PointerTensor : LLVM_PointerInAddressSpace<6>;
def LLVM_PointerSharedCluster : LLVM_PointerInAddressSpace<7>;
@@ -2333,6 +2334,90 @@ def NVVM_CpAsyncBulkTensorSharedCTAToGlobalOp :
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// NVVM Prefetch Ops
+//===----------------------------------------------------------------------===//
+
+def NVVM_PrefetchL1Op : NVVM_Op<"prefetch.L1"> {
+ let description = [{
+ Brings the cache line containing the specified address into L1 cache.
+
+ Operand `ptr` can be a global, local or generic address pointer.
+ No operation is performed if `ptr` maps to a `shared` memory location.
+
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-prefetch-prefetchu)
+ }];
+ let arguments = (ins AnyTypeOf<[LLVM_PointerGlobal,
+ LLVM_PointerLocal,
+ LLVM_PointerGeneric]>:$ptr);
+ let assemblyFormat = "$ptr attr-dict `:` type($ptr)";
+
+ let extraClassDeclaration = [{
+ static llvm::Intrinsic::ID getIntrinsicID(llvm::Type *ptrType);
+ }];
+ let llvmBuilder = [{
+ auto intId = NVVM::PrefetchL1Op::getIntrinsicID($ptr->getType());
+ createIntrinsicCall(builder, intId, $ptr);
+ }];
+}
+
+def EvictLast : I32EnumAttrCase<"EvictLast", 0, "evict_last">;
+def EvictNormal : I32EnumAttrCase<"EvictNormal", 1, "evict_normal">;
+
+def EvictionPriority : I32EnumAttr<"EvictionPriority", "NVVM Eviction Priority",
+ [EvictLast, EvictNormal]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::NVVM";
+}
+
+def EvictionPriorityAttr : EnumAttr<NVVM_Dialect, EvictionPriority, "eviction_priority"> {
+ let assemblyFormat = "$value";
+}
+
+def NVVM_PrefetchL2Op : NVVM_Op<"prefetch.L2"> {
+ let description = [{
+ Brings the cache line containing the specified address into L2 cache.
+
+ Operand `ptr` can be a global, local or generic address pointer.
+ No operation is performed if `ptr` maps to a `shared` memory location.
+
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-prefetch-prefetchu)
+ }];
+ let arguments = (ins AnyTypeOf<[LLVM_PointerGlobal,
+ LLVM_PointerLocal,
+ LLVM_PointerGeneric]>:$ptr,
+ OptionalAttr<EvictionPriorityAttr>:$evictionPriority);
+ let assemblyFormat = "$ptr (`,` `evict_priority` `=` $evictionPriority^)? attr-dict `:` type($ptr)";
+ let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ static llvm::Intrinsic::ID getIntrinsicID(llvm::Type *ptrType, std::optional<NVVM::EvictionPriority> evictionPriority);
+ }];
+ let llvmBuilder = [{
+ auto intId = NVVM::PrefetchL2Op::getIntrinsicID($ptr->getType(), $evictionPriority);
+ createIntrinsicCall(builder, intId, $ptr);
+ }];
+}
+
+def NVVM_PrefetchL1UniformOp : NVVM_Op<"prefetch.L1.uniform"> {
+ let description = [{
+ Brings the cache line containing the specified address into L1 uniform
+ cache.
+
+ Operand `ptr` is a generic address pointer.
+ No operation is performed if `ptr` maps to a `const`, `local`, or `shared`
+ memory location.
+
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-prefetch-prefetchu)
+ }];
+ let arguments = (ins LLVM_PointerGeneric:$ptr);
+ let assemblyFormat = "$ptr attr-dict `:` type($ptr)";
+
+ let llvmBuilder = [{
+ createIntrinsicCall(builder, llvm::Intrinsic::nvvm_prefetchu_L1, $ptr);
+ }];
+}
+
def NVVM_PrefetchTensorMapOp : NVVM_Op<"prefetch.tensormap",
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
Arguments<(ins LLVM_AnyPointer:$tmaDescriptor, PtxPredicate:$predicate)> {
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 8036ea27f524f..1e0039a3a0541 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1205,6 +1205,15 @@ LogicalResult NVVM::VoteSyncOp::verify() {
return success();
}
+LogicalResult NVVM::PrefetchL2Op::verify() {
+ if (getEvictionPriority() &&
+ (llvm::cast<LLVM::LLVMPointerType>(getPtr().getType())
+ .getAddressSpace() != 1))
+ return emitOpError(
+ "prefetch with eviction priority requires a global pointer");
+ return success();
+}
+
llvm::Value *
NVVM::DotAccumulate4WayOp::getPackedArg(llvm::Value *arg,
llvm::IRBuilderBase &builder) {
@@ -1712,6 +1721,42 @@ DotAccumulate4WayOp::getIntrinsicID(NVVM::DotAccumulate4WayType a_type,
}
}
+llvm::Intrinsic::ID PrefetchL1Op::getIntrinsicID(llvm::Type *ptrType) {
+ switch (ptrType->getPointerAddressSpace()) {
+ case 0:
+ return llvm::Intrinsic::nvvm_prefetch_L1;
+ case 1:
+ return llvm::Intrinsic::nvvm_prefetch_global_L1;
+ case 5:
+ return llvm::Intrinsic::nvvm_prefetch_local_L1;
+ default:
+ llvm_unreachable("Invalid pointer address space");
+ }
+}
+
+llvm::Intrinsic::ID PrefetchL2Op::getIntrinsicID(
+ llvm::Type *ptrType,
+ std::optional<NVVM::EvictionPriority> evictionPriority) {
+ switch (ptrType->getPointerAddressSpace()) {
+ case 0:
+ return llvm::Intrinsic::nvvm_prefetch_L2;
+ case 1:
+ if (evictionPriority) {
+ if (*evictionPriority == NVVM::EvictionPriority::EvictLast)
+ return llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last;
+ else if (*evictionPriority == NVVM::EvictionPriority::EvictNormal)
+ return llvm::Intrinsic::nvvm_prefetch_global_L2_evict_normal;
+ else
+ llvm_unreachable("Invalid eviction priority");
+ }
+ return llvm::Intrinsic::nvvm_prefetch_global_L2;
+ case 5:
+ return llvm::Intrinsic::nvvm_prefetch_local_L2;
+ default:
+ llvm_unreachable("Invalid pointer address space");
+ }
+}
+
//===----------------------------------------------------------------------===//
// NVVMDialect initialization, type parsing, and registration.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index e8425638cc9be..704f9d28fd5ae 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -587,6 +587,29 @@ func.func @dot_accumulate_4way(%a: i32, %a_vec: vector<4xi8>, %b: i32, %b_vec: v
return
}
+// CHECK-LABEL: @prefetch
+func.func @prefetch(%gen_ptr: !llvm.ptr, %local_ptr: !llvm.ptr<5>, %global_ptr: !llvm.ptr<1>) {
+ // CHECK: nvvm.prefetch.L1 %{{.*}}
+ nvvm.prefetch.L1 %gen_ptr : !llvm.ptr<0>
+ // CHECK: nvvm.prefetch.L1 %{{.*}}
+ nvvm.prefetch.L1 %local_ptr : !llvm.ptr<5>
+ // CHECK: nvvm.prefetch.L1 %{{.*}}
+ nvvm.prefetch.L1 %global_ptr : !llvm.ptr<1>
+ // CHECK: nvvm.prefetch.L2 %{{.*}}
+ nvvm.prefetch.L2 %gen_ptr : !llvm.ptr<0>
+ // CHECK: nvvm.prefetch.L2 %{{.*}}
+ nvvm.prefetch.L2 %local_ptr : !llvm.ptr<5>
+ // CHECK: nvvm.prefetch.L2 %{{.*}}
+ nvvm.prefetch.L2 %global_ptr : !llvm.ptr<1>
+ // CHECK: nvvm.prefetch.L2 %{{.*}}
+ nvvm.prefetch.L2 %global_ptr, evict_priority = evict_last : !llvm.ptr<1>
+ // CHECK: nvvm.prefetch.L2 %{{.*}}
+ nvvm.prefetch.L2 %global_ptr, evict_priority = evict_normal : !llvm.ptr<1>
+ // CHECK: nvvm.prefetch.L1.uniform %{{.*}}
+ nvvm.prefetch.L1.uniform %gen_ptr : !llvm.ptr
+ return
+}
+
// -----
// Just check these don't emit errors.
diff --git a/mlir/test/Target/LLVMIR/nvvm/prefetch.mlir b/mlir/test/Target/LLVMIR/nvvm/prefetch.mlir
new file mode 100644
index 0000000000000..b6532ff0fbf0d
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/prefetch.mlir
@@ -0,0 +1,47 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+llvm.func @prefetch_L1(%gen_ptr: !llvm.ptr, %local_ptr: !llvm.ptr<5>, %global_ptr: !llvm.ptr<1>) {
+ // CHECK-LABEL: define void @prefetch_L1(ptr %0, ptr addrspace(5) %1, ptr addrspace(1) %2) {
+ // CHECK-NEXT: call void @llvm.nvvm.prefetch.L1(ptr %0)
+ // CHECK-NEXT: call void @llvm.nvvm.prefetch.local.L1(ptr addrspace(5) %1)
+ // CHECK-NEXT: call void @llvm.nvvm.prefetch.global.L1(ptr addrspace(1) %2)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ nvvm.prefetch.L1 %gen_ptr : !llvm.ptr<0>
+ nvvm.prefetch.L1 %local_ptr : !llvm.ptr<5>
+ nvvm.prefetch.L1 %global_ptr : !llvm.ptr<1>
+ llvm.return
+}
+
+llvm.func @prefetch_L2(%gen_ptr: !llvm.ptr, %local_ptr: !llvm.ptr<5>, %global_ptr: !llvm.ptr<1>) {
+ // CHECK-LABEL: define void @prefetch_L2(ptr %0, ptr addrspace(5) %1, ptr addrspace(1) %2) {
+ // CHECK-NEXT: call void @llvm.nvvm.prefetch.L2(ptr %0)
+ // CHECK-NEXT: call void @llvm.nvvm.prefetch.local.L2(ptr addrspace(5) %1)
+ // CHECK-NEXT: call void @llvm.nvvm.prefetch.global.L2(ptr addrspace(1) %2)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ nvvm.prefetch.L2 %gen_ptr : !llvm.ptr<0>
+ nvvm.prefetch.L2 %local_ptr : !llvm.ptr<5>
+ nvvm.prefetch.L2 %global_ptr : !llvm.ptr<1>
+ llvm.return
+}
+
+llvm.func @prefetch_L2_eviction_priority(%global_ptr: !llvm.ptr<1>) {
+ // CHECK-LABEL: define void @prefetch_L2_eviction_priority(ptr addrspace(1) %0) {
+ // CHECK-NEXT: call void @llvm.nvvm.prefetch.global.L2.evict.last(ptr addrspace(1) %0)
+ // CHECK-NEXT: call void @llvm.nvvm.prefetch.global.L2.evict.normal(ptr addrspace(1) %0)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ nvvm.prefetch.L2 %global_ptr, evict_priority = evict_last : !llvm.ptr<1>
+ nvvm.prefetch.L2 %global_ptr, evict_priority = evict_normal : !llvm.ptr<1>
+ llvm.return
+}
+
+llvm.func @prefetch_L1_uniform(%gen_ptr: !llvm.ptr) {
+ // CHECK-LABEL: define void @prefetch_L1_uniform(ptr %0) {
+ // CHECK-NEXT: call void @llvm.nvvm.prefetchu.L1(ptr %0)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ nvvm.prefetch.L1.uniform %gen_ptr : !llvm.ptr
+ llvm.return
+}
\ No newline at end of file
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index 3d63434f310bd..734729ad3d8fd 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -248,3 +248,11 @@ llvm.func @nvvm_cvt_bf16x2_to_f8x2_invalid_rounding(%src : vector<2xbf16>) {
%res = nvvm.convert.bf16x2.to.f8x2 <ue8m0> %src {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xbf16> -> i16
llvm.return
}
+
+// -----
+
+llvm.func @nvvm_prefetch_L2_with_evict_invalid_addr_space(%local_ptr: !llvm.ptr<5>) {
+ // expected-error @below {{prefetch with eviction priority requires a global pointer}}
+ nvvm.prefetch.L2 %local_ptr, evict_priority = evict_last : !llvm.ptr<5>
+ llvm.return
+}
|
@llvm/pr-subscribers-mlir Author: Srinivasa Ravi (Wolfram70) ChangesThis change adds PTX Spec Reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-prefetch-prefetchu Full diff: https://github.com/llvm/llvm-project/pull/141737.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 408537be0a5e4..311847a27a5f0 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -25,6 +25,7 @@ include "mlir/Dialect/LLVMIR/LLVMTypes.td"
def LLVM_PointerGeneric : LLVM_PointerInAddressSpace<0>;
def LLVM_PointerGlobal : LLVM_PointerInAddressSpace<1>;
def LLVM_PointerShared : LLVM_PointerInAddressSpace<3>;
+def LLVM_PointerLocal : LLVM_PointerInAddressSpace<5>;
def LLVM_PointerTensor : LLVM_PointerInAddressSpace<6>;
def LLVM_PointerSharedCluster : LLVM_PointerInAddressSpace<7>;
@@ -2333,6 +2334,90 @@ def NVVM_CpAsyncBulkTensorSharedCTAToGlobalOp :
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// NVVM Prefetch Ops
+//===----------------------------------------------------------------------===//
+
+def NVVM_PrefetchL1Op : NVVM_Op<"prefetch.L1"> {
+ let description = [{
+ Brings the cache line containing the specified address into L1 cache.
+
+ Operand `ptr` can be a global, local or generic address pointer.
+ No operation is performed if `ptr` maps to a `shared` memory location.
+
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-prefetch-prefetchu)
+ }];
+ let arguments = (ins AnyTypeOf<[LLVM_PointerGlobal,
+ LLVM_PointerLocal,
+ LLVM_PointerGeneric]>:$ptr);
+ let assemblyFormat = "$ptr attr-dict `:` type($ptr)";
+
+ let extraClassDeclaration = [{
+ static llvm::Intrinsic::ID getIntrinsicID(llvm::Type *ptrType);
+ }];
+ let llvmBuilder = [{
+ auto intId = NVVM::PrefetchL1Op::getIntrinsicID($ptr->getType());
+ createIntrinsicCall(builder, intId, $ptr);
+ }];
+}
+
+def EvictLast : I32EnumAttrCase<"EvictLast", 0, "evict_last">;
+def EvictNormal : I32EnumAttrCase<"EvictNormal", 1, "evict_normal">;
+
+def EvictionPriority : I32EnumAttr<"EvictionPriority", "NVVM Eviction Priority",
+ [EvictLast, EvictNormal]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::NVVM";
+}
+
+def EvictionPriorityAttr : EnumAttr<NVVM_Dialect, EvictionPriority, "eviction_priority"> {
+ let assemblyFormat = "$value";
+}
+
+def NVVM_PrefetchL2Op : NVVM_Op<"prefetch.L2"> {
+ let description = [{
+ Brings the cache line containing the specified address into L2 cache.
+
+ Operand `ptr` can be a global, local or generic address pointer.
+ No operation is performed if `ptr` maps to a `shared` memory location.
+
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-prefetch-prefetchu)
+ }];
+ let arguments = (ins AnyTypeOf<[LLVM_PointerGlobal,
+ LLVM_PointerLocal,
+ LLVM_PointerGeneric]>:$ptr,
+ OptionalAttr<EvictionPriorityAttr>:$evictionPriority);
+ let assemblyFormat = "$ptr (`,` `evict_priority` `=` $evictionPriority^)? attr-dict `:` type($ptr)";
+ let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ static llvm::Intrinsic::ID getIntrinsicID(llvm::Type *ptrType, std::optional<NVVM::EvictionPriority> evictionPriority);
+ }];
+ let llvmBuilder = [{
+ auto intId = NVVM::PrefetchL2Op::getIntrinsicID($ptr->getType(), $evictionPriority);
+ createIntrinsicCall(builder, intId, $ptr);
+ }];
+}
+
+def NVVM_PrefetchL1UniformOp : NVVM_Op<"prefetch.L1.uniform"> {
+ let description = [{
+ Brings the cache line containing the specified address into L1 uniform
+ cache.
+
+ Operand `ptr` is a generic address pointer.
+ No operation is performed if `ptr` maps to a `const`, `local`, or `shared`
+ memory location.
+
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-prefetch-prefetchu)
+ }];
+ let arguments = (ins LLVM_PointerGeneric:$ptr);
+ let assemblyFormat = "$ptr attr-dict `:` type($ptr)";
+
+ let llvmBuilder = [{
+ createIntrinsicCall(builder, llvm::Intrinsic::nvvm_prefetchu_L1, $ptr);
+ }];
+}
+
def NVVM_PrefetchTensorMapOp : NVVM_Op<"prefetch.tensormap",
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
Arguments<(ins LLVM_AnyPointer:$tmaDescriptor, PtxPredicate:$predicate)> {
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 8036ea27f524f..1e0039a3a0541 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1205,6 +1205,15 @@ LogicalResult NVVM::VoteSyncOp::verify() {
return success();
}
+LogicalResult NVVM::PrefetchL2Op::verify() {
+ if (getEvictionPriority() &&
+ (llvm::cast<LLVM::LLVMPointerType>(getPtr().getType())
+ .getAddressSpace() != 1))
+ return emitOpError(
+ "prefetch with eviction priority requires a global pointer");
+ return success();
+}
+
llvm::Value *
NVVM::DotAccumulate4WayOp::getPackedArg(llvm::Value *arg,
llvm::IRBuilderBase &builder) {
@@ -1712,6 +1721,42 @@ DotAccumulate4WayOp::getIntrinsicID(NVVM::DotAccumulate4WayType a_type,
}
}
+llvm::Intrinsic::ID PrefetchL1Op::getIntrinsicID(llvm::Type *ptrType) {
+ switch (ptrType->getPointerAddressSpace()) {
+ case 0:
+ return llvm::Intrinsic::nvvm_prefetch_L1;
+ case 1:
+ return llvm::Intrinsic::nvvm_prefetch_global_L1;
+ case 5:
+ return llvm::Intrinsic::nvvm_prefetch_local_L1;
+ default:
+ llvm_unreachable("Invalid pointer address space");
+ }
+}
+
+llvm::Intrinsic::ID PrefetchL2Op::getIntrinsicID(
+ llvm::Type *ptrType,
+ std::optional<NVVM::EvictionPriority> evictionPriority) {
+ switch (ptrType->getPointerAddressSpace()) {
+ case 0:
+ return llvm::Intrinsic::nvvm_prefetch_L2;
+ case 1:
+ if (evictionPriority) {
+ if (*evictionPriority == NVVM::EvictionPriority::EvictLast)
+ return llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last;
+ else if (*evictionPriority == NVVM::EvictionPriority::EvictNormal)
+ return llvm::Intrinsic::nvvm_prefetch_global_L2_evict_normal;
+ else
+ llvm_unreachable("Invalid eviction priority");
+ }
+ return llvm::Intrinsic::nvvm_prefetch_global_L2;
+ case 5:
+ return llvm::Intrinsic::nvvm_prefetch_local_L2;
+ default:
+ llvm_unreachable("Invalid pointer address space");
+ }
+}
+
//===----------------------------------------------------------------------===//
// NVVMDialect initialization, type parsing, and registration.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index e8425638cc9be..704f9d28fd5ae 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -587,6 +587,29 @@ func.func @dot_accumulate_4way(%a: i32, %a_vec: vector<4xi8>, %b: i32, %b_vec: v
return
}
+// CHECK-LABEL: @prefetch
+func.func @prefetch(%gen_ptr: !llvm.ptr, %local_ptr: !llvm.ptr<5>, %global_ptr: !llvm.ptr<1>) {
+ // CHECK: nvvm.prefetch.L1 %{{.*}}
+ nvvm.prefetch.L1 %gen_ptr : !llvm.ptr<0>
+ // CHECK: nvvm.prefetch.L1 %{{.*}}
+ nvvm.prefetch.L1 %local_ptr : !llvm.ptr<5>
+ // CHECK: nvvm.prefetch.L1 %{{.*}}
+ nvvm.prefetch.L1 %global_ptr : !llvm.ptr<1>
+ // CHECK: nvvm.prefetch.L2 %{{.*}}
+ nvvm.prefetch.L2 %gen_ptr : !llvm.ptr<0>
+ // CHECK: nvvm.prefetch.L2 %{{.*}}
+ nvvm.prefetch.L2 %local_ptr : !llvm.ptr<5>
+ // CHECK: nvvm.prefetch.L2 %{{.*}}
+ nvvm.prefetch.L2 %global_ptr : !llvm.ptr<1>
+ // CHECK: nvvm.prefetch.L2 %{{.*}}
+ nvvm.prefetch.L2 %global_ptr, evict_priority = evict_last : !llvm.ptr<1>
+ // CHECK: nvvm.prefetch.L2 %{{.*}}
+ nvvm.prefetch.L2 %global_ptr, evict_priority = evict_normal : !llvm.ptr<1>
+ // CHECK: nvvm.prefetch.L1.uniform %{{.*}}
+ nvvm.prefetch.L1.uniform %gen_ptr : !llvm.ptr
+ return
+}
+
// -----
// Just check these don't emit errors.
diff --git a/mlir/test/Target/LLVMIR/nvvm/prefetch.mlir b/mlir/test/Target/LLVMIR/nvvm/prefetch.mlir
new file mode 100644
index 0000000000000..b6532ff0fbf0d
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/prefetch.mlir
@@ -0,0 +1,47 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+llvm.func @prefetch_L1(%gen_ptr: !llvm.ptr, %local_ptr: !llvm.ptr<5>, %global_ptr: !llvm.ptr<1>) {
+ // CHECK-LABEL: define void @prefetch_L1(ptr %0, ptr addrspace(5) %1, ptr addrspace(1) %2) {
+ // CHECK-NEXT: call void @llvm.nvvm.prefetch.L1(ptr %0)
+ // CHECK-NEXT: call void @llvm.nvvm.prefetch.local.L1(ptr addrspace(5) %1)
+ // CHECK-NEXT: call void @llvm.nvvm.prefetch.global.L1(ptr addrspace(1) %2)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ nvvm.prefetch.L1 %gen_ptr : !llvm.ptr<0>
+ nvvm.prefetch.L1 %local_ptr : !llvm.ptr<5>
+ nvvm.prefetch.L1 %global_ptr : !llvm.ptr<1>
+ llvm.return
+}
+
+llvm.func @prefetch_L2(%gen_ptr: !llvm.ptr, %local_ptr: !llvm.ptr<5>, %global_ptr: !llvm.ptr<1>) {
+ // CHECK-LABEL: define void @prefetch_L2(ptr %0, ptr addrspace(5) %1, ptr addrspace(1) %2) {
+ // CHECK-NEXT: call void @llvm.nvvm.prefetch.L2(ptr %0)
+ // CHECK-NEXT: call void @llvm.nvvm.prefetch.local.L2(ptr addrspace(5) %1)
+ // CHECK-NEXT: call void @llvm.nvvm.prefetch.global.L2(ptr addrspace(1) %2)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ nvvm.prefetch.L2 %gen_ptr : !llvm.ptr<0>
+ nvvm.prefetch.L2 %local_ptr : !llvm.ptr<5>
+ nvvm.prefetch.L2 %global_ptr : !llvm.ptr<1>
+ llvm.return
+}
+
+llvm.func @prefetch_L2_eviction_priority(%global_ptr: !llvm.ptr<1>) {
+ // CHECK-LABEL: define void @prefetch_L2_eviction_priority(ptr addrspace(1) %0) {
+ // CHECK-NEXT: call void @llvm.nvvm.prefetch.global.L2.evict.last(ptr addrspace(1) %0)
+ // CHECK-NEXT: call void @llvm.nvvm.prefetch.global.L2.evict.normal(ptr addrspace(1) %0)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ nvvm.prefetch.L2 %global_ptr, evict_priority = evict_last : !llvm.ptr<1>
+ nvvm.prefetch.L2 %global_ptr, evict_priority = evict_normal : !llvm.ptr<1>
+ llvm.return
+}
+
+llvm.func @prefetch_L1_uniform(%gen_ptr: !llvm.ptr) {
+ // CHECK-LABEL: define void @prefetch_L1_uniform(ptr %0) {
+ // CHECK-NEXT: call void @llvm.nvvm.prefetchu.L1(ptr %0)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ nvvm.prefetch.L1.uniform %gen_ptr : !llvm.ptr
+ llvm.return
+}
\ No newline at end of file
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index 3d63434f310bd..734729ad3d8fd 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -248,3 +248,11 @@ llvm.func @nvvm_cvt_bf16x2_to_f8x2_invalid_rounding(%src : vector<2xbf16>) {
%res = nvvm.convert.bf16x2.to.f8x2 <ue8m0> %src {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xbf16> -> i16
llvm.return
}
+
+// -----
+
+llvm.func @nvvm_prefetch_L2_with_evict_invalid_addr_space(%local_ptr: !llvm.ptr<5>) {
+ // expected-error @below {{prefetch with eviction priority requires a global pointer}}
+ nvvm.prefetch.L2 %local_ptr, evict_priority = evict_last : !llvm.ptr<5>
+ llvm.return
+}
|
2f1697a
to
9e4d64b
Compare
9e4d64b
to
729ef12
Compare
0570d4b
to
94e46fe
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM except for a few minor nits.
9164cc4
to
ac08000
Compare
0ba44d8
to
2c37270
Compare
bb26a7a
to
4dba233
Compare
2e3858d
to
d06ef75
Compare
This change adds `prefetch.L1`, `prefetch.L2`, and `prefetch.L1.uniform` Ops to the NVVM dialect for the `prefetch` and `prefetchu` group of instructions. PTX Spec Reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-prefetch-prefetchu
d06ef75
to
6e696b7
Compare
LGTM. |
This change adds
prefetch
andprefetch.uniform
Ops to the NVVM dialect for theprefetch
andprefetchu
group of instructions.PTX Spec Reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-prefetch-prefetchu