Skip to content

[AMDGPU] New RegBankSelect: Add Ptr32/Ptr64/Ptr128 #142602

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: users/pierre-vh/newrb-add-128b-tys
Choose a base branch
from

Conversation

Pierre-vh
Copy link
Contributor

There's quite a few opcodes that do not care about the exact AS of the pointer, just its size.
Adding generic types for these will help reduce duplication in the rule definitions.

I also moved the usual B types to use the new isAnyPtr helper I added to make sure they're supersets of the Ptr cases

Copy link
Contributor Author

Pierre-vh commented Jun 3, 2025

Warning

This pull request is not mergeable via GitHub because a downstack PR is open. Once all requirements are satisfied, merge this PR as a stack on Graphite.
Learn more

This stack of pull requests is managed by Graphite. Learn more about stacking.

@llvmbot
Copy link
Member

llvmbot commented Jun 3, 2025

@llvm/pr-subscribers-backend-amdgpu

Author: Pierre van Houtryve (Pierre-vh)

Changes

There's quite a few opcodes that do not care about the exact AS of the pointer, just its size.
Adding generic types for these will help reduce duplication in the rule definitions.

I also moved the usual B types to use the new isAnyPtr helper I added to make sure they're supersets of the Ptr cases


Full diff: https://github.com/llvm/llvm-project/pull/142602.diff

3 Files Affected:

  • (modified) llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp (+33-9)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp (+25-4)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.h (+19)
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp
index 12af7233ffad6..26aa3cf36c87a 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp
@@ -605,17 +605,23 @@ LLT RegBankLegalizeHelper::getBTyFromID(RegBankLLTMappingApplyID ID, LLT Ty) {
   case VgprB32:
   case UniInVgprB32:
     if (Ty == LLT::scalar(32) || Ty == LLT::fixed_vector(2, 16) ||
-        Ty == LLT::pointer(3, 32) || Ty == LLT::pointer(5, 32) ||
-        Ty == LLT::pointer(6, 32))
+        isAnyPtr(Ty, 32))
       return Ty;
     return LLT();
+  case SgprPtr32:
+  case VgprPtr32:
+    return isAnyPtr(Ty, 32) ? Ty : LLT();
+  case SgprPtr64:
+  case VgprPtr64:
+    return isAnyPtr(Ty, 64) ? Ty : LLT();
+  case SgprPtr128:
+  case VgprPtr128:
+    return isAnyPtr(Ty, 128) ? Ty : LLT();
   case SgprB64:
   case VgprB64:
   case UniInVgprB64:
     if (Ty == LLT::scalar(64) || Ty == LLT::fixed_vector(2, 32) ||
-        Ty == LLT::fixed_vector(4, 16) || Ty == LLT::pointer(0, 64) ||
-        Ty == LLT::pointer(1, 64) || Ty == LLT::pointer(4, 64) ||
-        (Ty.isPointer() && Ty.getAddressSpace() > AMDGPUAS::MAX_AMDGPU_ADDRESS))
+        Ty == LLT::fixed_vector(4, 16) || isAnyPtr(Ty, 64))
       return Ty;
     return LLT();
   case SgprB96:
@@ -629,7 +635,7 @@ LLT RegBankLegalizeHelper::getBTyFromID(RegBankLLTMappingApplyID ID, LLT Ty) {
   case VgprB128:
   case UniInVgprB128:
     if (Ty == LLT::scalar(128) || Ty == LLT::fixed_vector(4, 32) ||
-        Ty == LLT::fixed_vector(2, 64))
+        Ty == LLT::fixed_vector(2, 64) || isAnyPtr(Ty, 128))
       return Ty;
     return LLT();
   case SgprB256:
@@ -668,6 +674,9 @@ RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMappingApplyID ID) {
   case SgprP5:
   case SgprP6:
   case SgprP8:
+  case SgprPtr32:
+  case SgprPtr64:
+  case SgprPtr128:
   case SgprV2S16:
   case SgprV2S32:
   case SgprV4S32:
@@ -705,6 +714,9 @@ RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMappingApplyID ID) {
   case VgprP5:
   case VgprP6:
   case VgprP8:
+  case VgprPtr32:
+  case VgprPtr64:
+  case VgprPtr128:
   case VgprV2S16:
   case VgprV2S32:
   case VgprV4S32:
@@ -778,12 +790,18 @@ void RegBankLegalizeHelper::applyMappingDst(
     case SgprB128:
     case SgprB256:
     case SgprB512:
+    case SgprPtr32:
+    case SgprPtr64:
+    case SgprPtr128:
     case VgprB32:
     case VgprB64:
     case VgprB96:
     case VgprB128:
     case VgprB256:
-    case VgprB512: {
+    case VgprB512:
+    case VgprPtr32:
+    case VgprPtr64:
+    case VgprPtr128: {
       assert(Ty == getBTyFromID(MethodIDs[OpIdx], Ty));
       assert(RB == getRegBankFromID(MethodIDs[OpIdx]));
       break;
@@ -892,7 +910,10 @@ void RegBankLegalizeHelper::applyMappingSrc(
     case SgprB96:
     case SgprB128:
     case SgprB256:
-    case SgprB512: {
+    case SgprB512:
+    case SgprPtr32:
+    case SgprPtr64:
+    case SgprPtr128: {
       assert(Ty == getBTyFromID(MethodIDs[i], Ty));
       assert(RB == getRegBankFromID(MethodIDs[i]));
       break;
@@ -926,7 +947,10 @@ void RegBankLegalizeHelper::applyMappingSrc(
     case VgprB96:
     case VgprB128:
     case VgprB256:
-    case VgprB512: {
+    case VgprB512:
+    case VgprPtr32:
+    case VgprPtr64:
+    case VgprPtr128: {
       assert(Ty == getBTyFromID(MethodIDs[i], Ty));
       if (RB != VgprRB) {
         auto CopyToVgpr = B.buildCopy({VgprRB, Ty}, Reg);
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp
index 08a35b9794344..b6260076731ba 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp
@@ -26,6 +26,10 @@
 using namespace llvm;
 using namespace AMDGPU;
 
+bool AMDGPU::isAnyPtr(LLT Ty, unsigned Width) {
+  return Ty.isPointer() && Ty.getSizeInBits() == Width;
+}
+
 RegBankLLTMapping::RegBankLLTMapping(
     std::initializer_list<RegBankLLTMappingApplyID> DstOpMappingList,
     std::initializer_list<RegBankLLTMappingApplyID> SrcOpMappingList,
@@ -68,6 +72,12 @@ bool matchUniformityAndLLT(Register Reg, UniformityLLTOpPredicateID UniID,
     return MRI.getType(Reg) == LLT::pointer(6, 32);
   case P8:
     return MRI.getType(Reg) == LLT::pointer(8, 128);
+  case Ptr32:
+    return isAnyPtr(MRI.getType(Reg), 32);
+  case Ptr64:
+    return isAnyPtr(MRI.getType(Reg), 64);
+  case Ptr128:
+    return isAnyPtr(MRI.getType(Reg), 128);
   case V2S32:
     return MRI.getType(Reg) == LLT::fixed_vector(2, 32);
   case V4S32:
@@ -110,6 +120,12 @@ bool matchUniformityAndLLT(Register Reg, UniformityLLTOpPredicateID UniID,
     return MRI.getType(Reg) == LLT::pointer(6, 32) && MUI.isUniform(Reg);
   case UniP8:
     return MRI.getType(Reg) == LLT::pointer(8, 128) && MUI.isUniform(Reg);
+  case UniPtr32:
+    return isAnyPtr(MRI.getType(Reg), 32) && MUI.isUniform(Reg);
+  case UniPtr64:
+    return isAnyPtr(MRI.getType(Reg), 64) && MUI.isUniform(Reg);
+  case UniPtr128:
+    return isAnyPtr(MRI.getType(Reg), 128) && MUI.isUniform(Reg);
   case UniV2S16:
     return MRI.getType(Reg) == LLT::fixed_vector(2, 16) && MUI.isUniform(Reg);
   case UniB32:
@@ -150,6 +166,12 @@ bool matchUniformityAndLLT(Register Reg, UniformityLLTOpPredicateID UniID,
     return MRI.getType(Reg) == LLT::pointer(6, 32) && MUI.isDivergent(Reg);
   case DivP8:
     return MRI.getType(Reg) == LLT::pointer(8, 128) && MUI.isDivergent(Reg);
+  case DivPtr32:
+    return isAnyPtr(MRI.getType(Reg), 32) && MUI.isDivergent(Reg);
+  case DivPtr64:
+    return isAnyPtr(MRI.getType(Reg), 64) && MUI.isDivergent(Reg);
+  case DivPtr128:
+    return isAnyPtr(MRI.getType(Reg), 128) && MUI.isDivergent(Reg);
   case DivV2S16:
     return MRI.getType(Reg) == LLT::fixed_vector(2, 16) && MUI.isDivergent(Reg);
   case DivB32:
@@ -223,15 +245,14 @@ UniformityLLTOpPredicateID LLTToId(LLT Ty) {
 
 UniformityLLTOpPredicateID LLTToBId(LLT Ty) {
   if (Ty == LLT::scalar(32) || Ty == LLT::fixed_vector(2, 16) ||
-      (Ty.isPointer() && Ty.getSizeInBits() == 32))
+      isAnyPtr(Ty, 32))
     return B32;
   if (Ty == LLT::scalar(64) || Ty == LLT::fixed_vector(2, 32) ||
-      Ty == LLT::fixed_vector(4, 16) ||
-      (Ty.isPointer() && Ty.getSizeInBits() == 64))
+      Ty == LLT::fixed_vector(4, 16) || isAnyPtr(Ty, 64))
     return B64;
   if (Ty == LLT::fixed_vector(3, 32))
     return B96;
-  if (Ty == LLT::fixed_vector(4, 32))
+  if (Ty == LLT::fixed_vector(4, 32) || isAnyPtr(Ty, 128))
     return B128;
   return _;
 }
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.h b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.h
index 14be873b6ce19..1d429f711fbf6 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.h
@@ -15,6 +15,7 @@
 
 namespace llvm {
 
+class LLT;
 class MachineRegisterInfo;
 class MachineInstr;
 class GCNSubtarget;
@@ -26,6 +27,9 @@ using MachineUniformityInfo = GenericUniformityInfo<MachineSSAContext>;
 
 namespace AMDGPU {
 
+/// \returns true if \p Ty is a pointer type with size \p Width.
+bool isAnyPtr(LLT Ty, unsigned Width);
+
 // IDs used to build predicate for RegBankLegalizeRule. Predicate can have one
 // or more IDs and each represents a check for 'uniform or divergent' + LLT or
 // just LLT on register operand.
@@ -62,6 +66,9 @@ enum UniformityLLTOpPredicateID {
   P5,
   P6,
   P8,
+  Ptr32,
+  Ptr64,
+  Ptr128,
 
   UniP0,
   UniP1,
@@ -71,6 +78,9 @@ enum UniformityLLTOpPredicateID {
   UniP5,
   UniP6,
   UniP8,
+  UniPtr32,
+  UniPtr64,
+  UniPtr128,
 
   DivP0,
   DivP1,
@@ -80,6 +90,9 @@ enum UniformityLLTOpPredicateID {
   DivP5,
   DivP6,
   DivP8,
+  DivPtr32,
+  DivPtr64,
+  DivPtr128,
 
   // vectors
   V2S16,
@@ -138,6 +151,9 @@ enum RegBankLLTMappingApplyID {
   SgprP5,
   SgprP6,
   SgprP8,
+  SgprPtr32,
+  SgprPtr64,
+  SgprPtr128,
   SgprV2S16,
   SgprV4S32,
   SgprV2S32,
@@ -161,6 +177,9 @@ enum RegBankLLTMappingApplyID {
   VgprP5,
   VgprP6,
   VgprP8,
+  VgprPtr32,
+  VgprPtr64,
+  VgprPtr128,
   VgprV2S16,
   VgprV2S32,
   VgprB32,

@petar-avramovic
Copy link
Collaborator

This one LGTM, minus patches before

There's quite a few opcodes that do not care about the exact AS of the pointer, just its size.
Adding generic types for these will help reduce duplication in the rule definitions.

I also moved the usual B types to use the new `isAnyPtr` helper I added to make sure they're supersets of the `Ptr` cases
@Pierre-vh Pierre-vh force-pushed the users/pierre-vh/newrb-add-128b-tys branch from 5805695 to 96669ee Compare June 4, 2025 08:09
@Pierre-vh Pierre-vh force-pushed the users/pierre-vh/newrb-anyptr-types branch from b4da53a to c69258d Compare June 4, 2025 08:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants