Skip to content

Reland "[NVPTX] Cleanup/Refactoring in NVPTX AsmPrinter and RegisterInfo (NFC)" #127277

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

AlexMaclean
Copy link
Member

@AlexMaclean AlexMaclean commented Feb 14, 2025

Originally landed in #126800

This version fixes a typo in NVPTXAsmPrinter::emitFunctionParamList where .surfref was erroneously replaced with .samplerref.

@llvmbot
Copy link
Member

llvmbot commented Feb 14, 2025

@llvm/pr-subscribers-backend-nvptx

Author: Alex MacLean (AlexMaclean)

Changes

Patch is 40.29 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/127277.diff

6 Files Affected:

  • (modified) llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp (+214-302)
  • (modified) llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h (+10-12)
  • (modified) llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp (+11-12)
  • (modified) llvm/lib/Target/NVPTX/NVPTXRegisterInfo.h (+3-3)
  • (modified) llvm/test/CodeGen/NVPTX/surf-read.ll (+1)
  • (modified) llvm/test/CodeGen/NVPTX/surf-write.ll (+1)
diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index 75d930d9f7b6f..c8e29c1da6ec4 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -27,6 +27,7 @@
 #include "cl_common_defines.h"
 #include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/APInt.h"
+#include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/SmallString.h"
@@ -47,6 +48,7 @@
 #include "llvm/CodeGen/TargetRegisterInfo.h"
 #include "llvm/CodeGen/ValueTypes.h"
 #include "llvm/CodeGenTypes/MachineValueType.h"
+#include "llvm/IR/Argument.h"
 #include "llvm/IR/Attributes.h"
 #include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/Constant.h"
@@ -93,20 +95,19 @@ using namespace llvm;
 
 #define DEPOTNAME "__local_depot"
 
-/// DiscoverDependentGlobals - Return a set of GlobalVariables on which \p V
+/// discoverDependentGlobals - Return a set of GlobalVariables on which \p V
 /// depends.
 static void
-DiscoverDependentGlobals(const Value *V,
+discoverDependentGlobals(const Value *V,
                          DenseSet<const GlobalVariable *> &Globals) {
-  if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(V))
+  if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(V)) {
     Globals.insert(GV);
-  else {
-    if (const User *U = dyn_cast<User>(V)) {
-      for (unsigned i = 0, e = U->getNumOperands(); i != e; ++i) {
-        DiscoverDependentGlobals(U->getOperand(i), Globals);
-      }
-    }
+    return;
   }
+
+  if (const User *U = dyn_cast<User>(V))
+    for (const auto &O : U->operands())
+      discoverDependentGlobals(O, Globals);
 }
 
 /// VisitGlobalVariableForEmission - Add \p GV to the list of GlobalVariable
@@ -127,8 +128,8 @@ VisitGlobalVariableForEmission(const GlobalVariable *GV,
 
   // Make sure we visit all dependents first
   DenseSet<const GlobalVariable *> Others;
-  for (unsigned i = 0, e = GV->getNumOperands(); i != e; ++i)
-    DiscoverDependentGlobals(GV->getOperand(i), Others);
+  for (const auto &O : GV->operands())
+    discoverDependentGlobals(O, Others);
 
   for (const GlobalVariable *GV : Others)
     VisitGlobalVariableForEmission(GV, Order, Visited, Visiting);
@@ -623,9 +624,8 @@ static bool usedInGlobalVarDef(const Constant *C) {
   if (!C)
     return false;
 
-  if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(C)) {
+  if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(C))
     return GV->getName() != "llvm.used";
-  }
 
   for (const User *U : C->users())
     if (const Constant *C = dyn_cast<Constant>(U))
@@ -635,25 +635,23 @@ static bool usedInGlobalVarDef(const Constant *C) {
   return false;
 }
 
