Skip to content

[AMDGPU] New RegBankSelect: Add Ptr32/Ptr64/Ptr128 #142602

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: users/pierre-vh/newrb-add-128b-tys
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 33 additions & 9 deletions llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -595,17 +595,23 @@ LLT RegBankLegalizeHelper::getBTyFromID(RegBankLLTMappingApplyID ID, LLT Ty) {
case VgprB32:
case UniInVgprB32:
if (Ty == LLT::scalar(32) || Ty == LLT::fixed_vector(2, 16) ||
Ty == LLT::pointer(3, 32) || Ty == LLT::pointer(5, 32) ||
Ty == LLT::pointer(6, 32))
isAnyPtr(Ty, 32))
return Ty;
return LLT();
case SgprPtr32:
case VgprPtr32:
return isAnyPtr(Ty, 32) ? Ty : LLT();
case SgprPtr64:
case VgprPtr64:
return isAnyPtr(Ty, 64) ? Ty : LLT();
case SgprPtr128:
case VgprPtr128:
return isAnyPtr(Ty, 128) ? Ty : LLT();
case SgprB64:
case VgprB64:
case UniInVgprB64:
if (Ty == LLT::scalar(64) || Ty == LLT::fixed_vector(2, 32) ||
Ty == LLT::fixed_vector(4, 16) || Ty == LLT::pointer(0, 64) ||
Ty == LLT::pointer(1, 64) || Ty == LLT::pointer(4, 64) ||
(Ty.isPointer() && Ty.getAddressSpace() > AMDGPUAS::MAX_AMDGPU_ADDRESS))
Ty == LLT::fixed_vector(4, 16) || isAnyPtr(Ty, 64))
return Ty;
return LLT();
case SgprB96:
Expand All @@ -619,7 +625,7 @@ LLT RegBankLegalizeHelper::getBTyFromID(RegBankLLTMappingApplyID ID, LLT Ty) {
case VgprB128:
case UniInVgprB128:
if (Ty == LLT::scalar(128) || Ty == LLT::fixed_vector(4, 32) ||
Ty == LLT::fixed_vector(2, 64))
Ty == LLT::fixed_vector(2, 64) || isAnyPtr(Ty, 128))
return Ty;
return LLT();
case SgprB256:
Expand Down Expand Up @@ -654,6 +660,9 @@ RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMappingApplyID ID) {
case SgprP3:
case SgprP4:
case SgprP5:
case SgprPtr32:
case SgprPtr64:
case SgprPtr128:
case SgprV2S16:
case SgprV2S32:
case SgprV4S32:
Expand Down Expand Up @@ -688,6 +697,9 @@ RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMappingApplyID ID) {
case VgprP3:
case VgprP4:
case VgprP5:
case VgprPtr32:
case VgprPtr64:
case VgprPtr128:
case VgprV2S16:
case VgprV2S32:
case VgprV4S32:
Expand Down Expand Up @@ -754,12 +766,18 @@ void RegBankLegalizeHelper::applyMappingDst(
case SgprB128:
case SgprB256:
case SgprB512:
case SgprPtr32:
case SgprPtr64:
case SgprPtr128:
case VgprB32:
case VgprB64:
case VgprB96:
case VgprB128:
case VgprB256:
case VgprB512: {
case VgprB512:
case VgprPtr32:
case VgprPtr64:
case VgprPtr128: {
assert(Ty == getBTyFromID(MethodIDs[OpIdx], Ty));
assert(RB == getRegBankFromID(MethodIDs[OpIdx]));
break;
Expand Down Expand Up @@ -864,7 +882,10 @@ void RegBankLegalizeHelper::applyMappingSrc(
case SgprB96:
case SgprB128:
case SgprB256:
case SgprB512: {
case SgprB512:
case SgprPtr32:
case SgprPtr64:
case SgprPtr128: {
assert(Ty == getBTyFromID(MethodIDs[i], Ty));
assert(RB == getRegBankFromID(MethodIDs[i]));
break;
Expand Down Expand Up @@ -895,7 +916,10 @@ void RegBankLegalizeHelper::applyMappingSrc(
case VgprB96:
case VgprB128:
case VgprB256:
case VgprB512: {
case VgprB512:
case VgprPtr32:
case VgprPtr64:
case VgprPtr128: {
assert(Ty == getBTyFromID(MethodIDs[i], Ty));
if (RB != VgprRB) {
auto CopyToVgpr = B.buildCopy({VgprRB, Ty}, Reg);
Expand Down
29 changes: 25 additions & 4 deletions llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
using namespace llvm;
using namespace AMDGPU;

bool AMDGPU::isAnyPtr(LLT Ty, unsigned Width) {
return Ty.isPointer() && Ty.getSizeInBits() == Width;
}

RegBankLLTMapping::RegBankLLTMapping(
std::initializer_list<RegBankLLTMappingApplyID> DstOpMappingList,
std::initializer_list<RegBankLLTMappingApplyID> SrcOpMappingList,
Expand Down Expand Up @@ -62,6 +66,12 @@ bool matchUniformityAndLLT(Register Reg, UniformityLLTOpPredicateID UniID,
return MRI.getType(Reg) == LLT::pointer(4, 64);
case P5:
return MRI.getType(Reg) == LLT::pointer(5, 32);
case Ptr32:
return isAnyPtr(MRI.getType(Reg), 32);
case Ptr64:
return isAnyPtr(MRI.getType(Reg), 64);
case Ptr128:
return isAnyPtr(MRI.getType(Reg), 128);
case V2S32:
return MRI.getType(Reg) == LLT::fixed_vector(2, 32);
case V4S32:
Expand Down Expand Up @@ -98,6 +108,12 @@ bool matchUniformityAndLLT(Register Reg, UniformityLLTOpPredicateID UniID,
return MRI.getType(Reg) == LLT::pointer(4, 64) && MUI.isUniform(Reg);
case UniP5:
return MRI.getType(Reg) == LLT::pointer(5, 32) && MUI.isUniform(Reg);
case UniPtr32:
return isAnyPtr(MRI.getType(Reg), 32) && MUI.isUniform(Reg);
case UniPtr64:
return isAnyPtr(MRI.getType(Reg), 64) && MUI.isUniform(Reg);
case UniPtr128:
return isAnyPtr(MRI.getType(Reg), 128) && MUI.isUniform(Reg);
case UniV2S16:
return MRI.getType(Reg) == LLT::fixed_vector(2, 16) && MUI.isUniform(Reg);
case UniB32:
Expand Down Expand Up @@ -132,6 +148,12 @@ bool matchUniformityAndLLT(Register Reg, UniformityLLTOpPredicateID UniID,
return MRI.getType(Reg) == LLT::pointer(4, 64) && MUI.isDivergent(Reg);
case DivP5:
return MRI.getType(Reg) == LLT::pointer(5, 32) && MUI.isDivergent(Reg);
case DivPtr32:
return isAnyPtr(MRI.getType(Reg), 32) && MUI.isDivergent(Reg);
case DivPtr64:
return isAnyPtr(MRI.getType(Reg), 64) && MUI.isDivergent(Reg);
case DivPtr128:
return isAnyPtr(MRI.getType(Reg), 128) && MUI.isDivergent(Reg);
case DivV2S16:
return MRI.getType(Reg) == LLT::fixed_vector(2, 16) && MUI.isDivergent(Reg);
case DivB32:
Expand Down Expand Up @@ -205,15 +227,14 @@ UniformityLLTOpPredicateID LLTToId(LLT Ty) {

UniformityLLTOpPredicateID LLTToBId(LLT Ty) {
if (Ty == LLT::scalar(32) || Ty == LLT::fixed_vector(2, 16) ||
(Ty.isPointer() && Ty.getSizeInBits() == 32))
isAnyPtr(Ty, 32))
return B32;
if (Ty == LLT::scalar(64) || Ty == LLT::fixed_vector(2, 32) ||
Ty == LLT::fixed_vector(4, 16) ||
(Ty.isPointer() && Ty.getSizeInBits() == 64))
Ty == LLT::fixed_vector(4, 16) || isAnyPtr(Ty, 64))
return B64;
if (Ty == LLT::fixed_vector(3, 32))
return B96;
if (Ty == LLT::fixed_vector(4, 32))
if (Ty == LLT::fixed_vector(4, 32) || isAnyPtr(Ty, 128))
return B128;
return _;
}
Expand Down
19 changes: 19 additions & 0 deletions llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

namespace llvm {

class LLT;
class MachineRegisterInfo;
class MachineInstr;
class GCNSubtarget;
Expand All @@ -26,6 +27,9 @@ using MachineUniformityInfo = GenericUniformityInfo<MachineSSAContext>;

namespace AMDGPU {

/// \returns true if \p Ty is a pointer type with size \p Width.
bool isAnyPtr(LLT Ty, unsigned Width);

// IDs used to build predicate for RegBankLegalizeRule. Predicate can have one
// or more IDs and each represents a check for 'uniform or divergent' + LLT or
// just LLT on register operand.
Expand Down Expand Up @@ -59,18 +63,27 @@ enum UniformityLLTOpPredicateID {
P3,
P4,
P5,
Ptr32,
Ptr64,
Ptr128,

UniP0,
UniP1,
UniP3,
UniP4,
UniP5,
UniPtr32,
UniPtr64,
UniPtr128,

DivP0,
DivP1,
DivP3,
DivP4,
DivP5,
DivPtr32,
DivPtr64,
DivPtr128,

// vectors
V2S16,
Expand Down Expand Up @@ -125,6 +138,9 @@ enum RegBankLLTMappingApplyID {
SgprP3,
SgprP4,
SgprP5,
SgprPtr32,
SgprPtr64,
SgprPtr128,
SgprV2S16,
SgprV4S32,
SgprV2S32,
Expand All @@ -145,6 +161,9 @@ enum RegBankLLTMappingApplyID {
VgprP3,
VgprP4,
VgprP5,
VgprPtr32,
VgprPtr64,
VgprPtr128,
VgprV2S16,
VgprV2S32,
VgprB32,
Expand Down
Loading