Skip to content

[NVPTX] Improve kernel byval parameter lowering #136008

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

AlexMaclean
Copy link
Member

@AlexMaclean AlexMaclean commented Apr 16, 2025

This change introduces a new pattern for lowering kernel byval parameters in NVPTXLowerArgs. Each byval argument is wrapped in a call to a new intrinsic, @llvm.nvvm.internal.addrspace.wrap. This intrinsic explicitly equates to no instructions and is removed during operation legalization in SDAG. However, it allows us to change the addrspace of the arguments to 101 to reflect the fact that they will occupy this space when lowered by LowerFormalArgs in NVPTXISelLowering. Optionally, if a generic pointer to a param is needed, a standard addrspacecast is used. This approach offers several advantages:

  • Exposes addrspace optimizations: By using a standard addrspacecast back to generic space we allow InferAS to optimize this instruction, potentially sinking it through control flow or in other ways unsupported by NVPTXLowerArgs. This is demonstrated in several existing tests.
  • Clearer, more consistent semantics: Previously an addrspacecast from generic to param space was implicitly a no-op. This is problematic because it's not reciprocal with the inverse cast, violating LLVM semantics. Further it is very confusing given the existence of cvta.to.param. After this change the cast equates to this instruction.
  • Allow for the removal of all nvvm.ptr.* intrinsics: In a follow-up change the nvvm.ptr.gen.to.param and nvvm.ptr.param.to.gen intrinsics may be removed.

@llvmbot
Copy link
Member

llvmbot commented Apr 16, 2025

@llvm/pr-subscribers-backend-nvptx

Author: Alex MacLean (AlexMaclean)

Changes

This change introduces a new pattern for lowering kernel byval parameters in NVPTXLowerArgs. Each byval argument is wrapped in a call to a new intrinsic, @<!-- -->llvm.nvvm.internal.noop.addrspacecast. This intrinsic explicitly equates to no instructions and is removed during operation legalization in SDAG. However, it allows us to change the addrspace of the arguments to 101 to reflect the fact that they will occupy this space when lowered by LowerFormalArgs in NVPTXISelLowering. Optionally, if a generic pointer to a param is needed, a standard addrspacecast is used. This approach offers several advantages:

  • Exposes addrspace optimizations: By using a standard addrspacecast back to generic space we allow InferAS to optimize this instruction, potentially sinking it through control flow or in other ways unsupported by NVPTXLowerArgs. This is demonstrated in several existing tests.
  • Clearer, more consistent semantics: Previously an addrspacecast from generic to param space was implicitly a no-op. This is problematic because it's not reciprocal with the inverse cast, violating LLVM semantics. Further it is very confusing given the existence of cvta.to.param. After this change the cast equates to this instruction.
  • Allow for the removal of all nvvm.ptr.* intrinsics: In a follow-up change the nvvm.ptr.gen.to.param and nvvm.ptr.param.to.gen intrinsics may be removed.

Patch is 79.60 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/136008.diff

12 Files Affected:

  • (modified) llvm/include/llvm/IR/IntrinsicsNVVM.td (+15)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp (+4-1)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+24)
  • (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+4-12)
  • (modified) llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp (+43-46)
  • (modified) llvm/lib/Target/NVPTX/NVPTXUtilities.cpp (+24-9)
  • (modified) llvm/lib/Target/NVPTX/NVPTXUtilities.h (+1-1)
  • (modified) llvm/test/CodeGen/NVPTX/bug21465.ll (+1-1)
  • (modified) llvm/test/CodeGen/NVPTX/forward-ld-param.ll (+1-1)
  • (modified) llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll (+107-94)
  • (modified) llvm/test/CodeGen/NVPTX/lower-args.ll (+2-2)
  • (modified) llvm/test/CodeGen/NVPTX/lower-byval-args.ll (+161-137)
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 4aeb1d8a2779e..5d89b0ae2b484 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -1878,6 +1878,21 @@ def int_nvvm_ptr_param_to_gen: DefaultAttrsIntrinsic<[llvm_anyptr_ty],
                                    [IntrNoMem, IntrSpeculatable, IntrNoCallback],
                                    "llvm.nvvm.ptr.param.to.gen">;
 
+// Represents an explicit hole in the LLVM IR type system. It may be inserted by
+// the compiler in cases where a pointer is of the wrong type. In the backend
+// this intrinsic will be folded away and not equate to any instruction. It
+// should not be used by any frontend and should only be considered well defined
+// when added in the following cases:
+//
+//  - NVPTXLowerArgs: When wrapping a byval pointer argument to a kernel
+//    function to convert the address space from generic (0) to param (101).
+//    This accounts for the fact that the parameter symbols will occupy this
+//    space when lowered during ISel.
+//
+def int_nvvm_internal_noop_addrspacecast :
+  DefaultAttrsIntrinsic<[llvm_anyptr_ty], [llvm_anyptr_ty],
+                        [IntrNoMem, IntrSpeculatable, NoUndef<ArgIndex<0>>, NoUndef<RetIndex>]>;
+
 // Move intrinsics, used in nvvm internally
 
 def int_nvvm_move_i16 : Intrinsic<[llvm_i16_ty], [llvm_i16_ty], [IntrNoMem],
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index ec1f969494cd1..486c7c815435a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -985,6 +985,9 @@ void NVPTXDAGToDAGISel::SelectAddrSpaceCast(SDNode *N) {
     case ADDRESS_SPACE_LOCAL:
       Opc = TM.is64Bit() ? NVPTX::cvta_local_64 : NVPTX::cvta_local;
       break;
+    case ADDRESS_SPACE_PARAM:
+      Opc = TM.is64Bit() ? NVPTX::cvta_param_64 : NVPTX::cvta_param;
+      break;
     }
     ReplaceNode(N, CurDAG->getMachineNode(Opc, DL, N->getValueType(0), Src));
     return;
@@ -1008,7 +1011,7 @@ void NVPTXDAGToDAGISel::SelectAddrSpaceCast(SDNode *N) {
       Opc = TM.is64Bit() ? NVPTX::cvta_to_local_64 : NVPTX::cvta_to_local;
       break;
     case ADDRESS_SPACE_PARAM:
-      Opc = TM.is64Bit() ? NVPTX::IMOV64r : NVPTX::IMOV32r;
+      Opc = TM.is64Bit() ? NVPTX::cvta_to_param_64 : NVPTX::cvta_to_param;
       break;
     }
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 9bde2a976e164..166785a79ec4c 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -1017,6 +1017,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
                      {MVT::v2i32, MVT::v4i32, MVT::v8i32, MVT::v16i32,
                       MVT::v32i32, MVT::v64i32, MVT::v128i32},
                      Custom);
+
+  setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::Other, Custom);
 }
 
 const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
@@ -1434,6 +1436,17 @@ static MachinePointerInfo refinePtrAS(SDValue &Ptr, SelectionDAG &DAG,
 
     return MachinePointerInfo(ADDRESS_SPACE_LOCAL);
   }
+
+  // Peel of an addrspacecast to generic and load directly from the specific
+  // address space.
+  if (Ptr->getOpcode() == ISD::ADDRSPACECAST) {
+    const auto *ASC = cast<AddrSpaceCastSDNode>(Ptr);
+    if (ASC->getDestAddressSpace() == ADDRESS_SPACE_GENERIC) {
+      Ptr = ASC->getOperand(0);
+      return MachinePointerInfo(ASC->getSrcAddressSpace());
+    }
+  }
+
   return MachinePointerInfo();
 }
 
@@ -2754,6 +2767,15 @@ static SDValue LowerIntrinsicVoid(SDValue Op, SelectionDAG &DAG) {
   return Op;
 }
 
