-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[AMDGPU] New RegBankSelect: Add Ptr32/Ptr64/Ptr128 #142602
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: users/pierre-vh/newrb-add-128b-tys
Are you sure you want to change the base?
[AMDGPU] New RegBankSelect: Add Ptr32/Ptr64/Ptr128 #142602
Conversation
Warning This pull request is not mergeable via GitHub because a downstack PR is open. Once all requirements are satisfied, merge this PR as a stack on Graphite.
This stack of pull requests is managed by Graphite. Learn more about stacking. |
@llvm/pr-subscribers-backend-amdgpu Author: Pierre van Houtryve (Pierre-vh) ChangesThere's quite a few opcodes that do not care about the exact AS of the pointer, just its size. I also moved the usual B types to use the new Full diff: https://github.com/llvm/llvm-project/pull/142602.diff 3 Files Affected:
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp
index 12af7233ffad6..26aa3cf36c87a 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp
@@ -605,17 +605,23 @@ LLT RegBankLegalizeHelper::getBTyFromID(RegBankLLTMappingApplyID ID, LLT Ty) {
case VgprB32:
case UniInVgprB32:
if (Ty == LLT::scalar(32) || Ty == LLT::fixed_vector(2, 16) ||
- Ty == LLT::pointer(3, 32) || Ty == LLT::pointer(5, 32) ||
- Ty == LLT::pointer(6, 32))
+ isAnyPtr(Ty, 32))
return Ty;
return LLT();
+ case SgprPtr32:
+ case VgprPtr32:
+ return isAnyPtr(Ty, 32) ? Ty : LLT();
+ case SgprPtr64:
+ case VgprPtr64:
+ return isAnyPtr(Ty, 64) ? Ty : LLT();
+ case SgprPtr128:
+ case VgprPtr128:
+ return isAnyPtr(Ty, 128) ? Ty : LLT();
case SgprB64:
case VgprB64:
case UniInVgprB64:
if (Ty == LLT::scalar(64) || Ty == LLT::fixed_vector(2, 32) ||
- Ty == LLT::fixed_vector(4, 16) || Ty == LLT::pointer(0, 64) ||
- Ty == LLT::pointer(1, 64) || Ty == LLT::pointer(4, 64) ||
- (Ty.isPointer() && Ty.getAddressSpace() > AMDGPUAS::MAX_AMDGPU_ADDRESS))
+ Ty == LLT::fixed_vector(4, 16) || isAnyPtr(Ty, 64))
return Ty;
return LLT();
case SgprB96:
@@ -629,7 +635,7 @@ LLT RegBankLegalizeHelper::getBTyFromID(RegBankLLTMappingApplyID ID, LLT Ty) {
case VgprB128:
case UniInVgprB128:
if (Ty == LLT::scalar(128) || Ty == LLT::fixed_vector(4, 32) ||
- Ty == LLT::fixed_vector(2, 64))
+ Ty == LLT::fixed_vector(2, 64) || isAnyPtr(Ty, 128))
return Ty;
return LLT();
case SgprB256:
@@ -668,6 +674,9 @@ RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMappingApplyID ID) {
case SgprP5:
case SgprP6:
case SgprP8:
+ case SgprPtr32:
+ case SgprPtr64:
+ case SgprPtr128:
case SgprV2S16:
case SgprV2S32:
case SgprV4S32:
@@ -705,6 +714,9 @@ RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMappingApplyID ID) {
case VgprP5:
case VgprP6:
case VgprP8:
+ case VgprPtr32:
+ case VgprPtr64:
+ case VgprPtr128:
case VgprV2S16:
case VgprV2S32:
case VgprV4S32:
@@ -778,12 +790,18 @@ void RegBankLegalizeHelper::applyMappingDst(
case SgprB128:
case SgprB256:
case SgprB512:
+ case SgprPtr32:
+ case SgprPtr64:
+ case SgprPtr128:
case VgprB32:
case VgprB64:
case VgprB96:
case VgprB128:
case VgprB256:
- case VgprB512: {
+ case VgprB512:
+ case VgprPtr32:
+ case VgprPtr64:
+ case VgprPtr128: {
assert(Ty == getBTyFromID(MethodIDs[OpIdx], Ty));
assert(RB == getRegBankFromID(MethodIDs[OpIdx]));
break;
@@ -892,7 +910,10 @@ void RegBankLegalizeHelper::applyMappingSrc(
case SgprB96:
case SgprB128:
case SgprB256:
- case SgprB512: {
+ case SgprB512:
+ case SgprPtr32:
+ case SgprPtr64:
+ case SgprPtr128: {
assert(Ty == getBTyFromID(MethodIDs[i], Ty));
assert(RB == getRegBankFromID(MethodIDs[i]));
break;
@@ -926,7 +947,10 @@ void RegBankLegalizeHelper::applyMappingSrc(
case VgprB96:
case VgprB128:
case VgprB256:
- case VgprB512: {
+ case VgprB512:
+ case VgprPtr32:
+ case VgprPtr64:
+ case VgprPtr128: {
assert(Ty == getBTyFromID(MethodIDs[i], Ty));
if (RB != VgprRB) {
auto CopyToVgpr = B.buildCopy({VgprRB, Ty}, Reg);
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp
index 08a35b9794344..b6260076731ba 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp
@@ -26,6 +26,10 @@
using namespace llvm;
using namespace AMDGPU;
+bool AMDGPU::isAnyPtr(LLT Ty, unsigned Width) {
+ return Ty.isPointer() && Ty.getSizeInBits() == Width;
+}
+
RegBankLLTMapping::RegBankLLTMapping(
std::initializer_list<RegBankLLTMappingApplyID> DstOpMappingList,
std::initializer_list<RegBankLLTMappingApplyID> SrcOpMappingList,
@@ -68,6 +72,12 @@ bool matchUniformityAndLLT(Register Reg, UniformityLLTOpPredicateID UniID,
return MRI.getType(Reg) == LLT::pointer(6, 32);
case P8:
return MRI.getType(Reg) == LLT::pointer(8, 128);
+ case Ptr32:
+ return isAnyPtr(MRI.getType(Reg), 32);
+ case Ptr64:
+ return isAnyPtr(MRI.getType(Reg), 64);
+ case Ptr128:
+ return isAnyPtr(MRI.getType(Reg), 128);
case V2S32:
return MRI.getType(Reg) == LLT::fixed_vector(2, 32);
case V4S32:
@@ -110,6 +120,12 @@ bool matchUniformityAndLLT(Register Reg, UniformityLLTOpPredicateID UniID,
return MRI.getType(Reg) == LLT::pointer(6, 32) && MUI.isUniform(Reg);
case UniP8:
return MRI.getType(Reg) == LLT::pointer(8, 128) && MUI.isUniform(Reg);
+ case UniPtr32:
+ return isAnyPtr(MRI.getType(Reg), 32) && MUI.isUniform(Reg);
+ case UniPtr64:
+ return isAnyPtr(MRI.getType(Reg), 64) && MUI.isUniform(Reg);
+ case UniPtr128:
+ return isAnyPtr(MRI.getType(Reg), 128) && MUI.isUniform(Reg);
case UniV2S16:
return MRI.getType(Reg) == LLT::fixed_vector(2, 16) && MUI.isUniform(Reg);
case UniB32:
@@ -150,6 +166,12 @@ bool matchUniformityAndLLT(Register Reg, UniformityLLTOpPredicateID UniID,
return MRI.getType(Reg) == LLT::pointer(6, 32) && MUI.isDivergent(Reg);
case DivP8:
return MRI.getType(Reg) == LLT::pointer(8, 128) && MUI.isDivergent(Reg);
+ case DivPtr32:
+ return isAnyPtr(MRI.getType(Reg), 32) && MUI.isDivergent(Reg);
+ case DivPtr64:
+ return isAnyPtr(MRI.getType(Reg), 64) && MUI.isDivergent(Reg);
+ case DivPtr128:
+ return isAnyPtr(MRI.getType(Reg), 128) && MUI.isDivergent(Reg);
case DivV2S16:
return MRI.getType(Reg) == LLT::fixed_vector(2, 16) && MUI.isDivergent(Reg);
case DivB32:
@@ -223,15 +245,14 @@ UniformityLLTOpPredicateID LLTToId(LLT Ty) {
UniformityLLTOpPredicateID LLTToBId(LLT Ty) {
if (Ty == LLT::scalar(32) || Ty == LLT::fixed_vector(2, 16) ||
- (Ty.isPointer() && Ty.getSizeInBits() == 32))
+ isAnyPtr(Ty, 32))
return B32;
if (Ty == LLT::scalar(64) || Ty == LLT::fixed_vector(2, 32) ||
- Ty == LLT::fixed_vector(4, 16) ||
- (Ty.isPointer() && Ty.getSizeInBits() == 64))
+ Ty == LLT::fixed_vector(4, 16) || isAnyPtr(Ty, 64))
return B64;
if (Ty == LLT::fixed_vector(3, 32))
return B96;
- if (Ty == LLT::fixed_vector(4, 32))
+ if (Ty == LLT::fixed_vector(4, 32) || isAnyPtr(Ty, 128))
return B128;
return _;
}
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.h b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.h
index 14be873b6ce19..1d429f711fbf6 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.h
@@ -15,6 +15,7 @@
namespace llvm {
+class LLT;
class MachineRegisterInfo;
class MachineInstr;
class GCNSubtarget;
@@ -26,6 +27,9 @@ using MachineUniformityInfo = GenericUniformityInfo<MachineSSAContext>;
namespace AMDGPU {
+/// \returns true if \p Ty is a pointer type with size \p Width.
+bool isAnyPtr(LLT Ty, unsigned Width);
+
// IDs used to build predicate for RegBankLegalizeRule. Predicate can have one
// or more IDs and each represents a check for 'uniform or divergent' + LLT or
// just LLT on register operand.
@@ -62,6 +66,9 @@ enum UniformityLLTOpPredicateID {
P5,
P6,
P8,
+ Ptr32,
+ Ptr64,
+ Ptr128,
UniP0,
UniP1,
@@ -71,6 +78,9 @@ enum UniformityLLTOpPredicateID {
UniP5,
UniP6,
UniP8,
+ UniPtr32,
+ UniPtr64,
+ UniPtr128,
DivP0,
DivP1,
@@ -80,6 +90,9 @@ enum UniformityLLTOpPredicateID {
DivP5,
DivP6,
DivP8,
+ DivPtr32,
+ DivPtr64,
+ DivPtr128,
// vectors
V2S16,
@@ -138,6 +151,9 @@ enum RegBankLLTMappingApplyID {
SgprP5,
SgprP6,
SgprP8,
+ SgprPtr32,
+ SgprPtr64,
+ SgprPtr128,
SgprV2S16,
SgprV4S32,
SgprV2S32,
@@ -161,6 +177,9 @@ enum RegBankLLTMappingApplyID {
VgprP5,
VgprP6,
VgprP8,
+ VgprPtr32,
+ VgprPtr64,
+ VgprPtr128,
VgprV2S16,
VgprV2S32,
VgprB32,
|
This one LGTM, minus patches before |
There's quite a few opcodes that do not care about the exact AS of the pointer, just its size. Adding generic types for these will help reduce duplication in the rule definitions. I also moved the usual B types to use the new `isAnyPtr` helper I added to make sure they're supersets of the `Ptr` cases
5805695
to
96669ee
Compare
b4da53a
to
c69258d
Compare
There's quite a few opcodes that do not care about the exact AS of the pointer, just its size.
Adding generic types for these will help reduce duplication in the rule definitions.
I also moved the usual B types to use the new
isAnyPtr
helper I added to make sure they're supersets of thePtr
cases