Skip to content

[NVPTX] Basic support for "grid_constant" #96125

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

akshayrdeodhar
Copy link
Contributor

@akshayrdeodhar akshayrdeodhar commented Jun 20, 2024

- Adds a helper function for checking whether an argument is a grid_constant.
- Adds support for cvta.param using changes from llvm#95289
- Supports escaped grid_constant pointers conservatively, by casting all uses to the generic address space with cvta.param.
@llvmbot
Copy link
Member

llvmbot commented Jun 20, 2024

@llvm/pr-subscribers-llvm-ir

@llvm/pr-subscribers-backend-nvptx

Author: Akshay Deodhar (akshayrdeodhar)

Changes
  • Adds a helper function for checking whether an argument is a grid_constant.
  • Adds support for cvta.param using changes from [NVPTX] Add NVPTX intrinsics for TMA copies #95289
  • Supports escaped grid_constant pointers conservatively, by casting all uses to the generic address space with cvta.param.

Patch is 22.20 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/96125.diff

6 Files Affected:

  • (modified) llvm/include/llvm/IR/IntrinsicsNVVM.td (+5)
  • (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+1)
  • (modified) llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp (+52-21)
  • (modified) llvm/lib/Target/NVPTX/NVPTXUtilities.cpp (+78-65)
  • (modified) llvm/lib/Target/NVPTX/NVPTXUtilities.h (+1)
  • (added) llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll (+155)
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 0a9139e0062ba..b7c828566e375 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -1596,6 +1596,11 @@ def int_nvvm_ptr_gen_to_param: Intrinsic<[llvm_anyptr_ty],
                                    [IntrNoMem, IntrSpeculatable, IntrNoCallback],
                                    "llvm.nvvm.ptr.gen.to.param">;
 
+// sm70+, PTX7.7+
+def int_nvvm_ptr_param_to_gen: DefaultAttrsIntrinsic<[llvm_anyptr_ty],
+                                     [llvm_anyptr_ty],
+                                   [IntrNoMem, IntrSpeculatable, IntrNoCallback]>;
+
 // Move intrinsics, used in nvvm internally
 
 def int_nvvm_move_i16 : Intrinsic<[llvm_i16_ty], [llvm_i16_ty], [IntrNoMem],
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index a65170e56aa24..3e7f8d63439c8 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -2475,6 +2475,7 @@ defm cvta_local  : NG_TO_G<"local", int_nvvm_ptr_local_to_gen, useShortPtrLocal>
 defm cvta_shared : NG_TO_G<"shared", int_nvvm_ptr_shared_to_gen, useShortPtrShared>;
 defm cvta_global : NG_TO_G<"global", int_nvvm_ptr_global_to_gen, False>;
 defm cvta_const  : NG_TO_G<"const", int_nvvm_ptr_constant_to_gen, useShortPtrConst>;
+defm cvta_param : NG_TO_G<"param", int_nvvm_ptr_param_to_gen, False>;
 
 defm cvta_to_local  : G_TO_NG<"local", int_nvvm_ptr_gen_to_local, useShortPtrLocal>;
 defm cvta_to_shared : G_TO_NG<"shared", int_nvvm_ptr_gen_to_shared, useShortPtrShared>;
diff --git a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
index cde02c25c4834..1116b8e6313f7 100644
--- a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
@@ -95,7 +95,9 @@
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/CodeGen/TargetPassConfig.h"
 #include "llvm/IR/Function.h"
+#include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicsNVPTX.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Type.h"
 #include "llvm/InitializePasses.h"
@@ -336,8 +338,9 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
     while (!ValuesToCheck.empty()) {
       Value *V = ValuesToCheck.pop_back_val();
       if (!IsALoadChainInstr(V)) {
-        LLVM_DEBUG(dbgs() << "Need a copy of " << *Arg << " because of " << *V
-                          << "\n");
+        LLVM_DEBUG(dbgs() << "Need a "
+                          << (isParamGridConstant(*Arg) ? "cast " : "copy ")
+                          << "of " << *Arg << " because of " << *V << "\n");
         (void)Arg;
         return false;
       }
@@ -366,27 +369,55 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
     return;
   }
 
