Skip to content

Commit 97f782d

Browse files
committed
AMDGPU: Handle folding vector splats of inline split f64 inline immediates
Recognize a reg_sequence with 32-bit elements that produce a 64-bit splat value. This enables folding f64 constants into mfma operands
1 parent 1a8e2fb commit 97f782d

File tree

2 files changed

+76
-68
lines changed

2 files changed

+76
-68
lines changed

llvm/lib/Target/AMDGPU/SIFoldOperands.cpp

Lines changed: 70 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -227,12 +227,12 @@ class SIFoldOperandsImpl {
227227
getRegSeqInit(SmallVectorImpl<std::pair<MachineOperand *, unsigned>> &Defs,
228228
Register UseReg) const;
229229

230-
std::pair<MachineOperand *, const TargetRegisterClass *>
230+
std::pair<int64_t, const TargetRegisterClass *>
231231
isRegSeqSplat(MachineInstr &RegSeg) const;
232232

233-
MachineOperand *tryFoldRegSeqSplat(MachineInstr *UseMI, unsigned UseOpIdx,
234-
MachineOperand *SplatVal,
235-
const TargetRegisterClass *SplatRC) const;
233+
bool tryFoldRegSeqSplat(MachineInstr *UseMI, unsigned UseOpIdx,
234+
int64_t SplatVal,
235+
const TargetRegisterClass *SplatRC) const;
236236

237237
bool tryToFoldACImm(const FoldableDef &OpToFold, MachineInstr *UseMI,
238238
unsigned UseOpIdx,
@@ -966,15 +966,15 @@ const TargetRegisterClass *SIFoldOperandsImpl::getRegSeqInit(
966966
return getRegSeqInit(*Def, Defs);
967967
}
968968

969-
std::pair<MachineOperand *, const TargetRegisterClass *>
969+
std::pair<int64_t, const TargetRegisterClass *>
970970
SIFoldOperandsImpl::isRegSeqSplat(MachineInstr &RegSeq) const {
971971
SmallVector<std::pair<MachineOperand *, unsigned>, 32> Defs;
972972
const TargetRegisterClass *SrcRC = getRegSeqInit(RegSeq, Defs);
973973
if (!SrcRC)
974974
return {};
975975

976-
// TODO: Recognize 64-bit splats broken into 32-bit pieces (i.e. recognize
977-
// every other other element is 0 for 64-bit immediates)
976+
bool TryToMatchSplat64 = false;
977+
978978
int64_t Imm;
979979
for (unsigned I = 0, E = Defs.size(); I != E; ++I) {
980980
const MachineOperand *Op = Defs[I].first;
@@ -986,38 +986,75 @@ SIFoldOperandsImpl::isRegSeqSplat(MachineInstr &RegSeq) const {
986986
Imm = SubImm;
987987
continue;
988988
}
989-
if (Imm != SubImm)
989+
990+
if (Imm != SubImm) {
991+
if (I == 1 && (E & 1) == 0) {
992+
// If we have an even number of inputs, there's a chance this is a
993+
// 64-bit element splat broken into 32-bit pieces.
994+
TryToMatchSplat64 = true;
995+
break;
996+
}
997+
990998
return {}; // Can only fold splat constants
999+
}
1000+
}
1001+
1002+
if (!TryToMatchSplat64)
1003+
return {Defs[0].first->getImm(), SrcRC};
1004+
1005+
// Fallback to recognizing 64-bit splats broken into 32-bit pieces
1006+
// (i.e. recognize every other other element is 0 for 64-bit immediates)
1007+
int64_t SplatVal64;
1008+
for (unsigned I = 0, E = Defs.size(); I != E; I += 2) {
1009+
const MachineOperand *Op0 = Defs[I].first;
1010+
const MachineOperand *Op1 = Defs[I + 1].first;
1011+
1012+
if (!Op0->isImm() || !Op1->isImm())
1013+
return {};
1014+
1015+
unsigned SubReg0 = Defs[I].second;
1016+
unsigned SubReg1 = Defs[I + 1].second;
1017+
1018+
// Assume we're going to generally encounter reg_sequences with sorted
1019+
// subreg indexes, so reject any that aren't consecutive.
1020+
if (TRI->getChannelFromSubReg(SubReg0) + 1 !=
1021+
TRI->getChannelFromSubReg(SubReg1))
1022+
return {};
1023+
1024+
int64_t MergedVal = Make_64(Op1->getImm(), Op0->getImm());
1025+
if (I == 0)
1026+
SplatVal64 = MergedVal;
1027+
else if (SplatVal64 != MergedVal)
1028+
return {};
9911029
}
9921030

993-
return {Defs[0].first, SrcRC};
1031+
const TargetRegisterClass *RC64 = TRI->getSubRegisterClass(
1032+
MRI->getRegClass(RegSeq.getOperand(0).getReg()), AMDGPU::sub0_sub1);
1033+
1034+
return {SplatVal64, RC64};
9941035
}
9951036

996-
MachineOperand *SIFoldOperandsImpl::tryFoldRegSeqSplat(
997-
MachineInstr *UseMI, unsigned UseOpIdx, MachineOperand *SplatVal,
1037+
bool SIFoldOperandsImpl::tryFoldRegSeqSplat(
1038+
MachineInstr *UseMI, unsigned UseOpIdx, int64_t SplatVal,
9981039
const TargetRegisterClass *SplatRC) const {
9991040
const MCInstrDesc &Desc = UseMI->getDesc();
10001041
if (UseOpIdx >= Desc.getNumOperands())
1001-
return nullptr;
1042+
return false;
10021043

10031044
// Filter out unhandled pseudos.
10041045
if (!AMDGPU::isSISrcOperand(Desc, UseOpIdx))
1005-
return nullptr;
1046+
return false;
10061047

10071048
int16_t RCID = Desc.operands()[UseOpIdx].RegClass;
10081049
if (RCID == -1)
1009-
return nullptr;
1050+
return false;
1051+
1052+
const TargetRegisterClass *OpRC = TRI->getRegClass(RCID);
10101053

10111054
// Special case 0/-1, since when interpreted as a 64-bit element both halves
1012-
// have the same bits. Effectively this code does not handle 64-bit element
1013-
// operands correctly, as the incoming 64-bit constants are already split into
1014-
// 32-bit sequence elements.
1015-
//
1016-
// TODO: We should try to figure out how to interpret the reg_sequence as a
1017-
// split 64-bit splat constant, or use 64-bit pseudos for materializing f64
1018-
// constants.
1019-
if (SplatVal->getImm() != 0 && SplatVal->getImm() != -1) {
1020-
const TargetRegisterClass *OpRC = TRI->getRegClass(RCID);
1055+
// have the same bits. These are the only cases where a splat has the same
1056+
// interpretation for 32-bit and 64-bit splats.
1057+
if (SplatVal != 0 && SplatVal != -1) {
10211058
// We need to figure out the scalar type read by the operand. e.g. the MFMA
10221059
// operand will be AReg_128, and we want to check if it's compatible with an
10231060
// AReg_32 constant.
@@ -1031,17 +1068,18 @@ MachineOperand *SIFoldOperandsImpl::tryFoldRegSeqSplat(
10311068
OpRC = TRI->getSubRegisterClass(OpRC, AMDGPU::sub0_sub1);
10321069
break;
10331070
default:
1034-
return nullptr;
1071+
return false;
10351072
}
10361073

10371074
if (!TRI->getCommonSubClass(OpRC, SplatRC))
1038-
return nullptr;
1075+
return false;
10391076
}
10401077

1041-
if (!TII->isOperandLegal(*UseMI, UseOpIdx, SplatVal))
1042-
return nullptr;
1078+
MachineOperand TmpOp = MachineOperand::CreateImm(SplatVal);
1079+
if (!TII->isOperandLegal(*UseMI, UseOpIdx, &TmpOp))
1080+
return false;
10431081

1044-
return SplatVal;
1082+
return true;
10451083
}
10461084

10471085
bool SIFoldOperandsImpl::tryToFoldACImm(
@@ -1119,7 +1157,7 @@ void SIFoldOperandsImpl::foldOperand(
11191157
Register RegSeqDstReg = UseMI->getOperand(0).getReg();
11201158
unsigned RegSeqDstSubReg = UseMI->getOperand(UseOpIdx + 1).getImm();
11211159

1122-
MachineOperand *SplatVal;
1160+
int64_t SplatVal;
11231161
const TargetRegisterClass *SplatRC;
11241162
std::tie(SplatVal, SplatRC) = isRegSeqSplat(*UseMI);
11251163

@@ -1130,10 +1168,9 @@ void SIFoldOperandsImpl::foldOperand(
11301168
MachineInstr *RSUseMI = RSUse->getParent();
11311169
unsigned OpNo = RSUseMI->getOperandNo(RSUse);
11321170

1133-
if (SplatVal) {
1134-
if (MachineOperand *Foldable =
1135-
tryFoldRegSeqSplat(RSUseMI, OpNo, SplatVal, SplatRC)) {
1136-
FoldableDef SplatDef(*Foldable, SplatRC);
1171+
if (SplatRC) {
1172+
if (tryFoldRegSeqSplat(RSUseMI, OpNo, SplatVal, SplatRC)) {
1173+
FoldableDef SplatDef(SplatVal, SplatRC);
11371174
appendFoldCandidate(FoldList, RSUseMI, OpNo, SplatDef);
11381175
continue;
11391176
}

llvm/test/CodeGen/AMDGPU/llvm.amdgcn.mfma.gfx90a.ll

Lines changed: 6 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -165,19 +165,9 @@ bb:
165165
}
166166

167167
; GCN-LABEL: {{^}}test_mfma_f64_16x16x4f64_splat_imm_1:
168-
; GCN: v_mov_b32_e32 [[HIGH_BITS:v[0-9]+]], 0x3ff00000
169-
; GCN: v_accvgpr_write_b32 a[[A_HIGH_BITS_0:[0-9]+]], [[HIGH_BITS]]
170-
; GCN: v_accvgpr_write_b32 a[[A_LOW_BITS_0:[0-9]+]], 0{{$}}
171-
; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]]
172-
; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_HIGH_BITS_0]]
173-
; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]]
174-
; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_HIGH_BITS_0]]
175-
; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]]
176-
; GCN: v_accvgpr_mov_b32 a[[LAST_CONST_REG:[0-9]+]], a[[A_HIGH_BITS_0]]
177-
178-
; GFX90A: v_mfma_f64_16x16x4f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], a{{\[}}[[A_LOW_BITS_0]]:[[LAST_CONST_REG]]{{\]$}}
168+
; GFX90A: v_mfma_f64_16x16x4f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], 1.0{{$}}
179169
; GFX90A: v_mfma_f64_16x16x4f64 a[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], [[M1]] cbsz:1 abid:2 blgp:3
180-
; GFX942: v_mfma_f64_16x16x4_f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], a{{\[}}[[A_LOW_BITS_0]]:[[LAST_CONST_REG]]{{\]$}}
170+
; GFX942: v_mfma_f64_16x16x4_f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], 1.0{{$}}
181171
; GFX942: v_mfma_f64_16x16x4_f64 a[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], [[M1]] cbsz:1 abid:2 neg:[1,1,0]
182172
; GCN: global_store_dwordx4
183173
; GCN: global_store_dwordx4
@@ -190,19 +180,9 @@ bb:
190180
}
191181

192182
; GCN-LABEL: {{^}}test_mfma_f64_16x16x4f64_splat_imm_neg1:
193-
; GCN: v_mov_b32_e32 [[HIGH_BITS:v[0-9]+]], 0xbff00000
194-
; GCN: v_accvgpr_write_b32 a[[A_HIGH_BITS_0:[0-9]+]], [[HIGH_BITS]]
195-
; GCN: v_accvgpr_write_b32 a[[A_LOW_BITS_0:[0-9]+]], 0{{$}}
196-
; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]]
197-
; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_HIGH_BITS_0]]
198-
; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]]
199-
; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_HIGH_BITS_0]]
200-
; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]]
201-
; GCN: v_accvgpr_mov_b32 a[[LAST_CONST_REG:[0-9]+]], a[[A_HIGH_BITS_0]]
202-
203-
; GFX90A: v_mfma_f64_16x16x4f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], a{{\[}}[[A_LOW_BITS_0]]:[[LAST_CONST_REG]]{{\]$}}
183+
; GFX90A: v_mfma_f64_16x16x4f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], -1.0{{$}}
204184
; GFX90A: v_mfma_f64_16x16x4f64 a[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], [[M1]] cbsz:1 abid:2 blgp:3
205-
; GFX942: v_mfma_f64_16x16x4_f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], a{{\[}}[[A_LOW_BITS_0]]:[[LAST_CONST_REG]]{{\]$}}
185+
; GFX942: v_mfma_f64_16x16x4_f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], -1.0{{$}}
206186
; GFX942: v_mfma_f64_16x16x4_f64 a[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], [[M1]] cbsz:1 abid:2 neg:[1,1,0]
207187
; GCN: global_store_dwordx4
208188
; GCN: global_store_dwordx4
@@ -215,18 +195,9 @@ bb:
215195
}
216196

217197
; GCN-LABEL: {{^}}test_mfma_f64_16x16x4f64_splat_imm_int_64:
218-
; GCN: v_accvgpr_write_b32 a[[A_LOW_BITS_0:[0-9]+]], 64{{$}}
219-
; GCN: v_accvgpr_write_b32 a[[A_HIGH_BITS_0:[0-9]+]], 0
220-
; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]]
221-
; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_HIGH_BITS_0]]
222-
; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]]
223-
; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_HIGH_BITS_0]]
224-
; GCN: v_accvgpr_mov_b32 a{{[0-9]+}}, a[[A_LOW_BITS_0]]
225-
; GCN: v_accvgpr_mov_b32 a[[LAST_CONST_REG:[0-9]+]], a[[A_HIGH_BITS_0]]
226-
227-
; GFX90A: v_mfma_f64_16x16x4f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], a{{\[}}[[A_LOW_BITS_0]]:[[LAST_CONST_REG]]{{\]$}}
198+
; GFX90A: v_mfma_f64_16x16x4f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], 64{{$}}
228199
; GFX90A: v_mfma_f64_16x16x4f64 a[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], [[M1]] cbsz:1 abid:2 blgp:3
229-
; GFX942: v_mfma_f64_16x16x4_f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], a{{\[}}[[A_LOW_BITS_0]]:[[LAST_CONST_REG]]{{\]$}}
200+
; GFX942: v_mfma_f64_16x16x4_f64 [[M1:a\[[0-9]+:[0-9]+\]]], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], 64{{$}}
230201
; GFX942: v_mfma_f64_16x16x4_f64 a[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], v[{{[0-9]+:[0-9]+}}], [[M1]] cbsz:1 abid:2 neg:[1,1,0]
231202
; GCN: global_store_dwordx4
232203
; GCN: global_store_dwordx4

0 commit comments

Comments
 (0)