+static SDValue lowerIntrinsicWOChain(SDValue Op, SelectionDAG &DAG) {
+  switch (Op->getConstantOperandVal(0)) {
+  default:
+    return Op;
+  case Intrinsic::nvvm_internal_noop_addrspacecast:
+    return Op.getOperand(1);
+  }
+}
+
 // In PTX 64-bit CTLZ and CTPOP are supported, but they return a 32-bit value.
 // Lower these into a node returning the correct type which is zero-extended
 // back to the correct size.
@@ -2863,6 +2885,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
     return LowerGlobalAddress(Op, DAG);
   case ISD::INTRINSIC_W_CHAIN:
     return Op;
+  case ISD::INTRINSIC_WO_CHAIN:
+    return lowerIntrinsicWOChain(Op, DAG);
   case ISD::INTRINSIC_VOID:
     return LowerIntrinsicVoid(Op, DAG);
   case ISD::BUILD_VECTOR:
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 8528ff702f236..266f379607690 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -2335,18 +2335,10 @@ multiclass G_TO_NG<string Str> {
           "cvta.to." # Str # ".u64 \t$result, $src;", []>;
 }
 
-defm cvta_local  : NG_TO_G<"local">;
-defm cvta_shared : NG_TO_G<"shared">;
-defm cvta_global : NG_TO_G<"global">;
-defm cvta_const  : NG_TO_G<"const">;
-
-defm cvta_to_local  : G_TO_NG<"local">;
-defm cvta_to_shared : G_TO_NG<"shared">;
-defm cvta_to_global : G_TO_NG<"global">;
-defm cvta_to_const  : G_TO_NG<"const">;
-
-// nvvm.ptr.param.to.gen
-defm cvta_param : NG_TO_G<"param">;
+foreach space = ["local", "shared", "global", "const", "param"] in {
+  defm cvta_#space : NG_TO_G<space>;
+  defm cvta_to_#space : G_TO_NG<space>;
+}
 
 def : Pat<(int_nvvm_ptr_param_to_gen i32:$src),
           (cvta_param $src)>;
diff --git a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
index 6452fa05947dd..770914fcc2f28 100644
--- a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
@@ -265,18 +265,9 @@ static void convertToParamAS(Use *OldUse, Value *Param, bool HasCvtaParam,
     if (HasCvtaParam) {
       auto GetParamAddrCastToGeneric =
           [](Value *Addr, Instruction *OriginalUser) -> Value * {
-        PointerType *ReturnTy =
-            PointerType::get(OriginalUser->getContext(), ADDRESS_SPACE_GENERIC);
-        Function *CvtToGen = Intrinsic::getOrInsertDeclaration(
-            OriginalUser->getModule(), Intrinsic::nvvm_ptr_param_to_gen,
-            {ReturnTy, PointerType::get(OriginalUser->getContext(),
-                                        ADDRESS_SPACE_PARAM)});
-
-        // Cast param address to generic address space
-        Value *CvtToGenCall =
-            CallInst::Create(CvtToGen, Addr, Addr->getName() + ".gen",
-                             OriginalUser->getIterator());
-        return CvtToGenCall;
+        IRBuilder<> IRB(OriginalUser);
+        Type *GenTy = IRB.getPtrTy(ADDRESS_SPACE_GENERIC);
+        return IRB.CreateAddrSpaceCast(Addr, GenTy, Addr->getName() + ".gen");
       };
       auto *ParamInGenericAS =
           GetParamAddrCastToGeneric(I.NewParam, I.OldInstruction);
@@ -515,23 +506,24 @@ void copyByValParam(Function &F, Argument &Arg) {
   BasicBlock::iterator FirstInst = F.getEntryBlock().begin();
   Type *StructType = Arg.getParamByValType();
   const DataLayout &DL = F.getDataLayout();
-  AllocaInst *AllocA = new AllocaInst(StructType, DL.getAllocaAddrSpace(),
-                                      Arg.getName(), FirstInst);
+  IRBuilder<> IRB(&*FirstInst);
+  AllocaInst *AllocA = IRB.CreateAlloca(StructType, nullptr, Arg.getName());
   // Set the alignment to alignment of the byval parameter. This is because,
   // later load/stores assume that alignment, and we are going to replace
   // the use of the byval parameter with this alloca instruction.
-  AllocA->setAlignment(F.getParamAlign(Arg.getArgNo())
-                           .value_or(DL.getPrefTypeAlign(StructType)));
+  AllocA->setAlignment(
+      Arg.getParamAlign().value_or(DL.getPrefTypeAlign(StructType)));
   Arg.replaceAllUsesWith(AllocA);
 
-  Value *ArgInParam = new AddrSpaceCastInst(
-      &Arg, PointerType::get(Arg.getContext(), ADDRESS_SPACE_PARAM),
-      Arg.getName(), FirstInst);
+  Value *ArgInParam =
+      IRB.CreateIntrinsic(Intrinsic::nvvm_internal_noop_addrspacecast,
+                          {IRB.getPtrTy(ADDRESS_SPACE_PARAM), Arg.getType()},
+                          &Arg, {}, Arg.getName());
+
   // Be sure to propagate alignment to this load; LLVM doesn't know that NVPTX
   // addrspacecast preserves alignment.  Since params are constant, this load
   // is definitely not volatile.
   const auto ArgSize = *AllocA->getAllocationSize(DL);
-  IRBuilder<> IRB(&*FirstInst);
   IRB.CreateMemCpy(AllocA, AllocA->getAlign(), ArgInParam, AllocA->getAlign(),
                    ArgSize);
 }
@@ -539,9 +531,9 @@ void copyByValParam(Function &F, Argument &Arg) {
 
 static void handleByValParam(const NVPTXTargetMachine &TM, Argument *Arg) {
   Function *Func = Arg->getParent();
-  bool HasCvtaParam =
-      TM.getSubtargetImpl(*Func)->hasCvtaParam() && isKernelFunction(*Func);
-  bool IsGridConstant = HasCvtaParam && isParamGridConstant(*Arg);
+  assert(isKernelFunction(*Func));
+  const bool HasCvtaParam = TM.getSubtargetImpl(*Func)->hasCvtaParam();
+  const bool IsGridConstant = HasCvtaParam && isParamGridConstant(*Arg);
   const DataLayout &DL = Func->getDataLayout();
   BasicBlock::iterator FirstInst = Func->getEntryBlock().begin();
   Type *StructType = Arg->getParamByValType();
@@ -558,9 +550,11 @@ static void handleByValParam(const NVPTXTargetMachine &TM, Argument *Arg) {
     for (Use &U : Arg->uses())
       UsesToUpdate.push_back(&U);
 
-    Value *ArgInParamAS = new AddrSpaceCastInst(
-        Arg, PointerType::get(StructType->getContext(), ADDRESS_SPACE_PARAM),
-        Arg->getName(), FirstInst);
+    IRBuilder<> IRB(&*FirstInst);
+    Value *ArgInParamAS = IRB.CreateIntrinsic(
+        Intrinsic::nvvm_internal_noop_addrspacecast,
+        {IRB.getPtrTy(ADDRESS_SPACE_PARAM), Arg->getType()}, {Arg});
+
     for (Use *U : UsesToUpdate)
       convertToParamAS(U, ArgInParamAS, HasCvtaParam, IsGridConstant);
     LLVM_DEBUG(dbgs() << "No need to copy or cast " << *Arg << "\n");
@@ -578,30 +572,31 @@ static void handleByValParam(const NVPTXTargetMachine &TM, Argument *Arg) {
   // However, we're still not allowed to write to it. If the user specified
   // `__grid_constant__` for the argument, we'll consider escaped pointer as
   // read-only.
-  if (HasCvtaParam && (ArgUseIsReadOnly || IsGridConstant)) {
+  if (IsGridConstant || (HasCvtaParam && ArgUseIsReadOnly)) {
     LLVM_DEBUG(dbgs() << "Using non-copy pointer to " << *Arg << "\n");
     // Replace all argument pointer uses (which might include a device function
     // call) with a cast to the generic address space using cvta.param
     // instruction, which avoids a local copy.
     IRBuilder<> IRB(&Func->getEntryBlock().front());
 
-    // Cast argument to param address space
-    auto *CastToParam = cast<AddrSpaceCastInst>(IRB.CreateAddrSpaceCast(
-        Arg, IRB.getPtrTy(ADDRESS_SPACE_PARAM), Arg->getName() + ".param"));
+    // Cast argument to param address space. Because the backend will emit the
+    // argument already in the param address space, we need to use the noop
+    // intrinsic, this had the added benefit of preventing other optimizations
+    // from folding away this pair of addrspacecasts.
+    auto *ParamSpaceArg =
+        IRB.CreateIntrinsic(Intrinsic::nvvm_internal_noop_addrspacecast,
+                            {IRB.getPtrTy(ADDRESS_SPACE_PARAM), Arg->getType()},
+                            Arg, {}, Arg->getName() + ".param");
 
-    // Cast param address to generic address space. We do not use an
-    // addrspacecast to generic here, because, LLVM considers `Arg` to be in the
-    // generic address space, and a `generic -> param` cast followed by a `param
-    // -> generic` cast will be folded away. The `param -> generic` intrinsic
-    // will be correctly lowered to `cvta.param`.
-    Value *CvtToGenCall = IRB.CreateIntrinsic(
-        IRB.getPtrTy(ADDRESS_SPACE_GENERIC), Intrinsic::nvvm_ptr_param_to_gen,
-        CastToParam, nullptr, CastToParam->getName() + ".gen");
+    // Cast param address to generic address space.
+    Value *GenericArg = IRB.CreateAddrSpaceCast(
+        ParamSpaceArg, IRB.getPtrTy(ADDRESS_SPACE_GENERIC),
+        Arg->getName() + ".gen");
 
-    Arg->replaceAllUsesWith(CvtToGenCall);
+    Arg->replaceAllUsesWith(GenericArg);
 
     // Do not replace Arg in the cast to param space
-    CastToParam->setOperand(0, Arg);
+    ParamSpaceArg->setOperand(0, Arg);
   } else
     copyByValParam(*Func, *Arg);
 }
@@ -715,12 +710,14 @@ static bool copyFunctionByValArgs(Function &F) {
   LLVM_DEBUG(dbgs() << "Creating a copy of byval args of " << F.getName()
                     << "\n");
   bool Changed = false;
-  for (Argument &Arg : F.args())
-    if (Arg.getType()->isPointerTy() && Arg.hasByValAttr() &&
-        !(isParamGridConstant(Arg) && isKernelFunction(F))) {
-      copyByValParam(F, Arg);
-      Changed = true;
-    }
+  if (isKernelFunction(F)) {
+    for (Argument &Arg : F.args())
+      if (Arg.getType()->isPointerTy() && Arg.hasByValAttr() &&
+          !isParamGridConstant(Arg)) {
+        copyByValParam(F, Arg);
+        Changed = true;
+      }
+  }
   return Changed;
 }
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
index 3d9d2ae372080..0cbebc6995c9a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
@@ -16,11 +16,13 @@
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringRef.h"
+#include "llvm/IR/Argument.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/GlobalVariable.h"
 #include "llvm/IR/Module.h"
 #include "llvm/Support/Alignment.h"
+#include "llvm/Support/ModRef.h"
 #include "llvm/Support/Mutex.h"
 #include <cstdint>
 #include <cstring>
@@ -228,17 +230,30 @@ static std::optional<uint64_t> getVectorProduct(ArrayRef<unsigned> V) {
   return std::accumulate(V.begin(), V.end(), 1, std::multiplies<uint64_t>{});
 }
 
-bool isParamGridConstant(const Value &V) {
-  if (const Argument *Arg = dyn_cast<Argument>(&V)) {
-    // "grid_constant" counts argument indices starting from 1
-    if (Arg->hasByValAttr() &&
-        argHasNVVMAnnotation(*Arg, "grid_constant",
-                             /*StartArgIndexAtOne*/ true)) {
-      assert(isKernelFunction(*Arg->getParent()) &&
-             "only kernel arguments can be grid_constant");
+bool isParamGridConstant(const Argument &Arg) {
+  assert(isKernelFunction(*Arg.getParent()) &&
+         "only kernel arguments can be grid_constant");
+
+  if (!Arg.hasByValAttr())
+    return false;
+
+  // Lowering an argument as a grid_constant violates the byval semantics (and
+  // the C++ API) by reusing the same memory location for the argument across
+  // multiple threads. If an argument doesn't read memory and its address is not
+  // captured (its address is not compared with any value), then the tweak of
+  // the C++ API and byval semantics is unobservable by the program and we can
+  // lower the arg as a grid_constant.
+  if (Arg.onlyReadsMemory()) {
+    const auto CI = Arg.getAttributes().getCaptureInfo();
+    if (!capturesAddress(CI) && !capturesFullProvenance(CI))
       return true;
-    }
   }
+
+  // "grid_constant" counts argument indices starting from 1
+  if (argHasNVVMAnnotation(Arg, "grid_constant",
+                           /*StartArgIndexAtOne*/ true))
+    return true;
+
   return false;
 }
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.h b/llvm/lib/Target/NVPTX/NVPTXUtilities.h
index 9283b398a9c14..9adbb645deb4a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXUtilities.h
+++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.h
@@ -63,7 +63,7 @@ inline bool isKernelFunction(const Function &F) {
   return F.getCallingConv() == CallingConv::PTX_Kernel;
 }
 
-bool isParamGridConstant(const Value &);
+bool isParamGridConstant(const Argument &);
 
 inline MaybeAlign getAlign(const Function &F, unsigned Index) {
   return F.getAttributes().getAttributes(Index).getStackAlignment();
diff --git a/llvm/test/CodeGen/NVPTX/bug21465.ll b/llvm/test/CodeGen/NVPTX/bug21465.ll
index 76300e3cfdc5b..21ec05c70e3ad 100644
--- a/llvm/test/CodeGen/NVPTX/bug21465.ll
+++ b/llvm/test/CodeGen/NVPTX/bug21465.ll
@@ -12,7 +12,7 @@ define ptx_kernel void @_Z11TakesStruct1SPi(ptr byval(%struct.S) nocapture reado
 entry:
 ; CHECK-LABEL: @_Z11TakesStruct1SPi
 ; PTX-LABEL: .visible .entry _Z11TakesStruct1SPi(
-; CHECK: addrspacecast ptr %input to ptr addrspace(101)
+; CHECK: call ptr addrspace(101) @llvm.nvvm.internal.noop.addrspacecast.p101.p0(ptr %input)
   %b = getelementptr inbounds %struct.S, ptr %input, i64 0, i32 1
   %0 = load i32, ptr %b, align 4
 ; PTX-NOT: ld.param.u32 {{%r[0-9]+}}, [{{%rd[0-9]+}}]
diff --git a/llvm/test/CodeGen/NVPTX/forward-ld-param.ll b/llvm/test/CodeGen/NVPTX/forward-ld-param.ll
index 6d9710e6d2272..80ae8aac39115 100644
--- a/llvm/test/CodeGen/NVPTX/forward-ld-param.ll
+++ b/llvm/test/CodeGen/NVPTX/forward-ld-param.ll
@@ -65,7 +65,7 @@ define void @test_ld_param_byval(ptr byval(i32) %a) {
 ; CHECK-LABEL: test_ld_param_byval(
 ; CHECK:       {
 ; CHECK-NEXT:    .reg .b32 %r<2>;
-; CHECK-NEXT:    .reg .b64 %rd<3>;
+; CHECK-NEXT:    .reg .b64 %rd<2>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0:
 ; CHECK-NEXT:    ld.param.u32 %r1, [test_ld_param_byval_param_0];
diff --git a/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll b/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll
index 836a7d78a0cc5..46535a7a91c28 100644
--- a/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll
+++ b/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll
@@ -72,7 +72,7 @@ define ptx_kernel void @grid_const_int(ptr byval(i32) align 4 %input1, i32 %inpu
 ; PTX-NEXT:    ret;
 ; OPT-LABEL: define ptx_kernel void @grid_const_int(
 ; OPT-SAME: ptr byval(i32) align 4 [[INPUT1:%.*]], i32 [[INPUT2:%.*]], ptr [[OUT:%.*]], i32 [[N:%.*]]) #[[ATTR0]] {
-; OPT-NEXT:    [[INPUT11:%.*]] = addrspacecast ptr [[INPUT1]] to ptr addrspace(101)
+; OPT-NEXT:    [[INPUT11:%.*]] = call ptr addrspace(101) @llvm.nvvm.internal.noop.addrspacecast.p101.p0(ptr [[INPUT1]])
 ; OPT-NEXT:    [[TMP:%.*]] = load i32, ptr addrspace(101) [[INPUT11]], align 4
 ; OPT-NEXT:    [[ADD:%.*]] = add i32 [[TMP]], [[INPUT2]]
 ; OPT-NEXT:    store i32 [[ADD]], ptr [[OUT]], align 4
@@ -101,7 +101,7 @@ define ptx_kernel void @grid_const_struct(ptr byval(%struct.s) align 4 %input, p
 ; PTX-NEXT:    ret;
 ; OPT-LABEL: define ptx_kernel void @grid_const_struct(
 ; OPT-SAME: ptr byval([[STRUCT_S:%.*]]) align 4 [[INPUT:%.*]], ptr [[OUT:%.*]]) #[[ATTR0]] {
-; OPT-NEXT:    [[INPUT1:%.*]] = addrspacecast ptr [[INPUT]] to ptr addrspace(101)
+; OPT-NEXT:    [[INPUT1:%.*]] = call ptr addrspace(101) @llvm.nvvm.internal.noop.addrspacecast.p101.p0(ptr [[INPUT]])
 ; OPT-NEXT:    [[GEP13:%.*]] = getelementptr inbounds [[STRUCT_S]], ptr addrspace(101) [[INPUT1]], i32 0, i32 0
 ; OPT-NEXT:    [[GEP22:%.*]] = getelementptr inbounds [[STRUCT_S]], ptr addrspace(101) [[INPUT1]], i32 0, i32 1
 ; OPT-NEXT:    [[TMP1:%.*]] = load i32, ptr addrspace(101) [[GEP13]], align 4
@@ -122,16 +122,15 @@ define ptx_kernel void @grid_const_escape(ptr byval(%struct.s) align 4 %input) {
 ; PTX-LABEL: grid_const_escape(
 ; PTX:       {
 ; PTX-NEXT:    .reg .b32 %r<3>;
-; PTX-NEXT:    .reg .b64 %rd<5>;
+; PTX-NEXT:    .reg .b64 %rd<4>;
 ; PTX-EMPTY:
 ; PTX-NEXT:  // %bb.0:
 ; PTX-NEXT:    mov.b64 %rd2, grid_const_escape_param_0;
-; PTX-NEXT:    mov.b64 %rd3, %rd2;
-; PTX-NEXT:    cvta.param.u64 %rd4, %rd3;
+; PTX-NEXT:    cvta.param.u64 %rd3, %rd2;
 ; PTX-NEXT:    mov.b64 %rd1, escape;
 ; PTX-NEXT:    { // callseq 0, 0
 ; PTX-NEXT:    .param .b64 param0;
-; PTX-NEXT:    st.param.b64 [param0], %rd4;
+; PTX-NEXT:    st.param.b64 [param0], %rd3;
 ; PTX-NEXT:    .param .b32 retval0;
 ; PTX-NEXT:    prototype_0 : .callprototype (.param .b32 _) _ (.param .b64 _);
 ; PTX-NEXT:    call (retval0),
@@ -145,8 +144,8 @@ define ptx_kernel void @grid_const_escape(ptr byval(%struct.s) align 4 %input) {
 ; PTX-NEXT:    ret;
 ; OPT-LABEL: define ptx_kernel void @grid_const_escape(
 ; OPT-SAME: ptr byval([[STRUCT_S:%.*]]) align 4 [[INPUT:%.*]]) #[[ATTR0]] {
-; OPT-NEXT:    [[INPUT_PARAM:%.*]] = addrspacecast ptr [[INPUT]] to ptr addrspace(101)
-; OPT-NEXT:    [[INPUT_PARAM_GEN:%.*]] = call ptr @llvm.nvvm.ptr.param.to.gen.p0.p101(ptr addrspace(101) [[INPUT_PARAM]])
...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Apr 16, 2025

@llvm/pr-subscribers-llvm-ir

Author: Alex MacLean (AlexMaclean)

Changes

This change introduces a new pattern for lowering kernel byval parameters in NVPTXLowerArgs. Each byval argument is wrapped in a call to a new intrinsic, @<!-- -->llvm.nvvm.internal.noop.addrspacecast. This intrinsic explicitly equates to no instructions and is removed during operation legalization in SDAG. However, it allows us to change the addrspace of the arguments to 101 to reflect the fact that they will occupy this space when lowered by LowerFormalArgs in NVPTXISelLowering. Optionally, if a generic pointer to a param is needed, a standard addrspacecast is used. This approach offers several advantages:

  • Exposes addrspace optimizations: By using a standard addrspacecast back to generic space we allow InferAS to optimize this instruction, potentially sinking it through control flow or in other ways unsupported by NVPTXLowerArgs. This is demonstrated in several existing tests.
  • Clearer, more consistent semantics: Previously an addrspacecast from generic to param space was implicitly a no-op. This is problematic because it's not reciprocal with the inverse cast, violating LLVM semantics. Further it is very confusing given the existence of cvta.to.param. After this change the cast equates to this instruction.
  • Allow for the removal of all nvvm.ptr.* intrinsics: In a follow-up change the nvvm.ptr.gen.to.param and nvvm.ptr.param.to.gen intrinsics may be removed.

Patch is 79.60 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/136008.diff

12 Files Affected:

  • (modified) llvm/include/llvm/IR/IntrinsicsNVVM.td (+15)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp (+4-1)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+24)
  • (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+4-12)
  • (modified) llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp (+43-46)
  • (modified) llvm/lib/Target/NVPTX/NVPTXUtilities.cpp (+24-9)
  • (modified) llvm/lib/Target/NVPTX/NVPTXUtilities.h (+1-1)
  • (modified) llvm/test/CodeGen/NVPTX/bug21465.ll (+1-1)
  • (modified) llvm/test/CodeGen/NVPTX/forward-ld-param.ll (+1-1)
  • (modified) llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll (+107-94)
  • (modified) llvm/test/CodeGen/NVPTX/lower-args.ll (+2-2)
  • (modified) llvm/test/CodeGen/NVPTX/lower-byval-args.ll (+161-137)
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 4aeb1d8a2779e..5d89b0ae2b484 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -1878,6 +1878,21 @@ def int_nvvm_ptr_param_to_gen: DefaultAttrsIntrinsic<[llvm_anyptr_ty],
                                    [IntrNoMem, IntrSpeculatable, IntrNoCallback],
                                    "llvm.nvvm.ptr.param.to.gen">;
 
+// Represents an explicit hole in the LLVM IR type system. It may be inserted by
+// the compiler in cases where a pointer is of the wrong type. In the backend
+// this intrinsic will be folded away and not equate to any instruction. It
+// should not be used by any frontend and should only be considered well defined
+// when added in the following cases:
+//
+//  - NVPTXLowerArgs: When wrapping a byval pointer argument to a kernel
+//    function to convert the address space from generic (0) to param (101).
+//    This accounts for the fact that the parameter symbols will occupy this
+//    space when lowered during ISel.
+//
+def int_nvvm_internal_noop_addrspacecast :
+  DefaultAttrsIntrinsic<[llvm_anyptr_ty], [llvm_anyptr_ty],
+                        [IntrNoMem, IntrSpeculatable, NoUndef<ArgIndex<0>>, NoUndef<RetIndex>]>;
+
 // Move intrinsics, used in nvvm internally
 
 def int_nvvm_move_i16 : Intrinsic<[llvm_i16_ty], [llvm_i16_ty], [IntrNoMem],
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index ec1f969494cd1..486c7c815435a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -985,6 +985,9 @@ void NVPTXDAGToDAGISel::SelectAddrSpaceCast(SDNode *N) {
     case ADDRESS_SPACE_LOCAL:
       Opc = TM.is64Bit() ? NVPTX::cvta_local_64 : NVPTX::cvta_local;
       break;
+    case ADDRESS_SPACE_PARAM:
+      Opc = TM.is64Bit() ? NVPTX::cvta_param_64 : NVPTX::cvta_param;
+      break;
     }
     ReplaceNode(N, CurDAG->getMachineNode(Opc, DL, N->getValueType(0), Src));
     return;
@@ -1008,7 +1011,7 @@ void NVPTXDAGToDAGISel::SelectAddrSpaceCast(SDNode *N) {
       Opc = TM.is64Bit() ? NVPTX::cvta_to_local_64 : NVPTX::cvta_to_local;
       break;
     case ADDRESS_SPACE_PARAM:
-      Opc = TM.is64Bit() ? NVPTX::IMOV64r : NVPTX::IMOV32r;
+      Opc = TM.is64Bit() ? NVPTX::cvta_to_param_64 : NVPTX::cvta_to_param;
       break;
     }
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 9bde2a976e164..166785a79ec4c 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -1017,6 +1017,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
                      {MVT::v2i32, MVT::v4i32, MVT::v8i32, MVT::v16i32,
                       MVT::v32i32, MVT::v64i32, MVT::v128i32},
                      Custom);
+
+  setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::Other, Custom);
 }
 
 const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
@@ -1434,6 +1436,17 @@ static MachinePointerInfo refinePtrAS(SDValue &Ptr, SelectionDAG &DAG,
 
     return MachinePointerInfo(ADDRESS_SPACE_LOCAL);
   }
+
+  // Peel of an addrspacecast to generic and load directly from the specific
+  // address space.
+  if (Ptr->getOpcode() == ISD::ADDRSPACECAST) {
+    const auto *ASC = cast<AddrSpaceCastSDNode>(Ptr);
+    if (ASC->getDestAddressSpace() == ADDRESS_SPACE_GENERIC) {
+      Ptr = ASC->getOperand(0);
+      return MachinePointerInfo(ASC->getSrcAddressSpace());
+    }
+  }
+
   return MachinePointerInfo();
 }
 
@@ -2754,6 +2767,15 @@ static SDValue LowerIntrinsicVoid(SDValue Op, SelectionDAG &DAG) {
   return Op;
 }
 
+static SDValue lowerIntrinsicWOChain(SDValue Op, SelectionDAG &DAG) {
+  switch (Op->getConstantOperandVal(0)) {
+  default:
+    return Op;
+  case Intrinsic::nvvm_internal_noop_addrspacecast:
+    return Op.getOperand(1);
+  }
+}
+
 // In PTX 64-bit CTLZ and CTPOP are supported, but they return a 32-bit value.
 // Lower these into a node returning the correct type which is zero-extended
 // back to the correct size.
@@ -2863,6 +2885,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
     return LowerGlobalAddress(Op, DAG);
   case ISD::INTRINSIC_W_CHAIN:
     return Op;
+  case ISD::INTRINSIC_WO_CHAIN:
+    return lowerIntrinsicWOChain(Op, DAG);
   case ISD::INTRINSIC_VOID:
     return LowerIntrinsicVoid(Op, DAG);
   case ISD::BUILD_VECTOR:
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 8528ff702f236..266f379607690 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -2335,18 +2335,10 @@ multiclass G_TO_NG<string Str> {
           "cvta.to." # Str # ".u64 \t$result, $src;", []>;
 }
 
-defm cvta_local  : NG_TO_G<"local">;
-defm cvta_shared : NG_TO_G<"shared">;
-defm cvta_global : NG_TO_G<"global">;
-defm cvta_const  : NG_TO_G<"const">;
-
-defm cvta_to_local  : G_TO_NG<"local">;
-defm cvta_to_shared : G_TO_NG<"shared">;
-defm cvta_to_global : G_TO_NG<"global">;
-defm cvta_to_const  : G_TO_NG<"const">;
-
-// nvvm.ptr.param.to.gen
-defm cvta_param : NG_TO_G<"param">;
+foreach space = ["local", "shared", "global", "const", "param"] in {
+  defm cvta_#space : NG_TO_G<space>;
+  defm cvta_to_#space : G_TO_NG<space>;
+}
 
 def : Pat<(int_nvvm_ptr_param_to_gen i32:$src),
           (cvta_param $src)>;
diff --git a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
index 6452fa05947dd..770914fcc2f28 100644
--- a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
@@ -265,18 +265,9 @@ static void convertToParamAS(Use *OldUse, Value *Param, bool HasCvtaParam,
     if (HasCvtaParam) {
       auto GetParamAddrCastToGeneric =
           [](Value *Addr, Instruction *OriginalUser) -> Value * {
-        PointerType *ReturnTy =
-            PointerType::get(OriginalUser->getContext(), ADDRESS_SPACE_GENERIC);
-        Function *CvtToGen = Intrinsic::getOrInsertDeclaration(
-            OriginalUser->getModule(), Intrinsic::nvvm_ptr_param_to_gen,
-            {ReturnTy, PointerType::get(OriginalUser->getContext(),
-                                        ADDRESS_SPACE_PARAM)});
-
-        // Cast param address to generic address space
-        Value *CvtToGenCall =
-            CallInst::Create(CvtToGen, Addr, Addr->getName() + ".gen",
-                             OriginalUser->getIterator());
-        return CvtToGenCall;
+        IRBuilder<> IRB(OriginalUser);
+        Type *GenTy = IRB.getPtrTy(ADDRESS_SPACE_GENERIC);
+        return IRB.CreateAddrSpaceCast(Addr, GenTy, Addr->getName() + ".gen");
       };
       auto *ParamInGenericAS =
           GetParamAddrCastToGeneric(I.NewParam, I.OldInstruction);
@@ -515,23 +506,24 @@ void copyByValParam(Function &F, Argument &Arg) {
   BasicBlock::iterator FirstInst = F.getEntryBlock().begin();
   Type *StructType = Arg.getParamByValType();
   const DataLayout &DL = F.getDataLayout();
-  AllocaInst *AllocA = new AllocaInst(StructType, DL.getAllocaAddrSpace(),
-                                      Arg.getName(), FirstInst);
+  IRBuilder<> IRB(&*FirstInst);
+  AllocaInst *AllocA = IRB.CreateAlloca(StructType, nullptr, Arg.getName());
   // Set the alignment to alignment of the byval parameter. This is because,
   // later load/stores assume that alignment, and we are going to replace
   // the use of the byval parameter with this alloca instruction.
-  AllocA->setAlignment(F.getParamAlign(Arg.getArgNo())
-                           .value_or(DL.getPrefTypeAlign(StructType)));
+  AllocA->setAlignment(
+      Arg.getParamAlign().value_or(DL.getPrefTypeAlign(StructType)));
   Arg.replaceAllUsesWith(AllocA);
 
-  Value *ArgInParam = new AddrSpaceCastInst(
-      &Arg, PointerType::get(Arg.getContext(), ADDRESS_SPACE_PARAM),
-      Arg.getName(), FirstInst);
+  Value *ArgInParam =
+      IRB.CreateIntrinsic(Intrinsic::nvvm_internal_noop_addrspacecast,
+                          {IRB.getPtrTy(ADDRESS_SPACE_PARAM), Arg.getType()},
+                          &Arg, {}, Arg.getName());
+
   // Be sure to propagate alignment to this load; LLVM doesn't know that NVPTX
   // addrspacecast preserves alignment.  Since params are constant, this load
   // is definitely not volatile.
   const auto ArgSize = *AllocA->getAllocationSize(DL);
-  IRBuilder<> IRB(&*FirstInst);
   IRB.CreateMemCpy(AllocA, AllocA->getAlign(), ArgInParam, AllocA->getAlign(),
                    ArgSize);
 }
@@ -539,9 +531,9 @@ void copyByValParam(Function &F, Argument &Arg) {
 
 static void handleByValParam(const NVPTXTargetMachine &TM, Argument *Arg) {
   Function *Func = Arg->getParent();
-  bool HasCvtaParam =
-      TM.getSubtargetImpl(*Func)->hasCvtaParam() && isKernelFunction(*Func);
-  bool IsGridConstant = HasCvtaParam && isParamGridConstant(*Arg);
+  assert(isKernelFunction(*Func));
+  const bool HasCvtaParam = TM.getSubtargetImpl(*Func)->hasCvtaParam();
+  const bool IsGridConstant = HasCvtaParam && isParamGridConstant(*Arg);
   const DataLayout &DL = Func->getDataLayout();
   BasicBlock::iterator FirstInst = Func->getEntryBlock().begin();
   Type *StructType = Arg->getParamByValType();
@@ -558,9 +550,11 @@ static void handleByValParam(const NVPTXTargetMachine &TM, Argument *Arg) {
     for (Use &U : Arg->uses())
       UsesToUpdate.push_back(&U);
 
-    Value *ArgInParamAS = new AddrSpaceCastInst(
-        Arg, PointerType::get(StructType->getContext(), ADDRESS_SPACE_PARAM),
-        Arg->getName(), FirstInst);
+    IRBuilder<> IRB(&*FirstInst);
+    Value *ArgInParamAS = IRB.CreateIntrinsic(
+        Intrinsic::nvvm_internal_noop_addrspacecast,
+        {IRB.getPtrTy(ADDRESS_SPACE_PARAM), Arg->getType()}, {Arg});
+
     for (Use *U : UsesToUpdate)
       convertToParamAS(U, ArgInParamAS, HasCvtaParam, IsGridConstant);
     LLVM_DEBUG(dbgs() << "No need to copy or cast " << *Arg << "\n");
@@ -578,30 +572,31 @@ static void handleByValParam(const NVPTXTargetMachine &TM, Argument *Arg) {
   // However, we're still not allowed to write to it. If the user specified
   // `__grid_constant__` for the argument, we'll consider escaped pointer as
   // read-only.
-  if (HasCvtaParam && (ArgUseIsReadOnly || IsGridConstant)) {
+  if (IsGridConstant || (HasCvtaParam && ArgUseIsReadOnly)) {
     LLVM_DEBUG(dbgs() << "Using non-copy pointer to " << *Arg << "\n");
     // Replace all argument pointer uses (which might include a device function
     // call) with a cast to the generic address space using cvta.param
     // instruction, which avoids a local copy.
     IRBuilder<> IRB(&Func->getEntryBlock().front());
 
-    // Cast argument to param address space
-    auto *CastToParam = cast<AddrSpaceCastInst>(IRB.CreateAddrSpaceCast(
-        Arg, IRB.getPtrTy(ADDRESS_SPACE_PARAM), Arg->getName() + ".param"));
+    // Cast argument to param address space. Because the backend will emit the
+    // argument already in the param address space, we need to use the noop
+    // intrinsic, this had the added benefit of preventing other optimizations
+    // from folding away this pair of addrspacecasts.
+    auto *ParamSpaceArg =
+        IRB.CreateIntrinsic(Intrinsic::nvvm_internal_noop_addrspacecast,
+                            {IRB.getPtrTy(ADDRESS_SPACE_PARAM), Arg->getType()},
+                            Arg, {}, Arg->getName() + ".param");
 
-    // Cast param address to generic address space. We do not use an
-    // addrspacecast to generic here, because, LLVM considers `Arg` to be in the
-    // generic address space, and a `generic -> param` cast followed by a `param
-    // -> generic` cast will be folded away. The `param -> generic` intrinsic
-    // will be correctly lowered to `cvta.param`.
-    Value *CvtToGenCall = IRB.CreateIntrinsic(
-        IRB.getPtrTy(ADDRESS_SPACE_GENERIC), Intrinsic::nvvm_ptr_param_to_gen,
-        CastToParam, nullptr, CastToParam->getName() + ".gen");
+    // Cast param address to generic address space.
+    Value *GenericArg = IRB.CreateAddrSpaceCast(
+        ParamSpaceArg, IRB.getPtrTy(ADDRESS_SPACE_GENERIC),
+        Arg->getName() + ".gen");
 
-    Arg->replaceAllUsesWith(CvtToGenCall);
+    Arg->replaceAllUsesWith(GenericArg);
 
     // Do not replace Arg in the cast to param space
-    CastToParam->setOperand(0, Arg);
+    ParamSpaceArg->setOperand(0, Arg);
   } else
     copyByValParam(*Func, *Arg);
 }
@@ -715,12 +710,14 @@ static bool copyFunctionByValArgs(Function &F) {
   LLVM_DEBUG(dbgs() << "Creating a copy of byval args of " << F.getName()
                     << "\n");
   bool Changed = false;
-  for (Argument &Arg : F.args())
-    if (Arg.getType()->isPointerTy() && Arg.hasByValAttr() &&
-        !(isParamGridConstant(Arg) && isKernelFunction(F))) {
-      copyByValParam(F, Arg);
-      Changed = true;
-    }
+  if (isKernelFunction(F)) {
+    for (Argument &Arg : F.args())
+      if (Arg.getType()->isPointerTy() && Arg.hasByValAttr() &&
+          !isParamGridConstant(Arg)) {
+        copyByValParam(F, Arg);
+        Changed = true;
+      }
+  }
   return Changed;
 }
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
index 3d9d2ae372080..0cbebc6995c9a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
@@ -16,11 +16,13 @@
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringRef.h"
+#include "llvm/IR/Argument.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/GlobalVariable.h"
 #include "llvm/IR/Module.h"
 #include "llvm/Support/Alignment.h"
+#include "llvm/Support/ModRef.h"
 #include "llvm/Support/Mutex.h"
 #include <cstdint>
 #include <cstring>
@@ -228,17 +230,30 @@ static std::optional<uint64_t> getVectorProduct(ArrayRef<unsigned> V) {
   return std::accumulate(V.begin(), V.end(), 1, std::multiplies<uint64_t>{});
 }
 
-bool isParamGridConstant(const Value &V) {
-  if (const Argument *Arg = dyn_cast<Argument>(&V)) {
-    // "grid_constant" counts argument indices starting from 1
-    if (Arg->hasByValAttr() &&
-        argHasNVVMAnnotation(*Arg, "grid_constant",
-                             /*StartArgIndexAtOne*/ true)) {
-      assert(isKernelFunction(*Arg->getParent()) &&
-             "only kernel arguments can be grid_constant");
+bool isParamGridConstant(const Argument &Arg) {
+  assert(isKernelFunction(*Arg.getParent()) &&
+         "only kernel arguments can be grid_constant");
+
+  if (!Arg.hasByValAttr())
+    return false;
+
+  // Lowering an argument as a grid_constant violates the byval semantics (and
+  // the C++ API) by reusing the same memory location for the argument across
+  // multiple threads. If an argument doesn't read memory and its address is not
+  // captured (its address is not compared with any value), then the tweak of
+  // the C++ API and byval semantics is unobservable by the program and we can
+  // lower the arg as a grid_constant.
+  if (Arg.onlyReadsMemory()) {
+    const auto CI = Arg.getAttributes().getCaptureInfo();
+    if (!capturesAddress(CI) && !capturesFullProvenance(CI))
       return true;
-    }
   }
+
+  // "grid_constant" counts argument indices starting from 1
+  if (argHasNVVMAnnotation(Arg, "grid_constant",
+                           /*StartArgIndexAtOne*/ true))
+    return true;
+
   return false;
 }
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.h b/llvm/lib/Target/NVPTX/NVPTXUtilities.h
index 9283b398a9c14..9adbb645deb4a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXUtilities.h
+++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.h
@@ -63,7 +63,7 @@ inline bool isKernelFunction(const Function &F) {
   return F.getCallingConv() == CallingConv::PTX_Kernel;
 }
 
-bool isParamGridConstant(const Value &);
+bool isParamGridConstant(const Argument &);
 
 inline MaybeAlign getAlign(const Function &F, unsigned Index) {
   return F.getAttributes().getAttributes(Index).getStackAlignment();
diff --git a/llvm/test/CodeGen/NVPTX/bug21465.ll b/llvm/test/CodeGen/NVPTX/bug21465.ll
index 76300e3cfdc5b..21ec05c70e3ad 100644
--- a/llvm/test/CodeGen/NVPTX/bug21465.ll
+++ b/llvm/test/CodeGen/NVPTX/bug21465.ll
@@ -12,7 +12,7 @@ define ptx_kernel void @_Z11TakesStruct1SPi(ptr byval(%struct.S) nocapture reado
 entry:
 ; CHECK-LABEL: @_Z11TakesStruct1SPi
 ; PTX-LABEL: .visible .entry _Z11TakesStruct1SPi(
-; CHECK: addrspacecast ptr %input to ptr addrspace(101)
+; CHECK: call ptr addrspace(101) @llvm.nvvm.internal.noop.addrspacecast.p101.p0(ptr %input)
   %b = getelementptr inbounds %struct.S, ptr %input, i64 0, i32 1
   %0 = load i32, ptr %b, align 4
 ; PTX-NOT: ld.param.u32 {{%r[0-9]+}}, [{{%rd[0-9]+}}]
diff --git a/llvm/test/CodeGen/NVPTX/forward-ld-param.ll b/llvm/test/CodeGen/NVPTX/forward-ld-param.ll
index 6d9710e6d2272..80ae8aac39115 100644
--- a/llvm/test/CodeGen/NVPTX/forward-ld-param.ll
+++ b/llvm/test/CodeGen/NVPTX/forward-ld-param.ll
@@ -65,7 +65,7 @@ define void @test_ld_param_byval(ptr byval(i32) %a) {
 ; CHECK-LABEL: test_ld_param_byval(
 ; CHECK:       {
 ; CHECK-NEXT:    .reg .b32 %r<2>;
-; CHECK-NEXT:    .reg .b64 %rd<3>;
+; CHECK-NEXT:    .reg .b64 %rd<2>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0:
 ; CHECK-NEXT:    ld.param.u32 %r1, [test_ld_param_byval_param_0];
diff --git a/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll b/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll
index 836a7d78a0cc5..46535a7a91c28 100644
--- a/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll
+++ b/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll
@@ -72,7 +72,7 @@ define ptx_kernel void @grid_const_int(ptr byval(i32) align 4 %input1, i32 %inpu
 ; PTX-NEXT:    ret;
 ; OPT-LABEL: define ptx_kernel void @grid_const_int(
 ; OPT-SAME: ptr byval(i32) align 4 [[INPUT1:%.*]], i32 [[INPUT2:%.*]], ptr [[OUT:%.*]], i32 [[N:%.*]]) #[[ATTR0]] {
-; OPT-NEXT:    [[INPUT11:%.*]] = addrspacecast ptr [[INPUT1]] to ptr addrspace(101)
+; OPT-NEXT:    [[INPUT11:%.*]] = call ptr addrspace(101) @llvm.nvvm.internal.noop.addrspacecast.p101.p0(ptr [[INPUT1]])
 ; OPT-NEXT:    [[TMP:%.*]] = load i32, ptr addrspace(101) [[INPUT11]], align 4
 ; OPT-NEXT:    [[ADD:%.*]] = add i32 [[TMP]], [[INPUT2]]
 ; OPT-NEXT:    store i32 [[ADD]], ptr [[OUT]], align 4
@@ -101,7 +101,7 @@ define ptx_kernel void @grid_const_struct(ptr byval(%struct.s) align 4 %input, p
 ; PTX-NEXT:    ret;
 ; OPT-LABEL: define ptx_kernel void @grid_const_struct(
 ; OPT-SAME: ptr byval([[STRUCT_S:%.*]]) align 4 [[INPUT:%.*]], ptr [[OUT:%.*]]) #[[ATTR0]] {
-; OPT-NEXT:    [[INPUT1:%.*]] = addrspacecast ptr [[INPUT]] to ptr addrspace(101)
+; OPT-NEXT:    [[INPUT1:%.*]] = call ptr addrspace(101) @llvm.nvvm.internal.noop.addrspacecast.p101.p0(ptr [[INPUT]])
 ; OPT-NEXT:    [[GEP13:%.*]] = getelementptr inbounds [[STRUCT_S]], ptr addrspace(101) [[INPUT1]], i32 0, i32 0
 ; OPT-NEXT:    [[GEP22:%.*]] = getelementptr inbounds [[STRUCT_S]], ptr addrspace(101) [[INPUT1]], i32 0, i32 1
 ; OPT-NEXT:    [[TMP1:%.*]] = load i32, ptr addrspace(101) [[GEP13]], align 4
@@ -122,16 +122,15 @@ define ptx_kernel void @grid_const_escape(ptr byval(%struct.s) align 4 %input) {
 ; PTX-LABEL: grid_const_escape(
 ; PTX:       {
 ; PTX-NEXT:    .reg .b32 %r<3>;
-; PTX-NEXT:    .reg .b64 %rd<5>;
+; PTX-NEXT:    .reg .b64 %rd<4>;
 ; PTX-EMPTY:
 ; PTX-NEXT:  // %bb.0:
 ; PTX-NEXT:    mov.b64 %rd2, grid_const_escape_param_0;
-; PTX-NEXT:    mov.b64 %rd3, %rd2;
-; PTX-NEXT:    cvta.param.u64 %rd4, %rd3;
+; PTX-NEXT:    cvta.param.u64 %rd3, %rd2;
 ; PTX-NEXT:    mov.b64 %rd1, escape;
 ; PTX-NEXT:    { // callseq 0, 0
 ; PTX-NEXT:    .param .b64 param0;
-; PTX-NEXT:    st.param.b64 [param0], %rd4;
+; PTX-NEXT:    st.param.b64 [param0], %rd3;
 ; PTX-NEXT:    .param .b32 retval0;
 ; PTX-NEXT:    prototype_0 : .callprototype (.param .b32 _) _ (.param .b64 _);
 ; PTX-NEXT:    call (retval0),
@@ -145,8 +144,8 @@ define ptx_kernel void @grid_const_escape(ptr byval(%struct.s) align 4 %input) {
 ; PTX-NEXT:    ret;
 ; OPT-LABEL: define ptx_kernel void @grid_const_escape(
 ; OPT-SAME: ptr byval([[STRUCT_S:%.*]]) align 4 [[INPUT:%.*]]) #[[ATTR0]] {
-; OPT-NEXT:    [[INPUT_PARAM:%.*]] = addrspacecast ptr [[INPUT]] to ptr addrspace(101)
-; OPT-NEXT:    [[INPUT_PARAM_GEN:%.*]] = call ptr @llvm.nvvm.ptr.param.to.gen.p0.p101(ptr addrspace(101) [[INPUT_PARAM]])
...
[truncated]

@AlexMaclean AlexMaclean force-pushed the dev/amaclean/upstream/better-gc-params branch from 0206b0e to 106b0a0 Compare April 16, 2025 20:52
Copy link
Member

@Artem-B Artem-B left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice improvement, overall.

@@ -12,7 +12,7 @@ define ptx_kernel void @_Z11TakesStruct1SPi(ptr byval(%struct.S) nocapture reado
entry:
; CHECK-LABEL: @_Z11TakesStruct1SPi
; PTX-LABEL: .visible .entry _Z11TakesStruct1SPi(
; CHECK: addrspacecast ptr %input to ptr addrspace(101)
; CHECK: call ptr addrspace(101) @llvm.nvvm.internal.noop.addrspacecast.p101.p0(ptr %input)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Naming nit: I think noop.addrspacecast does not quite match what we want to do here. It's not that we're actually casting the pointer. It's more that we're letting LLVM know the actual AS of the argument which happens to be represented by a generic pointer.

So, technically it does look like a no-op cast (changes AS, does not actually do anything). I think "wrapped" that you used in the description would be a better name here. How about llvm.nvvm.internal.addrspace.wrap ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I don't have very strong feelings on the name either way. I've updated to use this name.

Comment on lines 145 to 151
; SM_60-NEXT: [[S1:%.*]] = alloca [[STRUCT_S]], align 4
; SM_60-NEXT: [[S2:%.*]] = call ptr addrspace(101) @llvm.nvvm.internal.noop.addrspacecast.p101.p0(ptr [[S]])
; SM_60-NEXT: call void @llvm.memcpy.p0.p101.i64(ptr align 4 [[S1]], ptr addrspace(101) align 4 [[S2]], i64 8, i1 false)
; SM_60-NEXT: [[B:%.*]] = getelementptr inbounds nuw i8, ptr [[S1]], i64 4
; SM_60-NEXT: [[ASC:%.*]] = addrspacecast ptr [[B]] to ptr addrspace(101)
; SM_60-NEXT: [[ASC0:%.*]] = addrspacecast ptr addrspace(101) [[ASC]] to ptr
; SM_60-NEXT: [[I:%.*]] = load i32, ptr [[ASC0]], align 4
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're doing something really wrong here:

; SM_60-NEXT:    [[S1:%.*]] = alloca [[STRUCT_S]], align 4
...
; SM_60-NEXT:    [[B:%.*]] = getelementptr inbounds nuw i8, ptr [[S1]], i64 4
; SM_60-NEXT:    [[ASC:%.*]] = addrspacecast ptr [[B]] to ptr addrspace(101)
; SM_60-NEXT:    [[ASC0:%.*]] = addrspacecast ptr addrspace(101) [[ASC]] to ptr

S1 is an alloca, in the local AS, yet we're still casting it to param AS. The only reason we get away with that is because we cast it back to generic right away and/or because alloca + memcpy get eliminated and we do end up pointing to the original location in the param space, but the IR as captured by the test is wrong.

Granted, it's been wrong before your patch, too. but this looks rather scary. If something ends up holding onto ASC and we'll end up trying to access it via ld.param we'd be in trouble.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure exactly what is going on here. The problematic addrspacecast instructions are in the source of this test, they aren't being inserted by NVPTXLowerArgs. I'd lean towards removing the casts in this test as I don't think there is a situation in which we'd expect to see them coming into NVPTXLowerArgs. Do you remember anything about your motivation for this test in d0615a9?

Copy link
Member

@Artem-B Artem-B Apr 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The broken test was indeed introduced by me,but quite a bit more recently, in b7b28e7

The test itself does appear to be OK on the surface, but things break down when nvptx-lower-args inserts a copy into an alloca which breaks the original IR which assumed that it does access a byval argument.

The question is whether the test was correct to assume that there will be no local copy.
Considering that we are only reading from the byval argument and the test is named read_only_gep_asc0, it seems to be reasonable. So why did we end up with an alloca then?

Edit: graphene somehow posted a stale version of my comment. Edited to update it

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The question is whether the test was correct to assume that there will be no local copy.
Considering that we are only reading from the byval argument and the test is named read_only_gep_asc0, it seems to be reasonable. So why did we end up with an alloca then?

Choosing to lower a parameter that hasn't been explicitly marked as "grid_constant" as a grid-constant seems like an optimization decision NVPTXLowerArgs may freely make based on it's own set of heuristics. I think it is not a good idea for any other pass or front-end to make assumptions about whether or not this will happen (fortunately, as far as I know there are no cases of this today). In this case, the addrspacecast instructions are extra dangerous because cvta.param isn't supported on sm_60 so we're reliant on other optimizations to remove them in order to compile correctly.

In this specific test it looks like we're not treating this pointer as grid-constant because we treat the cast back to generic as an escape:

if (ASC.getDestAddressSpace() != ADDRESS_SPACE_PARAM)
return PI.setEscapedAndAborted(&ASC);

It might make sense to allow this case as well. But I don't know if I fully understand all the ramifications or real world motivation for that so I'd prefer to leave it to the side for now, as this patch is already somewhat large.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for double checking the reason for us still making a copy.

I agree that the test is somewhat questionable and the IR should not have been casing argument pointer to generic AS and back to param space and it may indeed be broken in each of these directions.

In this case, the addrspacecast instructions are extra dangerous because cvta.param isn't supported on sm_60

Correct. That test indeed works only because both the alloca and the ASCs got eliminated because the code is trivial.
I guess that was partially the idea for the test that byval lowering can see through additional ASCs and I've goofed up by adding a supposedly no-op but actually invalid ASC to generic and back, and got lucky that it just happened to give me the result I expected, so I did not pay attention to what happened in-between test IR and the PTX.

For now we can remove or comment out this test.

It might make sense to allow this case as well.

Yup. We still do want the read-only access checks to be able to see through ASCs. The code you've pointed to was basically intended to allow generic->param ASCs added by the pass inferring address spaces, and we basically considered everything else to be an escape. Now that ASC from param to generic does exist in some cases, considering them acceptable during read-only-ness tracking would make sense, too. Whether you want to incorporate it into this patch or leave it out is up to you.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. I've commented-out the test for now and added a TODO to indicate we should follow up and figure out how to improve our support for addrspacecast within this pass.

@AlexMaclean AlexMaclean force-pushed the dev/amaclean/upstream/better-gc-params branch from 445a104 to de0c5aa Compare April 17, 2025 22:25
Copy link
Contributor

@akshayrdeodhar akshayrdeodhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@justinfargnoli justinfargnoli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed offline. Approved.

@AlexMaclean AlexMaclean merged commit 56910a8 into llvm:main Apr 21, 2025
11 checks passed
AlexMaclean added a commit that referenced this pull request Apr 25, 2025
After #136008 these intrinsics are no longer inserted by the
compiler and can be upgraded to addrspacecasts.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants