@@ -227,12 +227,12 @@ class SIFoldOperandsImpl {
227
227
getRegSeqInit (SmallVectorImpl<std::pair<MachineOperand *, unsigned >> &Defs,
228
228
Register UseReg) const ;
229
229
230
- std::pair<MachineOperand * , const TargetRegisterClass *>
230
+ std::pair<int64_t , const TargetRegisterClass *>
231
231
isRegSeqSplat (MachineInstr &RegSeg) const ;
232
232
233
- MachineOperand * tryFoldRegSeqSplat (MachineInstr *UseMI, unsigned UseOpIdx,
234
- MachineOperand * SplatVal,
235
- const TargetRegisterClass *SplatRC) const ;
233
+ bool tryFoldRegSeqSplat (MachineInstr *UseMI, unsigned UseOpIdx,
234
+ int64_t SplatVal,
235
+ const TargetRegisterClass *SplatRC) const ;
236
236
237
237
bool tryToFoldACImm (const FoldableDef &OpToFold, MachineInstr *UseMI,
238
238
unsigned UseOpIdx,
@@ -966,15 +966,15 @@ const TargetRegisterClass *SIFoldOperandsImpl::getRegSeqInit(
966
966
return getRegSeqInit (*Def, Defs);
967
967
}
968
968
969
- std::pair<MachineOperand * , const TargetRegisterClass *>
969
+ std::pair<int64_t , const TargetRegisterClass *>
970
970
SIFoldOperandsImpl::isRegSeqSplat (MachineInstr &RegSeq) const {
971
971
SmallVector<std::pair<MachineOperand *, unsigned >, 32 > Defs;
972
972
const TargetRegisterClass *SrcRC = getRegSeqInit (RegSeq, Defs);
973
973
if (!SrcRC)
974
974
return {};
975
975
976
- // TODO: Recognize 64-bit splats broken into 32-bit pieces (i.e. recognize
977
- // every other other element is 0 for 64-bit immediates)
976
+ bool TryToMatchSplat64 = false ;
977
+
978
978
int64_t Imm;
979
979
for (unsigned I = 0 , E = Defs.size (); I != E; ++I) {
980
980
const MachineOperand *Op = Defs[I].first ;
@@ -986,38 +986,75 @@ SIFoldOperandsImpl::isRegSeqSplat(MachineInstr &RegSeq) const {
986
986
Imm = SubImm;
987
987
continue ;
988
988
}
989
- if (Imm != SubImm)
989
+
990
+ if (Imm != SubImm) {
991
+ if (I == 1 && (E & 1 ) == 0 ) {
992
+ // If we have an even number of inputs, there's a chance this is a
993
+ // 64-bit element splat broken into 32-bit pieces.
994
+ TryToMatchSplat64 = true ;
995
+ break ;
996
+ }
997
+
990
998
return {}; // Can only fold splat constants
999
+ }
1000
+ }
1001
+
1002
+ if (!TryToMatchSplat64)
1003
+ return {Defs[0 ].first ->getImm (), SrcRC};
1004
+
1005
+ // Fallback to recognizing 64-bit splats broken into 32-bit pieces
1006
+ // (i.e. recognize every other other element is 0 for 64-bit immediates)
1007
+ int64_t SplatVal64;
1008
+ for (unsigned I = 0 , E = Defs.size (); I != E; I += 2 ) {
1009
+ const MachineOperand *Op0 = Defs[I].first ;
1010
+ const MachineOperand *Op1 = Defs[I + 1 ].first ;
1011
+
1012
+ if (!Op0->isImm () || !Op1->isImm ())
1013
+ return {};
1014
+
1015
+ unsigned SubReg0 = Defs[I].second ;
1016
+ unsigned SubReg1 = Defs[I + 1 ].second ;
1017
+
1018
+ // Assume we're going to generally encounter reg_sequences with sorted
1019
+ // subreg indexes, so reject any that aren't consecutive.
1020
+ if (TRI->getChannelFromSubReg (SubReg0) + 1 !=
1021
+ TRI->getChannelFromSubReg (SubReg1))
1022
+ return {};
1023
+
1024
+ int64_t MergedVal = Make_64 (Op1->getImm (), Op0->getImm ());
1025
+ if (I == 0 )
1026
+ SplatVal64 = MergedVal;
1027
+ else if (SplatVal64 != MergedVal)
1028
+ return {};
991
1029
}
992
1030
993
- return {Defs[0 ].first , SrcRC};
1031
+ const TargetRegisterClass *RC64 = TRI->getSubRegisterClass (
1032
+ MRI->getRegClass (RegSeq.getOperand (0 ).getReg ()), AMDGPU::sub0_sub1);
1033
+
1034
+ return {SplatVal64, RC64};
994
1035
}
995
1036
996
- MachineOperand * SIFoldOperandsImpl::tryFoldRegSeqSplat (
997
- MachineInstr *UseMI, unsigned UseOpIdx, MachineOperand * SplatVal,
1037
+ bool SIFoldOperandsImpl::tryFoldRegSeqSplat (
1038
+ MachineInstr *UseMI, unsigned UseOpIdx, int64_t SplatVal,
998
1039
const TargetRegisterClass *SplatRC) const {
999
1040
const MCInstrDesc &Desc = UseMI->getDesc ();
1000
1041
if (UseOpIdx >= Desc.getNumOperands ())
1001
- return nullptr ;
1042
+ return false ;
1002
1043
1003
1044
// Filter out unhandled pseudos.
1004
1045
if (!AMDGPU::isSISrcOperand (Desc, UseOpIdx))
1005
- return nullptr ;
1046
+ return false ;
1006
1047
1007
1048
int16_t RCID = Desc.operands ()[UseOpIdx].RegClass ;
1008
1049
if (RCID == -1 )
1009
- return nullptr ;
1050
+ return false ;
1051
+
1052
+ const TargetRegisterClass *OpRC = TRI->getRegClass (RCID);
1010
1053
1011
1054
// Special case 0/-1, since when interpreted as a 64-bit element both halves
1012
- // have the same bits. Effectively this code does not handle 64-bit element
1013
- // operands correctly, as the incoming 64-bit constants are already split into
1014
- // 32-bit sequence elements.
1015
- //
1016
- // TODO: We should try to figure out how to interpret the reg_sequence as a
1017
- // split 64-bit splat constant, or use 64-bit pseudos for materializing f64
1018
- // constants.
1019
- if (SplatVal->getImm () != 0 && SplatVal->getImm () != -1 ) {
1020
- const TargetRegisterClass *OpRC = TRI->getRegClass (RCID);
1055
+ // have the same bits. These are the only cases where a splat has the same
1056
+ // interpretation for 32-bit and 64-bit splats.
1057
+ if (SplatVal != 0 && SplatVal != -1 ) {
1021
1058
// We need to figure out the scalar type read by the operand. e.g. the MFMA
1022
1059
// operand will be AReg_128, and we want to check if it's compatible with an
1023
1060
// AReg_32 constant.
@@ -1031,17 +1068,18 @@ MachineOperand *SIFoldOperandsImpl::tryFoldRegSeqSplat(
1031
1068
OpRC = TRI->getSubRegisterClass (OpRC, AMDGPU::sub0_sub1);
1032
1069
break ;
1033
1070
default :
1034
- return nullptr ;
1071
+ return false ;
1035
1072
}
1036
1073
1037
1074
if (!TRI->getCommonSubClass (OpRC, SplatRC))
1038
- return nullptr ;
1075
+ return false ;
1039
1076
}
1040
1077
1041
- if (!TII->isOperandLegal (*UseMI, UseOpIdx, SplatVal))
1042
- return nullptr ;
1078
+ MachineOperand TmpOp = MachineOperand::CreateImm (SplatVal);
1079
+ if (!TII->isOperandLegal (*UseMI, UseOpIdx, &TmpOp))
1080
+ return false ;
1043
1081
1044
- return SplatVal ;
1082
+ return true ;
1045
1083
}
1046
1084
1047
1085
bool SIFoldOperandsImpl::tryToFoldACImm (
@@ -1119,7 +1157,7 @@ void SIFoldOperandsImpl::foldOperand(
1119
1157
Register RegSeqDstReg = UseMI->getOperand (0 ).getReg ();
1120
1158
unsigned RegSeqDstSubReg = UseMI->getOperand (UseOpIdx + 1 ).getImm ();
1121
1159
1122
- MachineOperand * SplatVal;
1160
+ int64_t SplatVal;
1123
1161
const TargetRegisterClass *SplatRC;
1124
1162
std::tie (SplatVal, SplatRC) = isRegSeqSplat (*UseMI);
1125
1163
@@ -1130,10 +1168,9 @@ void SIFoldOperandsImpl::foldOperand(
1130
1168
MachineInstr *RSUseMI = RSUse->getParent ();
1131
1169
unsigned OpNo = RSUseMI->getOperandNo (RSUse);
1132
1170
1133
- if (SplatVal) {
1134
- if (MachineOperand *Foldable =
1135
- tryFoldRegSeqSplat (RSUseMI, OpNo, SplatVal, SplatRC)) {
1136
- FoldableDef SplatDef (*Foldable, SplatRC);
1171
+ if (SplatRC) {
1172
+ if (tryFoldRegSeqSplat (RSUseMI, OpNo, SplatVal, SplatRC)) {
1173
+ FoldableDef SplatDef (SplatVal, SplatRC);
1137
1174
appendFoldCandidate (FoldList, RSUseMI, OpNo, SplatDef);
1138
1175
continue ;
1139
1176
}
0 commit comments