Skip to content

[AMDGPU] Implement vop3p complex pattern optmization for gisel #130234

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 49 commits into from
Apr 18, 2025
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
556f7ff
Implement vop3p complex pattern optmization for gisel
Shoreshen Mar 7, 2025
58464a3
fix lit file
Shoreshen Mar 7, 2025
25f7db0
Merge branch 'main' into gisel-vop3p
Shoreshen Mar 10, 2025
daae1ae
fix comments
Shoreshen Mar 10, 2025
afa6448
Merge branch 'main' into gisel-vop3p
Shoreshen Mar 11, 2025
04a5d4c
Merge branch 'main' into gisel-vop3p
Shoreshen Mar 12, 2025
2e587f5
fix comments
Shoreshen Mar 12, 2025
c6c4b3e
fix comments
Shoreshen Mar 12, 2025
a289297
Merge branch 'main' into gisel-vop3p
Shoreshen Mar 13, 2025
dd106c7
Merge branch 'main' into gisel-vop3p
Shoreshen Mar 17, 2025
6378180
fix comments and test case
Shoreshen Mar 17, 2025
b0feaff
fix comments
Shoreshen Mar 18, 2025
53370d8
fix conflict
Shoreshen Mar 18, 2025
3f178d2
fix lit
Shoreshen Mar 18, 2025
09abc3d
Merge branch 'main' into gisel-vop3p
Shoreshen Mar 18, 2025
79b8992
fix comments
Shoreshen Mar 18, 2025
136da47
fix comments
Shoreshen Mar 18, 2025
61b4df7
Merge branch 'main' into gisel-vop3p
Shoreshen Mar 19, 2025
97e6742
Merge branch 'main' into gisel-vop3p
Shoreshen Mar 20, 2025
d79ac03
Merge branch 'main' into gisel-vop3p
Shoreshen Mar 21, 2025
fc7c927
Block for root type other than 2 x Type
Shoreshen Mar 24, 2025
9f3a54f
Merge branch 'main' into gisel-vop3p
Shoreshen Mar 24, 2025
bc51bf4
fix comments
Shoreshen Mar 24, 2025
cafa3d1
Merge branch 'main' into gisel-vop3p
Shoreshen Mar 25, 2025
a5c5017
Merge branch 'main' into gisel-vop3p
Shoreshen Mar 26, 2025
47840d7
fix comments
Shoreshen Mar 26, 2025
6fe4147
fix comments
Shoreshen Mar 26, 2025
3b7f377
fix comments
Shoreshen Mar 26, 2025
d7de92f
fix lit
Shoreshen Mar 26, 2025
45ed994
avoid global variable
Shoreshen Mar 26, 2025
2f83470
Merge branch 'main' into gisel-vop3p
Shoreshen Mar 27, 2025
d651640
Merge branch 'main' into gisel-vop3p
Shoreshen Mar 28, 2025
a792c1d
fix comments
Shoreshen Mar 28, 2025
9ac58f9
Merge branch 'main' into gisel-vop3p
Shoreshen Mar 28, 2025
0d59649
Merge branch 'main' into gisel-vop3p
Shoreshen Mar 31, 2025
8390425
Merge branch 'main' into gisel-vop3p
Shoreshen Apr 1, 2025
276e41b
Merge branch 'main' into gisel-vop3p
Shoreshen Apr 2, 2025
1cb1651
Merge branch 'main' into gisel-vop3p
Shoreshen Apr 2, 2025
c2eeedd
Merge branch 'main' into gisel-vop3p
Shoreshen Apr 3, 2025
dc65247
Merge branch 'main' into gisel-vop3p
Shoreshen Apr 7, 2025
6cd21e6
Merge branch 'main' into gisel-vop3p
Shoreshen Apr 8, 2025
ee28947
Merge branch 'main' into gisel-vop3p
Shoreshen Apr 9, 2025
797055d
Merge branch 'main' into gisel-vop3p
Shoreshen Apr 11, 2025
e544665
Merge branch 'main' into gisel-vop3p
Shoreshen Apr 11, 2025
0eac2e9
fix comments and case changes
Shoreshen Apr 11, 2025
e328d7a
Merge branch 'main' into gisel-vop3p
Shoreshen Apr 14, 2025
c1680f3
Merge branch 'main' into gisel-vop3p
Shoreshen Apr 15, 2025
d9dc316
Merge branch 'main' into gisel-vop3p
Shoreshen Apr 16, 2025
223dc11
Merge branch 'main' into gisel-vop3p
Shoreshen Apr 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
367 changes: 339 additions & 28 deletions llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4293,44 +4293,350 @@ AMDGPUInstructionSelector::selectVOP3NoMods(MachineOperand &Root) const {
}};
}

std::pair<Register, unsigned>
AMDGPUInstructionSelector::selectVOP3PModsImpl(
Register Src, const MachineRegisterInfo &MRI, bool IsDOT) const {
unsigned Mods = 0;
MachineInstr *MI = MRI.getVRegDef(Src);
enum srcStatus {
IS_SAME,
IS_UPPER_HALF,
IS_LOWER_HALF,
IS_NEG,
IS_UPPER_HALF_NEG,
IS_LOWER_HALF_NEG
};

static bool isTruncHalf(const MachineInstr *MI,
const MachineRegisterInfo &MRI) {
if (MI->getOpcode() != AMDGPU::G_TRUNC) {
return false;
}
unsigned dstSize = MRI.getType(MI->getOperand(0).getReg()).getSizeInBits();
unsigned srcSize = MRI.getType(MI->getOperand(1).getReg()).getSizeInBits();
return dstSize * 2 == srcSize;
}

if (MI->getOpcode() == AMDGPU::G_FNEG &&
// It's possible to see an f32 fneg here, but unlikely.
// TODO: Treat f32 fneg as only high bit.
MRI.getType(Src) == LLT::fixed_vector(2, 16)) {
Mods ^= (SISrcMods::NEG | SISrcMods::NEG_HI);
Src = MI->getOperand(1).getReg();
MI = MRI.getVRegDef(Src);
static bool isLshrHalf(const MachineInstr *MI, const MachineRegisterInfo &MRI) {
if (MI->getOpcode() != AMDGPU::G_LSHR) {
return false;
}
Register ShiftSrc;
std::optional<ValueAndVReg> ShiftAmt;
if (mi_match(MI->getOperand(0).getReg(), MRI,
m_GLShr(m_Reg(ShiftSrc), m_GCst(ShiftAmt)))) {
unsigned srcSize = MRI.getType(MI->getOperand(1).getReg()).getSizeInBits();
unsigned shift = ShiftAmt->Value.getZExtValue();
return shift * 2 == srcSize;
}
return false;
}

// TODO: Handle G_FSUB 0 as fneg
static bool isShlHalf(const MachineInstr *MI, const MachineRegisterInfo &MRI) {
if (MI->getOpcode() != AMDGPU::G_SHL) {
return false;
}
Register ShiftSrc;
std::optional<ValueAndVReg> ShiftAmt;
if (mi_match(MI->getOperand(0).getReg(), MRI,
m_GShl(m_Reg(ShiftSrc), m_GCst(ShiftAmt)))) {
unsigned srcSize = MRI.getType(MI->getOperand(1).getReg()).getSizeInBits();
unsigned shift = ShiftAmt->Value.getZExtValue();
return shift * 2 == srcSize;
}
return false;
}

// TODO: Match op_sel through g_build_vector_trunc and g_shuffle_vector.
(void)IsDOT; // DOTs do not use OPSEL on gfx942+, check ST.hasDOTOpSelHazard()
static bool retOpStat(const MachineOperand *Op, srcStatus stat,
std::pair<const MachineOperand *, srcStatus> &curr) {
if ((Op->isReg() && !(Op->getReg().isPhysical())) || Op->isImm() ||
Op->isCImm() || Op->isFPImm()) {
curr = {Op, stat};
return true;
}
return false;
}

srcStatus getNegStatus(srcStatus S) {
switch (S) {
case IS_SAME:
return IS_NEG;
case IS_UPPER_HALF:
return IS_UPPER_HALF_NEG;
case IS_LOWER_HALF:
return IS_LOWER_HALF_NEG;
case IS_NEG:
return IS_SAME;
case IS_UPPER_HALF_NEG:
return IS_UPPER_HALF;
case IS_LOWER_HALF_NEG:
return IS_LOWER_HALF;
}
llvm_unreachable("unexpected srcStatus");
}

static bool calcNextStatus(std::pair<const MachineOperand *, srcStatus> &curr,
const MachineRegisterInfo &MRI) {
if (!curr.first->isReg()) {
return false;
}
const MachineInstr *MI = nullptr;

if (!curr.first->isDef()) {
MI = MRI.getVRegDef(curr.first->getReg());
} else {
MI = curr.first->getParent();
}
if (!MI) {
return false;
}

unsigned Opc = MI->getOpcode();

// Handle general Opc cases
switch (Opc) {
case AMDGPU::G_BITCAST:
case AMDGPU::G_CONSTANT:
case AMDGPU::G_FCONSTANT:
case AMDGPU::COPY:
return retOpStat(&MI->getOperand(1), curr.second, curr);
case AMDGPU::G_FNEG:
// XXXX + 3 = XXXX_NEG, (XXXX_NEG + 3) mod 3 = XXXX
return retOpStat(&MI->getOperand(1), getNegStatus(curr.second), curr);
}

// Calc next stat from current stat
switch (curr.second) {
case IS_SAME:
if (isTruncHalf(MI, MRI)) {
return retOpStat(&MI->getOperand(1), IS_LOWER_HALF, curr);
}
break;
case IS_NEG:
if (isTruncHalf(MI, MRI)) {
return retOpStat(&MI->getOperand(1), IS_LOWER_HALF_NEG, curr);
}
break;
case IS_UPPER_HALF:
if (isShlHalf(MI, MRI)) {
return retOpStat(&MI->getOperand(1), IS_LOWER_HALF, curr);
}
break;
case IS_LOWER_HALF:
if (isLshrHalf(MI, MRI)) {
return retOpStat(&MI->getOperand(1), IS_UPPER_HALF, curr);
}
break;
case IS_UPPER_HALF_NEG:
if (isShlHalf(MI, MRI)) {
return retOpStat(&MI->getOperand(1), IS_LOWER_HALF_NEG, curr);
}
break;
case IS_LOWER_HALF_NEG:
if (isLshrHalf(MI, MRI)) {
return retOpStat(&MI->getOperand(1), IS_UPPER_HALF_NEG, curr);
}
break;
}
return false;
}

SmallVector<std::pair<const MachineOperand *, srcStatus>>
getSrcStats(const MachineOperand *Op, const MachineRegisterInfo &MRI,
bool onlyLastSameOrNeg = false, int maxDepth = 6) {
int depth = 0;
std::pair<const MachineOperand *, srcStatus> curr = {Op, IS_SAME};
SmallVector<std::pair<const MachineOperand *, srcStatus>> statList;

while (depth <= maxDepth && calcNextStatus(curr, MRI)) {
depth++;
if ((onlyLastSameOrNeg &&
(curr.second != IS_SAME && curr.second != IS_NEG))) {
break;
} else if (!onlyLastSameOrNeg) {
statList.push_back(curr);
}
}
if (onlyLastSameOrNeg) {
statList.push_back(curr);
}
return statList;
}

static bool isInlinableConstant(const MachineOperand &Op,
const SIInstrInfo &TII) {
if (Op.isFPImm()) {
return TII.isInlineConstant(Op.getFPImm()->getValueAPF());
}
return false;
}

static bool isSameBitWidth(const MachineOperand *Op1, const MachineOperand *Op2,
const MachineRegisterInfo &MRI) {
unsigned width1 = MRI.getType(Op1->getReg()).getSizeInBits();
unsigned width2 = MRI.getType(Op2->getReg()).getSizeInBits();
return width1 == width2;
}

static bool isSameOperand(const MachineOperand *Op1,
const MachineOperand *Op2) {
if (Op1->isReg()) {
if (Op2->isReg()) {
return Op1->getReg() == Op2->getReg();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought you can directly use Op1->isIdenticalTo(*Op2) for all?

Copy link
Contributor Author

@Shoreshen Shoreshen Mar 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @shiltian , direct use of isIdenticalTo will differentiate use/def of register and subreg.

For example although 2 of them are same virtual register, it will return false if one of them is defined by instruction, one of them is used by instruction

}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is ignoring subregister uses, but you shouldn't encounter them either

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @arsenm , to be honest I'm not sure about the subreg. From the isIdenticalTo function, this is comparing SubReg_TargetFlags member of MachineOperand class.

While in AMD backend 's TargetOperandFlags enumeration it seems more like address related I guess.

Should I incorporate the subreg comparison?? can you explain what SubReg_TargetFlags is used for in AMD backend??

Thanks a lot

return false;
}
return Op1->isIdenticalTo(*Op2);
}

static bool validToPack(srcStatus HiStat, srcStatus LoStat, unsigned int &Mods,
const MachineOperand *newOp,
const MachineOperand *RootOp, const SIInstrInfo &TII,
const MachineRegisterInfo &MRI) {
if (newOp->isReg()) {
if (isSameBitWidth(newOp, RootOp, MRI)) {
// IS_LOWER_HALF remain 0
if (HiStat == IS_UPPER_HALF_NEG) {
Mods ^= SISrcMods::NEG_HI;
Mods |= SISrcMods::OP_SEL_1;
} else if (HiStat == IS_UPPER_HALF) {
Mods |= SISrcMods::OP_SEL_1;
} else if (HiStat == IS_LOWER_HALF_NEG) {
Mods ^= SISrcMods::NEG_HI;
}
if (LoStat == IS_UPPER_HALF_NEG) {
Mods ^= SISrcMods::NEG;
Mods |= SISrcMods::OP_SEL_0;
} else if (LoStat == IS_UPPER_HALF) {
Mods |= SISrcMods::OP_SEL_0;
} else if (LoStat == IS_UPPER_HALF_NEG) {
Mods |= SISrcMods::NEG;
}
return true;
}
} else {
if ((HiStat == IS_SAME || HiStat == IS_NEG) &&
(LoStat == IS_SAME || LoStat == IS_NEG) &&
isInlinableConstant(*newOp, TII)) {
if (HiStat == IS_NEG) {
Mods ^= SISrcMods::NEG_HI;
}
if (LoStat == IS_NEG) {
Mods ^= SISrcMods::NEG;
}
// opsel = opsel_hi = 0, since the upper half and lower half both
// the same as the target inlinable constant
return true;
}
}
return false;
}

std::pair<const MachineOperand *, unsigned>
AMDGPUInstructionSelector::selectVOP3PModsImpl(const MachineOperand *Op,
const MachineRegisterInfo &MRI,
bool IsDOT) const {
unsigned Mods = 0;
const MachineOperand *RootOp = Op;
std::pair<const MachineOperand *, srcStatus> stat =
getSrcStats(Op, MRI, true)[0];
if (!stat.first->isReg()) {
Mods |= SISrcMods::OP_SEL_1;
return {Op, Mods};
}
if (stat.second == IS_NEG) {
Mods ^= (SISrcMods::NEG | SISrcMods::NEG_HI);
}
Op = stat.first;
MachineInstr *MI = MRI.getVRegDef(Op->getReg());
if (MI->getOpcode() == AMDGPU::G_BUILD_VECTOR && MI->getNumOperands() == 3 &&
(!IsDOT || !Subtarget->hasDOTOpSelHazard())) {
SmallVector<std::pair<const MachineOperand *, srcStatus>> statList_Hi;
SmallVector<std::pair<const MachineOperand *, srcStatus>> statList_Lo;
statList_Hi = getSrcStats(&MI->getOperand(2), MRI);
if (statList_Hi.size() != 0) {
statList_Lo = getSrcStats(&MI->getOperand(1), MRI);
if (statList_Lo.size() != 0) {
for (int i = statList_Hi.size() - 1; i >= 0; i--) {
for (int j = statList_Lo.size() - 1; j >= 0; j--) {
if (isSameOperand(statList_Hi[i].first, statList_Lo[j].first)) {
if (validToPack(statList_Hi[i].second, statList_Lo[j].second,
Mods, statList_Hi[i].first, RootOp, TII, MRI)) {
return {statList_Hi[i].first, Mods};
}
}
}
}
}
}
}
// Packed instructions do not have abs modifiers.
Mods |= SISrcMods::OP_SEL_1;

return std::pair(Src, Mods);
return {Op, Mods};
}

int64_t getAllKindImm(const MachineOperand *Op) {
switch (Op->getType()) {
case MachineOperand::MachineOperandType::MO_Immediate:
return Op->getImm();
case MachineOperand::MachineOperandType::MO_CImmediate:
return Op->getCImm()->getSExtValue();
case MachineOperand::MachineOperandType::MO_FPImmediate:
return Op->getFPImm()->getValueAPF().bitcastToAPInt().getSExtValue();
}
llvm_unreachable("not an imm type");
}

bool checkRB(const MachineOperand *Op, int RBNo,
const AMDGPURegisterBankInfo &RBI, const MachineRegisterInfo &MRI,
const TargetRegisterInfo &TRI) {
const RegisterBank *RB = RBI.getRegBank(Op->getReg(), MRI, TRI);
return RB->getID() == RBNo;
}

const MachineOperand *
getVReg(const MachineOperand *newOp, const MachineOperand *RootOp,
const AMDGPURegisterBankInfo &RBI, MachineRegisterInfo &MRI,
const TargetRegisterInfo &TRI, const SIInstrInfo &TII) {
// RootOp can only be VGPR or SGPR (some hand written cases such as
// inst-select-ashr.v2s16.mir::ashr_v2s16_vs)
if (checkRB(RootOp, AMDGPU::SGPRRegBankID, RBI, MRI, TRI) ||
checkRB(newOp, AMDGPU::VGPRRegBankID, RBI, MRI, TRI)) {
return newOp;
}
MachineInstr *MI = MRI.getVRegDef(RootOp->getReg());
if (MI->getOpcode() == AMDGPU::COPY &&
isSameOperand(newOp, &MI->getOperand(1))) {
// RootOp is VGPR, newOp is not VGPR, but RootOp = COPY newOp
return RootOp;
}

MachineBasicBlock *BB = MI->getParent();
const TargetRegisterClass *DstRC =
TRI.getConstrainedRegClassForOperand(*RootOp, MRI);
Register dstReg = MRI.createVirtualRegister(DstRC);

MachineInstrBuilder MIB =
BuildMI(*BB, MI, MI->getDebugLoc(), TII.get(AMDGPU::COPY), dstReg)
.addReg(newOp->getReg());

// only accept VGPR
return &MIB->getOperand(0);
}

InstructionSelector::ComplexRendererFns
AMDGPUInstructionSelector::selectVOP3PMods(MachineOperand &Root) const {
MachineRegisterInfo &MRI
= Root.getParent()->getParent()->getParent()->getRegInfo();

Register Src;
unsigned Mods;
std::tie(Src, Mods) = selectVOP3PModsImpl(Root.getReg(), MRI);

std::pair<const MachineOperand *, unsigned> res =
selectVOP3PModsImpl(&Root, MRI);
if (!(res.first->isReg())) {
return {{
[=](MachineInstrBuilder &MIB) { MIB.addImm(getAllKindImm(res.first)); },
[=](MachineInstrBuilder &MIB) { MIB.addImm(res.second); } // src_mods
}};
}
res.first = getVReg(res.first, &Root, RBI, MRI, TRI, TII);
return {{
[=](MachineInstrBuilder &MIB) { MIB.addReg(Src); },
[=](MachineInstrBuilder &MIB) { MIB.addImm(Mods); } // src_mods
[=](MachineInstrBuilder &MIB) { MIB.addReg(res.first->getReg()); },
[=](MachineInstrBuilder &MIB) { MIB.addImm(res.second); } // src_mods
}};
}

Expand All @@ -4339,13 +4645,18 @@ AMDGPUInstructionSelector::selectVOP3PModsDOT(MachineOperand &Root) const {
MachineRegisterInfo &MRI
= Root.getParent()->getParent()->getParent()->getRegInfo();

Register Src;
unsigned Mods;
std::tie(Src, Mods) = selectVOP3PModsImpl(Root.getReg(), MRI, true);

std::pair<const MachineOperand *, unsigned> res =
selectVOP3PModsImpl(&Root, MRI, true);
if (!(res.first->isReg())) {
return {{
[=](MachineInstrBuilder &MIB) { MIB.addImm(getAllKindImm(res.first)); },
[=](MachineInstrBuilder &MIB) { MIB.addImm(res.second); } // src_mods
}};
}
res.first = getVReg(res.first, &Root, RBI, MRI, TRI, TII);
return {{
[=](MachineInstrBuilder &MIB) { MIB.addReg(Src); },
[=](MachineInstrBuilder &MIB) { MIB.addImm(Mods); } // src_mods
[=](MachineInstrBuilder &MIB) { MIB.addReg(res.first->getReg()); },
[=](MachineInstrBuilder &MIB) { MIB.addImm(res.second); } // src_mods
}};
}

Expand Down
Loading
Loading