Skip to content

[NVPTX] Add NVPTX intrinsics for TMA copies #95289

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

apaszke
Copy link
Member

@apaszke apaszke commented Jun 12, 2024

This is necessary to be able to pass in TMA descriptors through byval kernel parameters without having NVPTXLowerArgs insert an extra copy. While they can be passed in through global memory, this is the recommended approach that is also used by CUTLASS.

I think the code in this PR should be ready, but obviously it's missing tests. I'd welcome pointers to where I should add those. Until now I have tested the code with my own compiler and it has worked great so far.

This is necessary to be able to pass in TMA descriptors through byval kernel parameters
without having `NVPTXLowerArgs` insert an extra copy. While they can be passed in through
global memory, this is the recommended approach that is also used by CUTLASS.
@llvmbot
Copy link
Member

llvmbot commented Jun 12, 2024

@llvm/pr-subscribers-backend-nvptx

@llvm/pr-subscribers-llvm-ir

Author: Adam Paszke (apaszke)

Changes

This is necessary to be able to pass in TMA descriptors through byval kernel parameters without having NVPTXLowerArgs insert an extra copy. While they can be passed in through global memory, this is the recommended approach that is also used by CUTLASS.

I think the code in this PR should be ready, but obviously it's missing tests. I'd welcome pointers to where I should add those. Until now I have tested the code with my own compiler and it has worked great so far.


Full diff: https://github.com/llvm/llvm-project/pull/95289.diff

3 Files Affected:

  • (modified) llvm/include/llvm/IR/IntrinsicsNVVM.td (+24)
  • (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+28)
  • (modified) llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp (+78-13)
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 0a9139e0062ba..a210a208d01c0 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -1448,6 +1448,26 @@ defm int_nvvm_cp_async_ca_shared_global_8 : CP_ASYNC_SHARED_GLOBAL<"8", "ca">;
 defm int_nvvm_cp_async_ca_shared_global_16 : CP_ASYNC_SHARED_GLOBAL<"16", "ca">;
 defm int_nvvm_cp_async_cg_shared_global_16 : CP_ASYNC_SHARED_GLOBAL<"16", "cg">;
 
+// TODO(apaszke): Multicast TMA loads
+foreach dim = [1, 2, 3, 4, 5] in {
+  def int_nvvm_cp_async_bulk_tensor_ # dim # d_shared_cluster_global_tile_mbarrier_complete_tx_bytes :
+    Intrinsic<
+      [],
+      [llvm_shared_ptr_ty, llvm_anyptr_ty] # !listsplat(llvm_i32_ty, dim) # [llvm_anyptr_ty],
+      [IntrArgMemOnly, IntrNoCallback,
+       NoAlias<ArgIndex<0>>, NoAlias<ArgIndex<1>>, NoAlias<ArgIndex<!add(2,  dim)>>,
+       WriteOnly<ArgIndex<0>>, ReadOnly<ArgIndex<1>>],
+      "llvm.nvvm.cp.async.bulk.tensor." # dim # "d.shared_cluster.global.tile.mbarrier_complete_tx_bytes">;
+  def int_nvvm_cp_async_bulk_tensor_ # dim # d_global_shared_cta_tile_bulk_group :
+    Intrinsic<
+      [],
+      [llvm_anyptr_ty] # !listsplat(llvm_i32_ty, dim) # [llvm_shared_ptr_ty],
+      [IntrNoCallback,
+       NoAlias<ArgIndex<0>>, NoAlias<ArgIndex<!add(1, dim)>>,
+       ReadOnly<ArgIndex<0>>, ReadOnly<ArgIndex<!add(1, dim)>>],
+      "llvm.nvvm.cp.async.bulk.tensor." # dim # "d.global.shared_cta.tile.bulk_group">;
+}
+
 def int_nvvm_cp_async_commit_group :
     ClangBuiltin<"__nvvm_cp_async_commit_group">,
     Intrinsic<[],[],[]>;
@@ -1595,6 +1615,10 @@ def int_nvvm_ptr_gen_to_param: Intrinsic<[llvm_anyptr_ty],
                                      [llvm_anyptr_ty],
                                    [IntrNoMem, IntrSpeculatable, IntrNoCallback],
                                    "llvm.nvvm.ptr.gen.to.param">;