-  // Otherwise we have to create a temporary copy.
   const DataLayout &DL = Func->getParent()->getDataLayout();
   unsigned AS = DL.getAllocaAddrSpace();
-  AllocaInst *AllocA = new AllocaInst(StructType, AS, Arg->getName(), FirstInst);
-  // Set the alignment to alignment of the byval parameter. This is because,
-  // later load/stores assume that alignment, and we are going to replace
-  // the use of the byval parameter with this alloca instruction.
-  AllocA->setAlignment(Func->getParamAlign(Arg->getArgNo())
-                           .value_or(DL.getPrefTypeAlign(StructType)));
-  Arg->replaceAllUsesWith(AllocA);
-
-  Value *ArgInParam = new AddrSpaceCastInst(
-      Arg, PointerType::get(StructType, ADDRESS_SPACE_PARAM), Arg->getName(),
-      FirstInst);
-  // Be sure to propagate alignment to this load; LLVM doesn't know that NVPTX
-  // addrspacecast preserves alignment.  Since params are constant, this load is
-  // definitely not volatile.
-  LoadInst *LI =
-      new LoadInst(StructType, ArgInParam, Arg->getName(),
-                   /*isVolatile=*/false, AllocA->getAlign(), FirstInst);
-  new StoreInst(LI, AllocA, FirstInst);
+  if (isParamGridConstant(*Arg)) {
+    // Writes to a grid constant are undefined behaviour. We do not need a
+    // temporary copy. When a pointer might have escaped, conservatively replace
+    // all of its uses (which might include a device function call) with a cast
+    // to the generic address space.
+    // TODO: only cast byval grid constant parameters at use points that need
+    // generic address (e.g., merging parameter pointers with other address
+    // space, or escaping to call-sites, inline-asm, memory), and use the
+    // parameter address space for normal loads.
+    IRBuilder<> IRB(&Func->getEntryBlock().front());
+
+    // Cast argument to param address space
+    AddrSpaceCastInst *CastToParam =
+        cast<AddrSpaceCastInst>(IRB.CreateAddrSpaceCast(
+            Arg, IRB.getPtrTy(ADDRESS_SPACE_PARAM), Arg->getName() + ".param"));
+
+    // Cast param address to generic address space
+    Value *CvtToGenCall = IRB.CreateIntrinsic(
+        IRB.getPtrTy(ADDRESS_SPACE_GENERIC), Intrinsic::nvvm_ptr_param_to_gen,
+        CastToParam, nullptr, CastToParam->getName() + ".gen");
+
+    Arg->replaceAllUsesWith(CvtToGenCall);
+
+    // Do not replace Arg in the cast to param space
+    CastToParam->setOperand(0, Arg);
+  } else {
+    // Otherwise we have to create a temporary copy.
+    AllocaInst *AllocA =
+        new AllocaInst(StructType, AS, Arg->getName(), FirstInst);
+    // Set the alignment to alignment of the byval parameter. This is because,
+    // later load/stores assume that alignment, and we are going to replace
+    // the use of the byval parameter with this alloca instruction.
+    AllocA->setAlignment(Func->getParamAlign(Arg->getArgNo())
+                             .value_or(DL.getPrefTypeAlign(StructType)));
+    Arg->replaceAllUsesWith(AllocA);
+
+    Value *ArgInParam = new AddrSpaceCastInst(
+        Arg, PointerType::get(Arg->getContext(), ADDRESS_SPACE_PARAM),
+        Arg->getName(), FirstInst);
+    // Be sure to propagate alignment to this load; LLVM doesn't know that NVPTX
+    // addrspacecast preserves alignment.  Since params are constant, this load
+    // is definitely not volatile.
+    LoadInst *LI =
+        new LoadInst(StructType, ArgInParam, Arg->getName(),
+                     /*isVolatile=*/false, AllocA->getAlign(), FirstInst);
+    new StoreInst(LI, AllocA, FirstInst);
+  }
 }
 
 void NVPTXLowerArgs::markPointerAsGlobal(Value *Ptr) {
diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
index 3a536db1c9727..96db2079ed59f 100644
--- a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
@@ -52,29 +52,45 @@ void clearAnnotationCache(const Module *Mod) {
   AC.Cache.erase(Mod);
 }
 
-static void cacheAnnotationFromMD(const MDNode *md, key_val_pair_t &retval) {
+static void readIntVecFromMDNode(const MDNode *MetadataNode,
+                                 std::vector<unsigned> &Vec) {
+  for (unsigned i = 0, e = MetadataNode->getNumOperands(); i != e; ++i) {
+    ConstantInt *Val =
+        mdconst::extract<ConstantInt>(MetadataNode->getOperand(i));
+    Vec.push_back(Val->getZExtValue());
+  }
+}
+
+static void cacheAnnotationFromMD(const MDNode *MetadataNode,
+                                  key_val_pair_t &retval) {
   auto &AC = getAnnotationCache();
   std::lock_guard<sys::Mutex> Guard(AC.Lock);
-  assert(md && "Invalid mdnode for annotation");
-  assert((md->getNumOperands() % 2) == 1 && "Invalid number of operands");
+  assert(MetadataNode && "Invalid mdnode for annotation");
+  assert((MetadataNode->getNumOperands() % 2) == 1 &&
+         "Invalid number of operands");
   // start index = 1, to skip the global variable key
   // increment = 2, to skip the value for each property-value pairs
-  for (unsigned i = 1, e = md->getNumOperands(); i != e; i += 2) {
+  for (unsigned i = 1, e = MetadataNode->getNumOperands(); i != e; i += 2) {
     // property
-    const MDString *prop = dyn_cast<MDString>(md->getOperand(i));
+    const MDString *prop = dyn_cast<MDString>(MetadataNode->getOperand(i));
     assert(prop && "Annotation property not a string");
+    std::string Key = prop->getString().str();
 
     // value
-    ConstantInt *Val = mdconst::dyn_extract<ConstantInt>(md->getOperand(i + 1));
-    assert(Val && "Value operand not a constant int");
-
-    std::string keyname = prop->getString().str();
-    if (retval.find(keyname) != retval.end())
-      retval[keyname].push_back(Val->getZExtValue());
-    else {
-      std::vector<unsigned> tmp;
-      tmp.push_back(Val->getZExtValue());
-      retval[keyname] = tmp;
+    if (ConstantInt *Val = mdconst::dyn_extract<ConstantInt>(
+            MetadataNode->getOperand(i + 1))) {
+      retval[Key].push_back(Val->getZExtValue());
+    } else if (MDNode *VecMd =
+                   dyn_cast<MDNode>(MetadataNode->getOperand(i + 1))) {
+      // assert: there can only exist one unique key value pair of
+      // the form (string key, MDNode node). Operands of such a node
+      // shall always be unsigned ints.
+      if (retval.find(Key) == retval.end()) {
+        readIntVecFromMDNode(VecMd, retval[Key]);
+        continue;
+      }
+    } else {
+      llvm_unreachable("Value operand not a constant int or an mdnode");
     }
   }
 }
@@ -153,9 +169,9 @@ bool findAllNVVMAnnotation(const GlobalValue *gv, const std::string &prop,
 
 bool isTexture(const Value &val) {
   if (const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
-    unsigned annot;
-    if (findOneNVVMAnnotation(gv, "texture", annot)) {
-      assert((annot == 1) && "Unexpected annotation on a texture symbol");
+    unsigned Annot;
+    if (findOneNVVMAnnotation(gv, "texture", Annot)) {
+      assert((Annot == 1) && "Unexpected annotation on a texture symbol");
       return true;
     }
   }
@@ -164,70 +180,68 @@ bool isTexture(const Value &val) {
 
 bool isSurface(const Value &val) {
   if (const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
-    unsigned annot;
-    if (findOneNVVMAnnotation(gv, "surface", annot)) {
-      assert((annot == 1) && "Unexpected annotation on a surface symbol");
+    unsigned Annot;
+    if (findOneNVVMAnnotation(gv, "surface", Annot)) {
+      assert((Annot == 1) && "Unexpected annotation on a surface symbol");
       return true;
     }
   }
   return false;
 }
 
-bool isSampler(const Value &val) {
-  const char *AnnotationName = "sampler";
-
-  if (const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
-    unsigned annot;
-    if (findOneNVVMAnnotation(gv, AnnotationName, annot)) {
-      assert((annot == 1) && "Unexpected annotation on a sampler symbol");
-      return true;
-    }
-  }
-  if (const Argument *arg = dyn_cast<Argument>(&val)) {
-    const Function *func = arg->getParent();
-    std::vector<unsigned> annot;
-    if (findAllNVVMAnnotation(func, AnnotationName, annot)) {
-      if (is_contained(annot, arg->getArgNo()))
+static bool argHasNVVMAnnotation(const Value &Val,
+                                 const std::string &Annotation,
+                                 const bool StartArgIndexAtOne = false) {
+  if (const Argument *Arg = dyn_cast<Argument>(&Val)) {
+    const Function *Func = Arg->getParent();
+    std::vector<unsigned> Annot;
+    if (findAllNVVMAnnotation(Func, Annotation, Annot)) {
+      const unsigned BaseOffset = StartArgIndexAtOne ? 1 : 0;
+      if (is_contained(Annot, BaseOffset + Arg->getArgNo())) {
         return true;
+      }
     }
   }
   return false;
 }
 
-bool isImageReadOnly(const Value &val) {
-  if (const Argument *arg = dyn_cast<Argument>(&val)) {
-    const Function *func = arg->getParent();
-    std::vector<unsigned> annot;
-    if (findAllNVVMAnnotation(func, "rdoimage", annot)) {
-      if (is_contained(annot, arg->getArgNo()))
-        return true;
+bool isParamGridConstant(const Value &V) {
+  if (const Argument *Arg = dyn_cast<Argument>(&V)) {
+    std::vector<unsigned> Annot;
+    // "grid_constant" counts argument indices starting from 1
+    if (Arg->hasByValAttr() &&
+        argHasNVVMAnnotation(*Arg, "grid_constant", true)) {
+      assert(isKernelFunction(*Arg->getParent()) &&
+             "only kernel arguments can be grid_constant");
+      return true;
     }
   }
   return false;
 }
 
-bool isImageWriteOnly(const Value &val) {
-  if (const Argument *arg = dyn_cast<Argument>(&val)) {
-    const Function *func = arg->getParent();
-    std::vector<unsigned> annot;
-    if (findAllNVVMAnnotation(func, "wroimage", annot)) {
-      if (is_contained(annot, arg->getArgNo()))
-        return true;
+bool isSampler(const Value &val) {
+  const char *AnnotationName = "sampler";
+
+  if (const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
+    unsigned Annot;
+    if (findOneNVVMAnnotation(gv, AnnotationName, Annot)) {
+      assert((Annot == 1) && "Unexpected annotation on a sampler symbol");
+      return true;
     }
   }
-  return false;
+  return argHasNVVMAnnotation(val, AnnotationName);
+}
+
+bool isImageReadOnly(const Value &val) {
+  return argHasNVVMAnnotation(val, "rdoimage");
+}
+
+bool isImageWriteOnly(const Value &val) {
+  return argHasNVVMAnnotation(val, "wroimage");
 }
 
 bool isImageReadWrite(const Value &val) {
-  if (const Argument *arg = dyn_cast<Argument>(&val)) {
-    const Function *func = arg->getParent();
-    std::vector<unsigned> annot;
-    if (findAllNVVMAnnotation(func, "rdwrimage", annot)) {
-      if (is_contained(annot, arg->getArgNo()))
-        return true;
-    }
-  }
-  return false;
+  return argHasNVVMAnnotation(val, "rdwrimage");
 }
 
 bool isImage(const Value &val) {
@@ -236,9 +250,9 @@ bool isImage(const Value &val) {
 
 bool isManaged(const Value &val) {
   if(const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
-    unsigned annot;
-    if (findOneNVVMAnnotation(gv, "managed", annot)) {
-      assert((annot == 1) && "Unexpected annotation on a managed symbol");
+    unsigned Annot;
+    if (findOneNVVMAnnotation(gv, "managed", Annot)) {
+      assert((Annot == 1) && "Unexpected annotation on a managed symbol");
       return true;
     }
   }
@@ -323,8 +337,7 @@ bool getMaxNReg(const Function &F, unsigned &x) {
 
 bool isKernelFunction(const Function &F) {
   unsigned x = 0;
-  bool retval = findOneNVVMAnnotation(&F, "kernel", x);
-  if (!retval) {
+  if (!findOneNVVMAnnotation(&F, "kernel", x)) {
     // There is no NVVM metadata, check the calling convention
     return F.getCallingConv() == CallingConv::PTX_Kernel;
   }
diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.h b/llvm/lib/Target/NVPTX/NVPTXUtilities.h
index e020bc0f02e96..c15ff6cae1f27 100644
--- a/llvm/lib/Target/NVPTX/NVPTXUtilities.h
+++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.h
@@ -62,6 +62,7 @@ bool getMaxClusterRank(const Function &, unsigned &);
 bool getMinCTASm(const Function &, unsigned &);
 bool getMaxNReg(const Function &, unsigned &);
 bool isKernelFunction(const Function &);
+bool isParamGridConstant(const Value &);
 
 MaybeAlign getAlign(const Function &, unsigned);
 MaybeAlign getAlign(const CallInst &, unsigned);
diff --git a/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll b/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll
new file mode 100644
index 0000000000000..46f54e0e6f4d4
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll
@@ -0,0 +1,155 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt < %s -S -nvptx-lower-args --mtriple nvptx64-nvidia-cuda | FileCheck %s --check-prefixes OPT
+; RUN: llc < %s -mcpu=sm_70 --mtriple nvptx64-nvidia-cuda | FileCheck %s --check-prefixes PTX
+
+define void @grid_const_int(ptr byval(i32) align 4 %input1, i32 %input2, ptr %out, i32 %n) {
+; PTX-LABEL: grid_const_int(
+; PTX-NOT:     ld.u32
+; PTX:         ld.param.{{.*}} [[R2:%.*]], [grid_const_int_param_0];
+; 
+; OPT-LABEL: define void @grid_const_int(
+; OPT-SAME: ptr byval(i32) align 4 [[INPUT1:%.*]], i32 [[INPUT2:%.*]], ptr [[OUT:%.*]], i32 [[N:%.*]]) {
+; OPT-NOT:     alloca
+; OPT:         [[INPUT11:%.*]] = addrspacecast ptr [[INPUT1]] to ptr addrspace(101)
+; OPT:         [[TMP:%.*]] = load i32, ptr addrspace(101) [[INPUT11]], align 4
+;
+  %tmp = load i32, ptr %input1, align 4
+  %add = add i32 %tmp, %input2
+  store i32 %add, ptr %out
+  ret void
+}
+
+%struct.s = type { i32, i32 }
+
+define void @grid_const_struct(ptr byval(%struct.s) align 4 %input, ptr %out){
+; PTX-LABEL: grid_const_struct(
+; PTX:       {
+; PTX-NOT:     ld.u32
+; PTX:         ld.param.{{.*}} [[R1:%.*]], [grid_const_struct_param_0];
+; PTX:         ld.param.{{.*}} [[R2:%.*]], [grid_const_struct_param_0+4];
+;
+; OPT-LABEL: define void @grid_const_struct(
+; OPT-SAME: ptr byval([[STRUCT_S:%.*]]) align 4 [[INPUT:%.*]], ptr [[OUT:%.*]]) {
+; OPT-NOT:     alloca
+; OPT:         [[INPUT1:%.*]] = addrspacecast ptr [[INPUT]] to ptr addrspace(101)
+; OPT:         [[GEP13:%.*]] = getelementptr inbounds [[STRUCT_S]], ptr addrspace(101) [[INPUT1]], i32 0, i32 0
+; OPT:         [[GEP22:%.*]] = getelementptr inbounds [[STRUCT_S]], ptr addrspace(101) [[INPUT1]], i32 0, i32 1
+; OPT:         [[TMP1:%.*]] = load i32, ptr addrspace(101) [[GEP13]], align 4
+; OPT:         [[TMP2:%.*]] = load i32, ptr addrspace(101) [[GEP22]], align 4
+;
+  %gep1 = getelementptr inbounds %struct.s, ptr %input, i32 0, i32 0
+  %gep2 = getelementptr inbounds %struct.s, ptr %input, i32 0, i32 1
+  %int1 = load i32, ptr %gep1
+  %int2 = load i32, ptr %gep2
+  %add = add i32 %int1, %int2
+  store i32 %add, ptr %out
+  ret void
+}
+
+define void @grid_const_escape(ptr byval(%struct.s) align 4 %input) {
+; PTX-LABEL: grid_const_escape(
+; PTX:       {
+; PTX-NOT:     .local
+; PTX:         cvta.param.{{.*}}
+; OPT-LABEL: define void @grid_const_escape(
+; OPT-SAME: ptr byval([[STRUCT_S:%.*]]) align 4 [[INPUT:%.*]]) {
+; OPT-NOT:     alloca [[STRUCT_S]]
+; OPT:         [[INPUT_PARAM:%.*]] = addrspacecast ptr [[INPUT]] to ptr addrspace(101)
+; OPT:         [[INPUT_PARAM_GEN:%.*]] = call ptr @llvm.nvvm.ptr.param.to.gen.p0.p101(ptr addrspace(101) [[INPUT_PARAM]])
+; OPT:         [[CALL:%.*]] = call i32 @escape(ptr [[INPUT_PARAM_GEN]])
+;
+  %call = call i32 @escape(ptr %input)
+  ret void
+}
+
+define void @multiple_grid_const_escape(ptr byval(%struct.s) align 4 %input, i32 %a, ptr byval(i32) align 4 %b) {
+; PTX-LABEL: multiple_grid_const_escape(
+; PTX:         mov.{{.*}} [[RD1:%.*]], multiple_grid_const_escape_param_0;
+; PTX:         mov.{{.*}} [[RD2:%.*]], multiple_grid_const_escape_param_2;
+; PTX:         mov.{{.*}} [[RD3:%.*]], [[RD2]];
+; PTX:         cvta.param.{{.*}} [[RD4:%.*]], [[RD3]];
+; PTX:         mov.u64 [[RD5:%.*]], [[RD1]];
+; PTX:         cvta.param.{{.*}} [[RD6:%.*]], [[RD5]];
+; PTX:         {
+; PTX:         st.param.b64 [param0+0], [[RD6]];
+; PTX:         st.param.b64 [param2+0], [[RD4]];
+;
+; OPT-LABEL: define void @multiple_grid_const_escape(
+; OPT-SAME: ptr byval([[STRUCT_S:%.*]]) align 4 [[INPUT:%.*]], i32 [[A:%.*]], ptr byval(i32) align 4 [[B:%.*]]) {
+; OPT-NOT:     alloca i32
+; OPT:         [[B_PARAM:%.*]] = addrspacecast ptr [[B]] to ptr addrspace(101)
+; OPT:         [[B_PARAM_GEN:%.*]] = call ptr @llvm.nvvm.ptr.param.to.gen.p0.p101(ptr addrspace(101) [[B_PARAM]])
+; OPT-NOT:     alloca [[STRUCT_S]]
+; OPT:         [[INPUT_PARAM:%.*]] = addrspacecast ptr [[INPUT]] to ptr addrspace(101)
+; OPT:         [[INPUT_PARAM_GEN:%.*]] = call ptr @llvm.nvvm.ptr.param.to.gen.p0.p101(ptr addrspace(101) [[INPUT_PARAM]])
+; OPT:         [[CALL:%.*]] = call i32 @escape3(ptr [[INPUT_PARAM_GEN]], ptr {{.*}}, ptr [[B_PARAM_GEN]])
+;
+  %a.addr = alloca i32, align 4
+  store i32 %a, ptr %a.addr, align 4
+  %call = call i32 @escape3(ptr %input, ptr %a.addr, ptr %b)
+  ret void
+}
+
+define void @grid_const_memory_escape(ptr byval(%struct.s) align 4 %input, ptr %addr) {
+; PTX-LABEL: grid_const_memory_escape(
+; PTX-NOT:     .local
+; PTX:         mov.b64 [[RD1:%.*]], grid_const_memory_escape_param_0;
+; PTX:         cvta.param.u64 [[RD3:%.*]], [[RD2:%.*]];
+; PTX:         st.global.u64 [[[RD4:%.*]]], [[RD3]];
+;
+; OPT-LABEL: define void @grid_const_memory_escape(
+; OPT-SAME: ptr byval([[STRUCT_S:%.*]]) align 4 [[INPUT:%.*]], ptr [[ADDR:%.*]]) {
+; OPT-NOT:     alloca [[STRUCT_S]]
+; OPT:         [[INPUT_PARAM:%.*]] = addrspacecast ptr [[INPUT]] to ptr addrspace(101)
+; OPT:         [[INPUT_PARAM_GEN:%....
[truncated]

@akshayrdeodhar akshayrdeodhar changed the title Add support for "grid_constant" in NVPTXLowerArgs. Basic support for "grid_constant" in NVPTXLowerArgs. Jun 20, 2024
@akshayrdeodhar akshayrdeodhar requested a review from jlebar June 20, 2024 00:42
@akshayrdeodhar
Copy link
Contributor Author

akshayrdeodhar commented Jun 20, 2024

@apaszke - unable to add you as a reviewer, please take a look

@akshayrdeodhar akshayrdeodhar changed the title Basic support for "grid_constant" in NVPTXLowerArgs. [NVPTX] Basic support for "grid_constant" Jun 20, 2024
@akshayrdeodhar akshayrdeodhar requested a review from Artem-B June 20, 2024 06:29
@harinvidia
Copy link

@Artem-B , @jlebar - Can one of you take a quick look and approve? Thanks.

Copy link
Member

@jlebar jlebar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@akshayrdeodhar
Copy link
Contributor Author

Thanks for the review!

@akshayrdeodhar akshayrdeodhar merged commit 687d6fb into llvm:main Jun 24, 2024
5 of 6 checks passed
@akshayrdeodhar akshayrdeodhar deleted the upstream/nvptx-grid-constant-support branch June 25, 2024 21:06
akshayrdeodhar added a commit that referenced this pull request Jun 30, 2024
- Supports escaped grid_constant pointers less conservatively. Casts
uses inside Calls, PtrToInts, Stores where the pointer is a _value
operand_ to generic address space, immediately before the escape, while
keeping other uses in the param address space

- Related to: #96125
lravenclaw pushed a commit to lravenclaw/llvm-project that referenced this pull request Jul 3, 2024
- Supports escaped grid_constant pointers less conservatively. Casts
uses inside Calls, PtrToInts, Stores where the pointer is a _value
operand_ to generic address space, immediately before the escape, while
keeping other uses in the param address space

- Related to: llvm#96125
AlexisPerry pushed a commit to llvm-project-tlp/llvm-project that referenced this pull request Jul 9, 2024
- Adds a helper function for checking whether an argument is a
[grid_constant](https://docs.nvidia.com/cuda/nvvm-ir-spec/index.html#supported-properties).
- Adds support for cvta.param using changes from
llvm#95289
- Supports escaped grid_constant pointers conservatively, by casting all
uses to the generic address space with cvta.param.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants