Skip to content

Commit 6f3efd8

Browse files
committed
[X86] combineTruncatedArithmetic - move more of fold inside combinei64TruncSrlConstant
Let combinei64TruncSrlConstant decide when the fold is invalid instead of splitting so many of the conditions with combineTruncatedArithmetic NOTE: We can probably relax the i32 truncation constraint to <= i32, perform the SRL as i32 and then truncate further. Noticed while triaging #141496
1 parent ac9a466 commit 6f3efd8

File tree

1 file changed

+16
-19
lines changed

1 file changed

+16
-19
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -54214,22 +54214,31 @@ static SDValue combineLRINT_LLRINT(SDNode *N, SelectionDAG &DAG,
5421454214
// cases.
5421554215
static SDValue combinei64TruncSrlConstant(SDValue N, EVT VT, SelectionDAG &DAG,
5421654216
const SDLoc &DL) {
54217+
assert(N.getOpcode() == ISD::SRL && "Unknown shift opcode");
54218+
std::optional<uint64_t> ValidSrlConst = DAG.getValidShiftAmount(N);
54219+
if (!ValidSrlConst)
54220+
return SDValue();
54221+
uint64_t SrlConstVal = *ValidSrlConst;
5421754222

5421854223
SDValue Op = N.getOperand(0);
54219-
APInt OpConst = Op.getConstantOperandAPInt(1);
54220-
APInt SrlConst = N.getConstantOperandAPInt(1);
54221-
uint64_t SrlConstVal = SrlConst.getZExtValue();
5422254224
unsigned Opcode = Op.getOpcode();
54225+
assert(VT == MVT::i32 && Op.getValueType() == MVT::i64 &&
54226+
"Illegal truncation types");
54227+
54228+
if ((Opcode != ISD::ADD && Opcode != ISD::OR && Opcode != ISD::XOR) ||
54229+
!isa<ConstantSDNode>(Op.getOperand(1)))
54230+
return SDValue();
54231+
const APInt &OpConst = Op.getConstantOperandAPInt(1);
5422354232

54224-
if (SrlConst.ule(32) ||
54233+
if (SrlConstVal <= 32 ||
5422554234
(Opcode == ISD::ADD && OpConst.countr_zero() < SrlConstVal))
5422654235
return SDValue();
5422754236

5422854237
SDValue OpLhsSrl =
5422954238
DAG.getNode(ISD::SRL, DL, MVT::i64, Op.getOperand(0), N.getOperand(1));
5423054239
SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, VT, OpLhsSrl);
5423154240

54232-
APInt NewOpConstVal = OpConst.lshr(SrlConst).trunc(VT.getSizeInBits());
54241+
APInt NewOpConstVal = OpConst.lshr(SrlConstVal).trunc(VT.getSizeInBits());
5423354242
SDValue NewOpConst = DAG.getConstant(NewOpConstVal, DL, VT);
5423454243
SDValue NewOpNode = DAG.getNode(Opcode, DL, VT, Trunc, NewOpConst);
5423554244

@@ -54285,20 +54294,8 @@ static SDValue combineTruncatedArithmetic(SDNode *N, SelectionDAG &DAG,
5428554294
if (!Src.hasOneUse())
5428654295
return SDValue();
5428754296

54288-
if (VT == MVT::i32 && SrcVT == MVT::i64 && SrcOpcode == ISD::SRL &&
54289-
isa<ConstantSDNode>(Src.getOperand(1))) {
54290-
54291-
unsigned SrcOpOpcode = Src.getOperand(0).getOpcode();
54292-
if ((SrcOpOpcode != ISD::ADD && SrcOpOpcode != ISD::OR &&
54293-
SrcOpOpcode != ISD::XOR) ||
54294-
!isa<ConstantSDNode>(Src.getOperand(0).getOperand(1)))
54295-
return SDValue();
54296-
54297-
if (SDValue R = combinei64TruncSrlConstant(Src, VT, DAG, DL))
54298-
return R;
54299-
54300-
return SDValue();
54301-
}
54297+
if (VT == MVT::i32 && SrcVT == MVT::i64 && SrcOpcode == ISD::SRL)
54298+
return combinei64TruncSrlConstant(Src, VT, DAG, DL);
5430254299

5430354300
if (!VT.isVector())
5430454301
return SDValue();

0 commit comments

Comments
 (0)