+def int_nvvm_ptr_param_to_gen: Intrinsic<[llvm_anyptr_ty],
+                                     [llvm_anyptr_ty],
+                                   [IntrNoMem, IntrSpeculatable, IntrNoCallback],
+                                   "llvm.nvvm.ptr.param.to.gen">;
 
 // Move intrinsics, used in nvvm internally
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 440af085cb8e9..e2a565defb95b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -403,6 +403,33 @@ defm CP_ASYNC_CG_SHARED_GLOBAL_16 :
   CP_ASYNC_SHARED_GLOBAL_I<"cg", "16", int_nvvm_cp_async_cg_shared_global_16,
                                        int_nvvm_cp_async_cg_shared_global_16_s>;
 
+foreach dim = [1, 2, 3, 4, 5] in {
+  defvar idx_ptx = !interleave(!foreach(i, !range(dim), "$idx" # i), ", ");
+  defvar idx_dag = !dag(ins, !listsplat(Int32Regs, dim), !foreach(i, !range(dim), "idx" # i));
+  defvar intrinsic_g2s = !cast<Intrinsic>("int_nvvm_cp_async_bulk_tensor_" # dim # "d_shared_cluster_global_tile_mbarrier_complete_tx_bytes");
+  def CP_ASYNC_BULK_TENSOR_ # dim # D_SHARED_CLUSTER_GLOBAL_TILE_MBARRIER_COMPLETE_TX_BYTES_64 :
+    NVPTXInst<
+      (outs),
+      !con((ins Int64Regs:$dst, Int64Regs:$desc), idx_dag, (ins Int64Regs:$mbar)),
+      "cp.async.bulk.tensor." # dim # "d.shared::cluster.global.tile.mbarrier::complete_tx::bytes [$dst], [$desc, {{" # idx_ptx # "}}], [$mbar];",
+      [!con((intrinsic_g2s Int64Regs:$dst, Int64Regs:$desc),
+            !setdagop(idx_dag, intrinsic_g2s),
+            (intrinsic_g2s Int64Regs:$mbar))]
+    >,
+    Requires<[hasPTX<80>, hasSM<90>]>;
+  defvar intrinsic_s2g = !cast<Intrinsic>("int_nvvm_cp_async_bulk_tensor_" # dim # "d_global_shared_cta_tile_bulk_group");
+  def CP_ASYNC_BULK_TENSOR_ # dim # D_GLOBAL_SHARED_CTA_TILE_BULK_GROUP_64 :
+    NVPTXInst<
+      (outs),
+      !con((ins Int64Regs:$desc), idx_dag, (ins Int64Regs:$dst)),
+      "cp.async.bulk.tensor." # dim # "d.global.shared::cta.tile.bulk_group [$desc, {{" # idx_ptx # "}}], [$dst];",
+      [!con((intrinsic_s2g Int64Regs:$desc),
+            !setdagop(idx_dag, intrinsic_s2g),
+            (intrinsic_s2g Int64Regs:$dst))]
+    >,
+    Requires<[hasPTX<80>, hasSM<90>]>;
+}
+
 def CP_ASYNC_COMMIT_GROUP :
   NVPTXInst<(outs), (ins), "cp.async.commit_group;", [(int_nvvm_cp_async_commit_group)]>,
   Requires<[hasPTX<70>, hasSM<80>]>;
@@ -2475,6 +2502,7 @@ defm cvta_local  : NG_TO_G<"local", int_nvvm_ptr_local_to_gen, useShortPtrLocal>
 defm cvta_shared : NG_TO_G<"shared", int_nvvm_ptr_shared_to_gen, useShortPtrShared>;
 defm cvta_global : NG_TO_G<"global", int_nvvm_ptr_global_to_gen, False>;
 defm cvta_const  : NG_TO_G<"const", int_nvvm_ptr_constant_to_gen, useShortPtrConst>;
+defm cvta_param  : NG_TO_G<"param", int_nvvm_ptr_param_to_gen, False>;
 
 defm cvta_to_local  : G_TO_NG<"local", int_nvvm_ptr_gen_to_local, useShortPtrLocal>;
 defm cvta_to_shared : G_TO_NG<"shared", int_nvvm_ptr_gen_to_shared, useShortPtrShared>;
diff --git a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
index cde02c25c4834..06eb2ba848762 100644
--- a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
@@ -94,12 +94,17 @@
 #include "NVPTXUtilities.h"
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/CodeGen/TargetPassConfig.h"
+#include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/IntrinsicsNVPTX.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Type.h"
+#include "llvm/IR/Use.h"
 #include "llvm/InitializePasses.h"
 #include "llvm/Pass.h"
+#include <cassert>
 #include <numeric>
 #include <queue>
 
@@ -146,6 +151,28 @@ INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
 INITIALIZE_PASS_END(NVPTXLowerArgs, "nvptx-lower-args",
                     "Lower arguments (NVPTX)", false, false)
 
+static std::optional<int> tmaDescriptorOperandIndex(Instruction *I) {
+  if (auto *II = dyn_cast<IntrinsicInst>(I)) {
+    switch (II->getIntrinsicID()) {
+    case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_1d_shared_cluster_global_tile_mbarrier_complete_tx_bytes:
+    case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_2d_shared_cluster_global_tile_mbarrier_complete_tx_bytes:
+    case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_3d_shared_cluster_global_tile_mbarrier_complete_tx_bytes:
+    case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_4d_shared_cluster_global_tile_mbarrier_complete_tx_bytes:
+    case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_5d_shared_cluster_global_tile_mbarrier_complete_tx_bytes:
+      return 1;
+    case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_1d_global_shared_cta_tile_bulk_group:
+    case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_2d_global_shared_cta_tile_bulk_group:
+    case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_3d_global_shared_cta_tile_bulk_group:
+    case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_4d_global_shared_cta_tile_bulk_group:
+    case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_5d_global_shared_cta_tile_bulk_group:
+      return 0;
+    default:
+      return std::nullopt;
+    }
+  }
+  return std::nullopt;
+}
+
 // =============================================================================
 // If the function had a byval struct ptr arg, say foo(%struct.x* byval %d),
 // and we can't guarantee that the only accesses are loads,
@@ -166,14 +193,15 @@ INITIALIZE_PASS_END(NVPTXLowerArgs, "nvptx-lower-args",
 
 // Replaces the \p OldUser instruction with the same in parameter AS.
 // Only Load and GEP are supported.
-static void convertToParamAS(Value *OldUser, Value *Param) {
+static void convertToParamAS(Value *OldUser, Value *OldParam, Value *NewParam) {
   Instruction *I = dyn_cast<Instruction>(OldUser);
   assert(I && "OldUser must be an instruction");
   struct IP {
     Instruction *OldInstruction;
+    Value *OldParam;
     Value *NewParam;
   };
-  SmallVector<IP> ItemsToConvert = {{I, Param}};
+  SmallVector<IP> ItemsToConvert = {{I, OldParam, NewParam}};
   SmallVector<Instruction *> InstructionsToDelete;
 
   auto CloneInstInParamAS = [](const IP &I) -> Value * {
@@ -200,6 +228,28 @@ static void convertToParamAS(Value *OldUser, Value *Param) {
       // Just pass through the argument, the old ASC is no longer needed.
       return I.NewParam;
     }
+    if (auto *II = dyn_cast<IntrinsicInst>(I.OldInstruction)) {
+      // Assert that this is a TMA intrinsic.
+      assert(tmaDescriptorOperandIndex(II).has_value());
+      assert(I.OldInstruction->getOperand(*tmaDescriptorOperandIndex(II)) ==
+             I.OldParam);
+      // TMA descriptors can remain in param memory space, but need to be passed
+      // in the generic address space.
+      Type *ParamPtr = PointerType::get(II->getContext(), ADDRESS_SPACE_PARAM);
+      Type *GenericPtr =
+          PointerType::get(II->getContext(), ADDRESS_SPACE_GENERIC);
+      FunctionType *cast_func_ty =
+          FunctionType::get(GenericPtr, {ParamPtr}, false);
+      Module *M = I.OldInstruction->getModule();
+      FunctionCallee func =
+          M->getOrInsertFunction(getName(llvm::Intrinsic::nvvm_ptr_param_to_gen,
+                                         {GenericPtr, ParamPtr}, M),
+                                 cast_func_ty);
+      Instruction *NewInGeneric =
+          CallInst::Create(func, {I.NewParam}, "", II->getIterator());
+      II->replaceUsesOfWith(I.OldParam, NewInGeneric);
+      return II;
+    }
     llvm_unreachable("Unsupported instruction");
   };
 
@@ -212,7 +262,8 @@ static void convertToParamAS(Value *OldUser, Value *Param) {
       // be converted and the instruction itself to be deleted. We can't delete
       // the old instruction yet, because it's still in use by a load somewhere.
       for (Value *V : I.OldInstruction->users())
-        ItemsToConvert.push_back({cast<Instruction>(V), NewInst});
+        ItemsToConvert.push_back(
+            {cast<Instruction>(V), I.OldInstruction, NewInst});
 
       InstructionsToDelete.push_back(I.OldInstruction);
     }
@@ -300,9 +351,13 @@ static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS,
         Worklist.push({I, Ctx.Offset + Offset});
         continue;
       }
+      if (auto *II = dyn_cast<IntrinsicInst>(CurUser)) {
+        assert(tmaDescriptorOperandIndex(II).has_value());
+        continue;
+      }
 
       llvm_unreachable("All users must be one of: load, "
-                       "bitcast, getelementptr.");
+                       "bitcast, getelementptr, TMA intrinsic.");
     }
   }
 
@@ -321,8 +376,11 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
   assert(StructType && "Missing byval type");
 
   auto IsALoadChain = [&](Value *Start) {
-    SmallVector<Value *, 16> ValuesToCheck = {Start};
-    auto IsALoadChainInstr = [](Value *V) -> bool {
+    SmallVector<Use*, 16> UsesToCheck;
+    for (Use& u : Start->uses())
+      UsesToCheck.push_back(&u);
+    auto IsSupportedUse = [](Use *U) -> bool {
+      Value *V = U->get();
       if (isa<GetElementPtrInst>(V) || isa<BitCastInst>(V) || isa<LoadInst>(V))
         return true;
       // ASC to param space are OK, too -- we'll just strip them.
@@ -330,19 +388,26 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
         if (ASC->getDestAddressSpace() == ADDRESS_SPACE_PARAM)
           return true;
       }
+      // TMA descriptors passed to TMA intrinsics are OK, too.
+      if (auto *II = dyn_cast<IntrinsicInst>(V)) {
+        auto OI = tmaDescriptorOperandIndex(II);
+        return OI.has_value() && *OI == U->getOperandNo();
+      }
       return false;
     };
 
-    while (!ValuesToCheck.empty()) {
-      Value *V = ValuesToCheck.pop_back_val();
-      if (!IsALoadChainInstr(V)) {
-        LLVM_DEBUG(dbgs() << "Need a copy of " << *Arg << " because of " << *V
+    while (!UsesToCheck.empty()) {
+      Use* U = UsesToCheck.pop_back_val();
+      if (!IsSupportedUse(U)) {
+        LLVM_DEBUG(dbgs() << "Need a copy of " << *Arg << " because of " << U
                           << "\n");
         (void)Arg;
         return false;
       }
-      if (!isa<LoadInst>(V))
-        llvm::append_range(ValuesToCheck, V->users());
+      if (!isa<LoadInst>(U)) {
+        for (Use& u : U->getUser()->uses())
+          UsesToCheck.push_back(&u);
+      }
     }
     return true;
   };
@@ -355,7 +420,7 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
         Arg, PointerType::get(StructType, ADDRESS_SPACE_PARAM), Arg->getName(),
         FirstInst);
     for (Value *V : UsersToUpdate)
-      convertToParamAS(V, ArgInParamAS);
+      convertToParamAS(V, Arg, ArgInParamAS);
     LLVM_DEBUG(dbgs() << "No need to copy " << *Arg << "\n");
 
     const auto *TLI =

Copy link

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff 3254f31a66263ea9647c9547f1531c3123444fcd 48328b66827d33454f3a40e78adaaa16a8654612 -- llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
View the diff from clang-format here.
diff --git a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
index 06eb2ba848..d34b417d79 100644
--- a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
@@ -154,17 +154,27 @@ INITIALIZE_PASS_END(NVPTXLowerArgs, "nvptx-lower-args",
 static std::optional<int> tmaDescriptorOperandIndex(Instruction *I) {
   if (auto *II = dyn_cast<IntrinsicInst>(I)) {
     switch (II->getIntrinsicID()) {
-    case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_1d_shared_cluster_global_tile_mbarrier_complete_tx_bytes:
-    case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_2d_shared_cluster_global_tile_mbarrier_complete_tx_bytes:
-    case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_3d_shared_cluster_global_tile_mbarrier_complete_tx_bytes:
-    case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_4d_shared_cluster_global_tile_mbarrier_complete_tx_bytes:
-    case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_5d_shared_cluster_global_tile_mbarrier_complete_tx_bytes:
+    case llvm::Intrinsic::
+        nvvm_cp_async_bulk_tensor_1d_shared_cluster_global_tile_mbarrier_complete_tx_bytes:
+    case llvm::Intrinsic::
+        nvvm_cp_async_bulk_tensor_2d_shared_cluster_global_tile_mbarrier_complete_tx_bytes:
+    case llvm::Intrinsic::
+        nvvm_cp_async_bulk_tensor_3d_shared_cluster_global_tile_mbarrier_complete_tx_bytes:
+    case llvm::Intrinsic::
+        nvvm_cp_async_bulk_tensor_4d_shared_cluster_global_tile_mbarrier_complete_tx_bytes:
+    case llvm::Intrinsic::
+        nvvm_cp_async_bulk_tensor_5d_shared_cluster_global_tile_mbarrier_complete_tx_bytes:
       return 1;
-    case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_1d_global_shared_cta_tile_bulk_group:
-    case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_2d_global_shared_cta_tile_bulk_group:
-    case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_3d_global_shared_cta_tile_bulk_group:
-    case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_4d_global_shared_cta_tile_bulk_group:
-    case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_5d_global_shared_cta_tile_bulk_group:
+    case llvm::Intrinsic::
+        nvvm_cp_async_bulk_tensor_1d_global_shared_cta_tile_bulk_group:
+    case llvm::Intrinsic::
+        nvvm_cp_async_bulk_tensor_2d_global_shared_cta_tile_bulk_group:
+    case llvm::Intrinsic::
+        nvvm_cp_async_bulk_tensor_3d_global_shared_cta_tile_bulk_group:
+    case llvm::Intrinsic::
+        nvvm_cp_async_bulk_tensor_4d_global_shared_cta_tile_bulk_group:
+    case llvm::Intrinsic::
+        nvvm_cp_async_bulk_tensor_5d_global_shared_cta_tile_bulk_group:
       return 0;
     default:
       return std::nullopt;
@@ -376,8 +386,8 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
   assert(StructType && "Missing byval type");
 
   auto IsALoadChain = [&](Value *Start) {
-    SmallVector<Use*, 16> UsesToCheck;
-    for (Use& u : Start->uses())
+    SmallVector<Use *, 16> UsesToCheck;
+    for (Use &u : Start->uses())
       UsesToCheck.push_back(&u);
     auto IsSupportedUse = [](Use *U) -> bool {
       Value *V = U->get();
@@ -397,7 +407,7 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
     };
 
     while (!UsesToCheck.empty()) {
-      Use* U = UsesToCheck.pop_back_val();
+      Use *U = UsesToCheck.pop_back_val();
       if (!IsSupportedUse(U)) {
         LLVM_DEBUG(dbgs() << "Need a copy of " << *Arg << " because of " << U
                           << "\n");
@@ -405,7 +415,7 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
         return false;
       }
       if (!isa<LoadInst>(U)) {
-        for (Use& u : U->getUser()->uses())
+        for (Use &u : U->getUser()->uses())
           UsesToCheck.push_back(&u);
       }
     }

@joker-eph
Copy link
Collaborator

joker-eph commented Jun 13, 2024

It seems to me like this is somehow a partial support for grid_constant, but instead of applying generally to any byval parameter, the support it is hard-coded for a special case (TMA).
The general solution is likely not out-of-reach: likely something like applying the lowering logic in NVPTXLowerArgs.cpp to just any byval struct identified with the right property (as documented in NVVM IR doc here: https://docs.nvidia.com/cuda/nvvm-ir-spec/#supported-properties )

(MLIR already has the ability to emit the right LLVM IR for grid_constant TMA descriptor, there are even tests in the repo for this, we just need to match the property now)

@apaszke
Copy link
Member Author

apaszke commented Jun 13, 2024

grid_constant is a new custom NVIDIA extension and in my view you're welcome to upstream it as well, but I don't see why this should block a PR that improves the current implementation. Especially that it's the only way to sensibly pass TMA descriptors as kernel args which is a pattern you, as a company, recommend. Happy to close this if you think the upstreaming will happen soon, but it's blocking for us and I can't wait indefinitely.

@justinfargnoli
Copy link
Contributor

I'd welcome pointers to where I should add those.

llvm/test/CodeGen/NVPTX/*

e.g. [NVPTX] Infer AS of pointers passed to kernels as integers.

@joker-eph
Copy link
Collaborator

I apologize if my comment looked like a strong objection here, that wasn't the intent. Also I'm used to comment on LLVM as a random contributor, just because of my current employment has no bearing on my comment (and in particular shouldn't carry more weight in the review here!). I'm neither in a position of approving or rejecting this PR.

@jlebar and @Artem-B should evaluate about what's reasonable for NVPTX in LLVM, the support of the NVVM IR spec, etc.

I just shared what I know about the NVVM IR spec, and squinting here it seemed to me that matching the metadata for kernel parameters and applying the same treatment you're already doing didn't seem out-of-reach (and would connect to the current MLIR emission of the general NVVM IR grid-constant metadata here: https://github.com/llvm/llvm-project/blob/main/mlir/test/Target/LLVMIR/nvvmir.mlir#L560-L575 ; which was added especially for TMA as well)

@jlebar
Copy link
Member

jlebar commented Jun 13, 2024

@joker-eph would you be willing to check internally if nvidia has code they're willing to upstream that's more in line with what you're looking for? It sounds like everyone would prefer that if we had it.

akshayrdeodhar added a commit to akshayrdeodhar/llvm-project that referenced this pull request Jun 20, 2024
- Adds a helper function for checking whether an argument is a grid_constant.
- Adds support for cvta.param using changes from llvm#95289
- Supports escaped grid_constant pointers conservatively, by casting all uses to the generic address space with cvta.param.
akshayrdeodhar added a commit that referenced this pull request Jun 24, 2024
- Adds a helper function for checking whether an argument is a
[grid_constant](https://docs.nvidia.com/cuda/nvvm-ir-spec/index.html#supported-properties).
- Adds support for cvta.param using changes from
#95289
- Supports escaped grid_constant pointers conservatively, by casting all
uses to the generic address space with cvta.param.
@apaszke
Copy link
Member Author

apaszke commented Jun 26, 2024

Closing since #96125 is merged now.

@apaszke apaszke closed this Jun 26, 2024
frasercrmck pushed a commit to frasercrmck/llvm that referenced this pull request Jun 27, 2024
- Adds a helper function for checking whether an argument is a
[grid_constant](https://docs.nvidia.com/cuda/nvvm-ir-spec/index.html#supported-properties).
- Adds support for cvta.param using changes from
llvm/llvm-project#95289
- Supports escaped grid_constant pointers conservatively, by casting all
uses to the generic address space with cvta.param.
AlexisPerry pushed a commit to llvm-project-tlp/llvm-project that referenced this pull request Jul 9, 2024
- Adds a helper function for checking whether an argument is a
[grid_constant](https://docs.nvidia.com/cuda/nvvm-ir-spec/index.html#supported-properties).
- Adds support for cvta.param using changes from
llvm#95289
- Supports escaped grid_constant pointers conservatively, by casting all
uses to the generic address space with cvta.param.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants