Skip to content

Commit c69258d

Browse files
committed
[AMDGPU] New RegBankSelect: Add Ptr32/Ptr64/Ptr128
There's quite a few opcodes that do not care about the exact AS of the pointer, just its size. Adding generic types for these will help reduce duplication in the rule definitions. I also moved the usual B types to use the new `isAnyPtr` helper I added to make sure they're supersets of the `Ptr` cases
1 parent 96669ee commit c69258d

File tree

3 files changed

+77
-13
lines changed

3 files changed

+77
-13
lines changed

llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -595,17 +595,23 @@ LLT RegBankLegalizeHelper::getBTyFromID(RegBankLLTMappingApplyID ID, LLT Ty) {
595595
case VgprB32:
596596
case UniInVgprB32:
597597
if (Ty == LLT::scalar(32) || Ty == LLT::fixed_vector(2, 16) ||
598-
Ty == LLT::pointer(3, 32) || Ty == LLT::pointer(5, 32) ||
599-
Ty == LLT::pointer(6, 32))
598+
isAnyPtr(Ty, 32))
600599
return Ty;
601600
return LLT();
601+
case SgprPtr32:
602+
case VgprPtr32:
603+
return isAnyPtr(Ty, 32) ? Ty : LLT();
604+
case SgprPtr64:
605+
case VgprPtr64:
606+
return isAnyPtr(Ty, 64) ? Ty : LLT();
607+
case SgprPtr128:
608+
case VgprPtr128:
609+
return isAnyPtr(Ty, 128) ? Ty : LLT();
602610
case SgprB64:
603611
case VgprB64:
604612
case UniInVgprB64:
605613
if (Ty == LLT::scalar(64) || Ty == LLT::fixed_vector(2, 32) ||
606-
Ty == LLT::fixed_vector(4, 16) || Ty == LLT::pointer(0, 64) ||
607-
Ty == LLT::pointer(1, 64) || Ty == LLT::pointer(4, 64) ||
608-
(Ty.isPointer() && Ty.getAddressSpace() > AMDGPUAS::MAX_AMDGPU_ADDRESS))
614+
Ty == LLT::fixed_vector(4, 16) || isAnyPtr(Ty, 64))
609615
return Ty;
610616
return LLT();
611617
case SgprB96:
@@ -619,7 +625,7 @@ LLT RegBankLegalizeHelper::getBTyFromID(RegBankLLTMappingApplyID ID, LLT Ty) {
619625
case VgprB128:
620626
case UniInVgprB128:
621627
if (Ty == LLT::scalar(128) || Ty == LLT::fixed_vector(4, 32) ||
622-
Ty == LLT::fixed_vector(2, 64))
628+
Ty == LLT::fixed_vector(2, 64) || isAnyPtr(Ty, 128))
623629
return Ty;
624630
return LLT();
625631
case SgprB256:
@@ -654,6 +660,9 @@ RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMappingApplyID ID) {
654660
case SgprP3:
655661
case SgprP4:
656662
case SgprP5:
663+
case SgprPtr32:
664+
case SgprPtr64:
665+
case SgprPtr128:
657666
case SgprV2S16:
658667
case SgprV2S32:
659668
case SgprV4S32:
@@ -688,6 +697,9 @@ RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMappingApplyID ID) {
688697
case VgprP3:
689698
case VgprP4:
690699
case VgprP5:
700+
case VgprPtr32:
701+
case VgprPtr64:
702+
case VgprPtr128:
691703
case VgprV2S16:
692704
case VgprV2S32:
693705
case VgprV4S32:
@@ -754,12 +766,18 @@ void RegBankLegalizeHelper::applyMappingDst(
754766
case SgprB128:
755767
case SgprB256:
756768
case SgprB512:
769+
case SgprPtr32:
770+
case SgprPtr64:
771+
case SgprPtr128:
757772
case VgprB32:
758773
case VgprB64:
759774
case VgprB96:
760775
case VgprB128:
761776
case VgprB256:
762-
case VgprB512: {
777+
case VgprB512:
778+
case VgprPtr32:
779+
case VgprPtr64:
780+
case VgprPtr128: {
763781
assert(Ty == getBTyFromID(MethodIDs[OpIdx], Ty));
764782
assert(RB == getRegBankFromID(MethodIDs[OpIdx]));
765783
break;
@@ -864,7 +882,10 @@ void RegBankLegalizeHelper::applyMappingSrc(
864882
case SgprB96:
865883
case SgprB128:
866884
case SgprB256:
867-
case SgprB512: {
885+
case SgprB512:
886+
case SgprPtr32:
887+
case SgprPtr64:
888+
case SgprPtr128: {
868889
assert(Ty == getBTyFromID(MethodIDs[i], Ty));
869890
assert(RB == getRegBankFromID(MethodIDs[i]));
870891
break;
@@ -895,7 +916,10 @@ void RegBankLegalizeHelper::applyMappingSrc(
895916
case VgprB96:
896917
case VgprB128:
897918
case VgprB256:
898-
case VgprB512: {
919+
case VgprB512:
920+
case VgprPtr32:
921+
case VgprPtr64:
922+
case VgprPtr128: {
899923
assert(Ty == getBTyFromID(MethodIDs[i], Ty));
900924
if (RB != VgprRB) {
901925
auto CopyToVgpr = B.buildCopy({VgprRB, Ty}, Reg);

llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@
2626
using namespace llvm;
2727
using namespace AMDGPU;
2828

29+
bool AMDGPU::isAnyPtr(LLT Ty, unsigned Width) {
30+
return Ty.isPointer() && Ty.getSizeInBits() == Width;
31+
}
32+
2933
RegBankLLTMapping::RegBankLLTMapping(
3034
std::initializer_list<RegBankLLTMappingApplyID> DstOpMappingList,
3135
std::initializer_list<RegBankLLTMappingApplyID> SrcOpMappingList,
@@ -62,6 +66,12 @@ bool matchUniformityAndLLT(Register Reg, UniformityLLTOpPredicateID UniID,
6266
return MRI.getType(Reg) == LLT::pointer(4, 64);
6367
case P5:
6468
return MRI.getType(Reg) == LLT::pointer(5, 32);
69+
case Ptr32:
70+
return isAnyPtr(MRI.getType(Reg), 32);
71+
case Ptr64:
72+
return isAnyPtr(MRI.getType(Reg), 64);
73+
case Ptr128:
74+
return isAnyPtr(MRI.getType(Reg), 128);
6575
case V2S32:
6676
return MRI.getType(Reg) == LLT::fixed_vector(2, 32);
6777
case V4S32:
@@ -98,6 +108,12 @@ bool matchUniformityAndLLT(Register Reg, UniformityLLTOpPredicateID UniID,
98108
return MRI.getType(Reg) == LLT::pointer(4, 64) && MUI.isUniform(Reg);
99109
case UniP5:
100110
return MRI.getType(Reg) == LLT::pointer(5, 32) && MUI.isUniform(Reg);
111+
case UniPtr32:
112+
return isAnyPtr(MRI.getType(Reg), 32) && MUI.isUniform(Reg);
113+
case UniPtr64:
114+
return isAnyPtr(MRI.getType(Reg), 64) && MUI.isUniform(Reg);
115+
case UniPtr128:
116+
return isAnyPtr(MRI.getType(Reg), 128) && MUI.isUniform(Reg);
101117
case UniV2S16:
102118
return MRI.getType(Reg) == LLT::fixed_vector(2, 16) && MUI.isUniform(Reg);
103119
case UniB32:
@@ -132,6 +148,12 @@ bool matchUniformityAndLLT(Register Reg, UniformityLLTOpPredicateID UniID,
132148
return MRI.getType(Reg) == LLT::pointer(4, 64) && MUI.isDivergent(Reg);
133149
case DivP5:
134150
return MRI.getType(Reg) == LLT::pointer(5, 32) && MUI.isDivergent(Reg);
151+
case DivPtr32:
152+
return isAnyPtr(MRI.getType(Reg), 32) && MUI.isDivergent(Reg);
153+
case DivPtr64:
154+
return isAnyPtr(MRI.getType(Reg), 64) && MUI.isDivergent(Reg);
155+
case DivPtr128:
156+
return isAnyPtr(MRI.getType(Reg), 128) && MUI.isDivergent(Reg);
135157
case DivV2S16:
136158
return MRI.getType(Reg) == LLT::fixed_vector(2, 16) && MUI.isDivergent(Reg);
137159
case DivB32:
@@ -205,15 +227,14 @@ UniformityLLTOpPredicateID LLTToId(LLT Ty) {
205227

206228
UniformityLLTOpPredicateID LLTToBId(LLT Ty) {
207229
if (Ty == LLT::scalar(32) || Ty == LLT::fixed_vector(2, 16) ||
208-
(Ty.isPointer() && Ty.getSizeInBits() == 32))
230+
isAnyPtr(Ty, 32))
209231
return B32;
210232
if (Ty == LLT::scalar(64) || Ty == LLT::fixed_vector(2, 32) ||
211-
Ty == LLT::fixed_vector(4, 16) ||
212-
(Ty.isPointer() && Ty.getSizeInBits() == 64))
233+
Ty == LLT::fixed_vector(4, 16) || isAnyPtr(Ty, 64))
213234
return B64;
214235
if (Ty == LLT::fixed_vector(3, 32))
215236
return B96;
216-
if (Ty == LLT::fixed_vector(4, 32))
237+
if (Ty == LLT::fixed_vector(4, 32) || isAnyPtr(Ty, 128))
217238
return B128;
218239
return _;
219240
}

llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
namespace llvm {
1717

18+
class LLT;
1819
class MachineRegisterInfo;
1920
class MachineInstr;
2021
class GCNSubtarget;
@@ -26,6 +27,9 @@ using MachineUniformityInfo = GenericUniformityInfo<MachineSSAContext>;
2627

2728
namespace AMDGPU {
2829

30+
/// \returns true if \p Ty is a pointer type with size \p Width.
31+
bool isAnyPtr(LLT Ty, unsigned Width);
32+
2933
// IDs used to build predicate for RegBankLegalizeRule. Predicate can have one
3034
// or more IDs and each represents a check for 'uniform or divergent' + LLT or
3135
// just LLT on register operand.
@@ -59,18 +63,27 @@ enum UniformityLLTOpPredicateID {
5963
P3,
6064
P4,
6165
P5,
66+
Ptr32,
67+
Ptr64,
68+
Ptr128,
6269

6370
UniP0,
6471
UniP1,
6572
UniP3,
6673
UniP4,
6774
UniP5,
75+
UniPtr32,
76+
UniPtr64,
77+
UniPtr128,
6878

6979
DivP0,
7080
DivP1,
7181
DivP3,
7282
DivP4,
7383
DivP5,
84+
DivPtr32,
85+
DivPtr64,
86+
DivPtr128,
7487

7588
// vectors
7689
V2S16,
@@ -125,6 +138,9 @@ enum RegBankLLTMappingApplyID {
125138
SgprP3,
126139
SgprP4,
127140
SgprP5,
141+
SgprPtr32,
142+
SgprPtr64,
143+
SgprPtr128,
128144
SgprV2S16,
129145
SgprV4S32,
130146
SgprV2S32,
@@ -145,6 +161,9 @@ enum RegBankLLTMappingApplyID {
145161
VgprP3,
146162
VgprP4,
147163
VgprP5,
164+
VgprPtr32,
165+
VgprPtr64,
166+
VgprPtr128,
148167
VgprV2S16,
149168
VgprV2S32,
150169
VgprB32,

0 commit comments

Comments
 (0)