-static bool usedInOneFunc(const User *U, Function const *&oneFunc) {
-  if (const GlobalVariable *othergv = dyn_cast<GlobalVariable>(U)) {
-    if (othergv->getName() == "llvm.used")
+static bool usedInOneFunc(const User *U, Function const *&OneFunc) {
+  if (const GlobalVariable *OtherGV = dyn_cast<GlobalVariable>(U))
+    if (OtherGV->getName() == "llvm.used")
       return true;
-  }
 
-  if (const Instruction *instr = dyn_cast<Instruction>(U)) {
-    if (instr->getParent() && instr->getParent()->getParent()) {
-      const Function *curFunc = instr->getParent()->getParent();
-      if (oneFunc && (curFunc != oneFunc))
+  if (const Instruction *I = dyn_cast<Instruction>(U)) {
+    if (const Function *CurFunc = I->getFunction()) {
+      if (OneFunc && (CurFunc != OneFunc))
         return false;
-      oneFunc = curFunc;
+      OneFunc = CurFunc;
       return true;
-    } else
-      return false;
+    }
+    return false;
   }
 
   for (const User *UU : U->users())
-    if (!usedInOneFunc(UU, oneFunc))
+    if (!usedInOneFunc(UU, OneFunc))
       return false;
 
   return true;
@@ -666,16 +664,15 @@ static bool usedInOneFunc(const User *U, Function const *&oneFunc) {
  * 2. Does it have local linkage?
  * 3. Is the global variable referenced only in one function?
  */
-static bool canDemoteGlobalVar(const GlobalVariable *gv, Function const *&f) {
-  if (!gv->hasLocalLinkage())
+static bool canDemoteGlobalVar(const GlobalVariable *GV, Function const *&f) {
+  if (!GV->hasLocalLinkage())
     return false;
-  PointerType *Pty = gv->getType();
-  if (Pty->getAddressSpace() != ADDRESS_SPACE_SHARED)
+  if (GV->getAddressSpace() != ADDRESS_SPACE_SHARED)
     return false;
 
   const Function *oneFunc = nullptr;
 
-  bool flag = usedInOneFunc(gv, oneFunc);
+  bool flag = usedInOneFunc(GV, oneFunc);
   if (!flag)
     return false;
   if (!oneFunc)
@@ -685,27 +682,22 @@ static bool canDemoteGlobalVar(const GlobalVariable *gv, Function const *&f) {
 }
 
 static bool useFuncSeen(const Constant *C,
-                        DenseMap<const Function *, bool> &seenMap) {
+                        const SmallPtrSetImpl<const Function *> &SeenSet) {
   for (const User *U : C->users()) {
     if (const Constant *cu = dyn_cast<Constant>(U)) {
-      if (useFuncSeen(cu, seenMap))
+      if (useFuncSeen(cu, SeenSet))
         return true;
     } else if (const Instruction *I = dyn_cast<Instruction>(U)) {
-      const BasicBlock *bb = I->getParent();
-      if (!bb)
-        continue;
-      const Function *caller = bb->getParent();
-      if (!caller)
-        continue;
-      if (seenMap.contains(caller))
-        return true;
+      if (const Function *Caller = I->getFunction())
+        if (SeenSet.contains(Caller))
+          return true;
     }
   }
   return false;
 }
 
 void NVPTXAsmPrinter::emitDeclarations(const Module &M, raw_ostream &O) {
-  DenseMap<const Function *, bool> seenMap;
+  SmallPtrSet<const Function *, 32> SeenSet;
   for (const Function &F : M) {
     if (F.getAttributes().hasFnAttr("nvptx-libcall-callee")) {
       emitDeclaration(&F, O);
@@ -731,7 +723,7 @@ void NVPTXAsmPrinter::emitDeclarations(const Module &M, raw_ostream &O) {
         }
         // Emit a declaration of this function if the function that
         // uses this constant expr has already been seen.
-        if (useFuncSeen(C, seenMap)) {
+        if (useFuncSeen(C, SeenSet)) {
           emitDeclaration(&F, O);
           break;
         }
@@ -739,23 +731,19 @@ void NVPTXAsmPrinter::emitDeclarations(const Module &M, raw_ostream &O) {
 
       if (!isa<Instruction>(U))
         continue;
-      const Instruction *instr = cast<Instruction>(U);
-      const BasicBlock *bb = instr->getParent();
-      if (!bb)
-        continue;
-      const Function *caller = bb->getParent();
-      if (!caller)
+      const Function *Caller = cast<Instruction>(U)->getFunction();
+      if (!Caller)
         continue;
 
       // If a caller has already been seen, then the caller is
       // appearing in the module before the callee. so print out
       // a declaration for the callee.
-      if (seenMap.contains(caller)) {
+      if (SeenSet.contains(Caller)) {
         emitDeclaration(&F, O);
         break;
       }
     }
-    seenMap[&F] = true;
+    SeenSet.insert(&F);
   }
   for (const GlobalAlias &GA : M.aliases())
     emitAliasDeclaration(&GA, O);
@@ -818,7 +806,7 @@ void NVPTXAsmPrinter::emitGlobals(const Module &M) {
 
   // Print out module-level global variables in proper order
   for (const GlobalVariable *GV : Globals)
-    printModuleLevelGV(GV, OS2, /*processDemoted=*/false, STI);
+    printModuleLevelGV(GV, OS2, /*ProcessDemoted=*/false, STI);
 
   OS2 << '\n';
 
@@ -839,16 +827,14 @@ void NVPTXAsmPrinter::emitGlobalAlias(const Module &M, const GlobalAlias &GA) {
 
 void NVPTXAsmPrinter::emitHeader(Module &M, raw_ostream &O,
                                  const NVPTXSubtarget &STI) {
-  O << "//\n";
-  O << "// Generated by LLVM NVPTX Back-End\n";
-  O << "//\n";
-  O << "\n";
+  const unsigned PTXVersion = STI.getPTXVersion();
 
-  unsigned PTXVersion = STI.getPTXVersion();
-  O << ".version " << (PTXVersion / 10) << "." << (PTXVersion % 10) << "\n";
-
-  O << ".target ";
-  O << STI.getTargetName();
+  O << "//\n"
+       "// Generated by LLVM NVPTX Back-End\n"
+       "//\n"
+       "\n"
+    << ".version " << (PTXVersion / 10) << "." << (PTXVersion % 10) << "\n"
+    << ".target " << STI.getTargetName();
 
   const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
   if (NTM.getDrvInterface() == NVPTX::NVCL)
@@ -871,16 +857,9 @@ void NVPTXAsmPrinter::emitHeader(Module &M, raw_ostream &O,
   if (HasFullDebugInfo)
     O << ", debug";
 
-  O << "\n";
-
-  O << ".address_size ";
-  if (NTM.is64Bit())
-    O << "64";
-  else
-    O << "32";
-  O << "\n";
-
-  O << "\n";
+  O << "\n"
+    << ".address_size " << (NTM.is64Bit() ? "64" : "32") << "\n"
+    << "\n";
 }
 
 bool NVPTXAsmPrinter::doFinalization(Module &M) {
@@ -928,41 +907,28 @@ void NVPTXAsmPrinter::emitLinkageDirective(const GlobalValue *V,
                                            raw_ostream &O) {
   if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() == NVPTX::CUDA) {
     if (V->hasExternalLinkage()) {
-      if (isa<GlobalVariable>(V)) {
-        const GlobalVariable *GVar = cast<GlobalVariable>(V);
-        if (GVar) {
-          if (GVar->hasInitializer())
-            O << ".visible ";
-          else
-            O << ".extern ";
-        }
-      } else if (V->isDeclaration())
+      if (const auto *GVar = dyn_cast<GlobalVariable>(V))
+        O << (GVar->hasInitializer() ? ".visible " : ".extern ");
+      else if (V->isDeclaration())
         O << ".extern ";
       else
         O << ".visible ";
     } else if (V->hasAppendingLinkage()) {
-      std::string msg;
-      msg.append("Error: ");
-      msg.append("Symbol ");
-      if (V->hasName())
-        msg.append(std::string(V->getName()));
-      msg.append("has unsupported appending linkage type");
-      llvm_unreachable(msg.c_str());
-    } else if (!V->hasInternalLinkage() &&
-               !V->hasPrivateLinkage()) {
+      report_fatal_error("Symbol '" + (V->hasName() ? V->getName() : "") +
+                         "' has unsupported appending linkage type");
+    } else if (!V->hasInternalLinkage() && !V->hasPrivateLinkage()) {
       O << ".weak ";
     }
   }
 }
 
 void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
-                                         raw_ostream &O, bool processDemoted,
+                                         raw_ostream &O, bool ProcessDemoted,
                                          const NVPTXSubtarget &STI) {
   // Skip meta data
-  if (GVar->hasSection()) {
+  if (GVar->hasSection())
     if (GVar->getSection() == "llvm.metadata")
       return;
-  }
 
   // Skip LLVM intrinsic global variables
   if (GVar->getName().starts_with("llvm.") ||
@@ -1069,20 +1035,20 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
   }
 
   if (GVar->hasPrivateLinkage()) {
-    if (strncmp(GVar->getName().data(), "unrollpragma", 12) == 0)
+    if (GVar->getName().starts_with("unrollpragma"))
       return;
 
     // FIXME - need better way (e.g. Metadata) to avoid generating this global
-    if (strncmp(GVar->getName().data(), "filename", 8) == 0)
+    if (GVar->getName().starts_with("filename"))
       return;
     if (GVar->use_empty())
       return;
   }
 
-  const Function *demotedFunc = nullptr;
-  if (!processDemoted && canDemoteGlobalVar(GVar, demotedFunc)) {
+  const Function *DemotedFunc = nullptr;
+  if (!ProcessDemoted && canDemoteGlobalVar(GVar, DemotedFunc)) {
     O << "// " << GVar->getName() << " has been demoted\n";
-    localDecls[demotedFunc].push_back(GVar);
+    localDecls[DemotedFunc].push_back(GVar);
     return;
   }
 
@@ -1090,17 +1056,14 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
   emitPTXAddressSpace(GVar->getAddressSpace(), O);
 
   if (isManaged(*GVar)) {
-    if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) {
+    if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30)
       report_fatal_error(
           ".attribute(.managed) requires PTX version >= 4.0 and sm_30");
-    }
     O << " .attribute(.managed)";
   }
 
-  if (MaybeAlign A = GVar->getAlign())
-    O << " .align " << A->value();
-  else
-    O << " .align " << (int)DL.getPrefTypeAlign(ETy).value();
+  O << " .align "
+    << GVar->getAlign().value_or(DL.getPrefTypeAlign(ETy)).value();
 
   if (ETy->isFloatingPointTy() || ETy->isPointerTy() ||
       (ETy->isIntegerTy() && ETy->getScalarSizeInBits() <= 64)) {
@@ -1137,8 +1100,6 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
       }
     }
   } else {
-    uint64_t ElementSize = 0;
-
     // Although PTX has direct support for struct type and array type and
     // LLVM IR is very similar to PTX, the LLVM CodeGen does not support for
     // targets that support these high level field accesses. Structs, arrays
@@ -1147,8 +1108,8 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
     case Type::IntegerTyID: // Integers larger than 64 bits
     case Type::StructTyID:
     case Type::ArrayTyID:
-    case Type::FixedVectorTyID:
-      ElementSize = DL.getTypeStoreSize(ETy);
+    case Type::FixedVectorTyID: {
+      const uint64_t ElementSize = DL.getTypeStoreSize(ETy);
       // Ptx allows variable initilization only for constant and
       // global state spaces.
       if (((GVar->getAddressSpace() == ADDRESS_SPACE_GLOBAL) ||
@@ -1159,7 +1120,7 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
           AggBuffer aggBuffer(ElementSize, *this);
           bufferAggregateConstant(Initializer, &aggBuffer);
           if (aggBuffer.numSymbols()) {
-            unsigned int ptrSize = MAI->getCodePointerSize();
+            const unsigned int ptrSize = MAI->getCodePointerSize();
             if (ElementSize % ptrSize ||
                 !aggBuffer.allSymbolsAligned(ptrSize)) {
               // Print in bytes and use the mask() operator for pointers.
@@ -1190,22 +1151,17 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
         } else {
           O << " .b8 ";
           getSymbol(GVar)->print(O, MAI);
-          if (ElementSize) {
-            O << "[";
-            O << ElementSize;
-            O << "]";
-          }
+          if (ElementSize)
+            O << "[" << ElementSize << "]";
         }
       } else {
         O << " .b8 ";
         getSymbol(GVar)->print(O, MAI);
-        if (ElementSize) {
-          O << "[";
-          O << ElementSize;
-          O << "]";
-        }
+        if (ElementSize)
+          O << "[" << ElementSize << "]";
       }
       break;
+    }
     default:
       llvm_unreachable("type not supported yet");
     }
@@ -1229,7 +1185,7 @@ void NVPTXAsmPrinter::AggBuffer::printSymbol(unsigned nSym, raw_ostream &os) {
       Name->print(os, AP.MAI);
     }
   } else if (const ConstantExpr *CExpr = dyn_cast<ConstantExpr>(v0)) {
-    const MCExpr *Expr = AP.lowerConstantForGV(cast<Constant>(CExpr), false);
+    const MCExpr *Expr = AP.lowerConstantForGV(CExpr, false);
     AP.printMCExpr(*Expr, os);
   } else
     llvm_unreachable("symbol type unknown");
@@ -1298,18 +1254,18 @@ void NVPTXAsmPrinter::AggBuffer::printWords(raw_ostream &os) {
   }
 }
 
-void NVPTXAsmPrinter::emitDemotedVars(const Function *f, raw_ostream &O) {
-  auto It = localDecls.find(f);
+void NVPTXAsmPrinter::emitDemotedVars(const Function *F, raw_ostream &O) {
+  auto It = localDecls.find(F);
   if (It == localDecls.end())
     return;
 
-  std::vector<const GlobalVariable *> &gvars = It->second;
+  ArrayRef<const GlobalVariable *> GVars = It->second;
 
   const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
   const NVPTXSubtarget &STI =
       *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
 
-  for (const GlobalVariable *GV : gvars) {
+  for (const GlobalVariable *GV : GVars) {
     O << "\t// demoted variable\n\t";
     printModuleLevelGV(GV, O, /*processDemoted=*/true, STI);
   }
@@ -1344,13 +1300,11 @@ NVPTXAsmPrinter::getPTXFundamentalTypeStr(Type *Ty, bool useB4PTR) const {
     unsigned NumBits = cast<IntegerType>(Ty)->getBitWidth();
     if (NumBits == 1)
       return "pred";
-    else if (NumBits <= 64) {
+    if (NumBits <= 64) {
       std::string name = "u";
       return name + utostr(NumBits);
-    } else {
-      llvm_unreachable("Integer too large");
-      break;
     }
+    llvm_unreachable("Integer too large");
     break;
   }
   case Type::BFloatTyID:
@@ -1393,16 +1347,14 @@ void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar,
   O << ".";
   emitPTXAddressSpace(GVar->getType()->getAddressSpace(), O);
   if (isManaged(*GVar)) {
-    if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) {
+    if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30)
       report_fatal_error(
           ".attribute(.managed) requires PTX version >= 4.0 and sm_30");
-    }
+
     O << " .attribute(.managed)";
   }
-  if (MaybeAlign A = GVar->getAlign())
-    O << " .align " << A->value();
-  else
-    O << " .align " << (int)DL.getPrefTypeAlign(ETy).value();
+  O << " .align "
+    << GVar->getAlign().value_or(DL.getPrefTypeAlign(ETy)).value();
 
   // Special case for i128
   if (ETy->isIntegerTy(128)) {
@@ -1413,9 +1365,7 @@ void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar,
   }
 
   if (ETy->isFloatingPointTy() || ETy->isIntOrPtrTy()) {
-    O << " .";
-    O << getPTXFundamentalTypeStr(ETy);
-    O << " ";
+    O << " ." << getPTXFundamentalTypeStr(ETy) << " ";
     getSymbol(GVar)->print(O, MAI);
     return;
   }
@@ -1446,16 +1396,13 @@ void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar,
 
 void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
   const DataLayout &DL = getDataLayout();
-  const AttributeList &PAL = F->getAttributes();
   const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F);
   const auto *TLI = cast<NVPTXTargetLowering>(STI.getTargetLowering());
   const NVPTXMachineFunctionInfo *MFI =
       MF ? MF->getInfo<NVPTXMachineFunctionInfo>() : nullptr;
 
-  Function::const_arg_iterator I, E;
-  unsigned paramIndex = 0;
-  bool first = true;
-  bool isKernelFunc = isKernelFunction(*F);
+  bool IsFirst = true;
+  const bool IsKernelFunc = isKernelFunction(*F);
 
   if (F->arg_empty() && !F->isVarArg()) {
     O << "()";
@@ -1464,161 +1411,143 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
 
   O << "(\n";
 
-  for (I = F->arg_begin(), E = F->arg_end(); I != E; ++I, paramIndex++) {
-    Type *Ty = I->getType();
+  for (const Argument &Arg : F->args()) {
+    Type *Ty = Arg.getType();
+    const std::string ParamSym = TLI->getParamName(F, Arg.getArgNo());
 
-    if (!first)
+    if (!IsFirst)
       O << ",\n";
 
-    first = false;
+    IsFirst = false;
 
     // Handle image/sampler parameters
-    if (isKernelFunc) {
-      if (isSampler(*I) || isImage(*I)) {
-        std::string ParamSym;
-        raw_string_ostream ParamStr(ParamSym);
-        ParamStr << F->getName() << "_param_" << paramIndex;
-        ParamStr.flush();
-        bool EmitImagePtr = !MFI || !MFI->checkImageHandleSymbol(ParamSym);
-        if (isImage(*I)) {
-          if (isImageWriteOnly(*I) || isImageReadWrite(*I)) {
-            if (EmitImagePtr)
-              O << "\t.param .u64 .ptr .surfref ";
-            else
-              O << "\t.param .surfref ";
-            O << TLI->getParamName(F, paramIndex);
-          }
-          else { // Default image is read_only
-            if (EmitImagePtr)
-              O << "\t.param .u64 .ptr .texref ";
-            else
-              O << "\t.param .texref ";
-            O << TLI->getParamName(F, paramIndex);
-          }
-        } else {
-          if (EmitImagePtr)
-            O << "\t.param .u64 .ptr .samplerref ";
-          else
-            O << "\t.param .samplerref ";
-          O << TLI->getParamName(F, paramIndex);
-        }
+    if (IsKernelFunc) {
+      const bool IsSampler = isSampler(Arg);
+      const bool IsTexture = !IsSampler && isImageReadOnly(Arg);
+      const bool IsSurface = !IsSampler && !IsTexture &&
+                             (isImageReadWrite(Arg) || isImageWriteOnly(Arg));
+      if (IsSampler || IsTexture || IsSurface) {
+        const bool EmitImgPtr = !MFI || !MFI->checkImageHandleSymbol(ParamSym);
+        O << "\t.param ";
+        if (EmitImgPtr)
+          O << ".u64 .ptr ";
+
+        if (IsSampler)...
[truncated]

Copy link
Contributor

@justinfargnoli justinfargnoli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you provide a link to the original PR that introduced the failure and populate the description with the reason the original PR failed and this will pass.

@AlexMaclean AlexMaclean merged commit 34cf04b into llvm:main Feb 15, 2025
10 checks passed
sivan-shani pushed a commit to sivan-shani/llvm-project that referenced this pull request Feb 24, 2025
…nfo (NFC)" (llvm#127277)

Originally landed in llvm#126800

This version fixes a typo in NVPTXAsmPrinter::emitFunctionParamList
where .surfref was erroneously replaced with .samplerref.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants